J'essaie de créer une nouvelle colonne de listes dans Pyspark en utilisant une agrégation groupby sur un ensemble existant de colonnes. Un exemple de trame de données d'entrée est fourni ci-dessous:
------------------------
id | date | value
------------------------
1 |2014-01-03 | 10
1 |2014-01-04 | 5
1 |2014-01-05 | 15
1 |2014-01-06 | 20
2 |2014-02-10 | 100
2 |2014-03-11 | 500
2 |2014-04-15 | 1500
Le résultat attendu est:
id | value_list
------------------------
1 | [10, 5, 15, 20]
2 | [100, 500, 1500]
Les valeurs d'une liste sont triées par date.
J'ai essayé d'utiliser collect_list comme suit:
from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))
Mais collect_list ne garantit pas l'ordre même si je trie la trame de données d'entrée par date avant agrégation.
Quelqu'un pourrait-il m'aider à faire l'agrégation en préservant l'ordre basé sur une deuxième variable (date)?
Si vous collectez à la fois des dates et des valeurs sous forme de liste, vous pouvez trier la colonne résultante en fonction de la date à l'aide de et udf
, puis conserver uniquement les valeurs dans le résultat.
import operator
import pyspark.sql.functions as F
# create list column
grouped_df = input_df.groupby("id") \
.agg(F.collect_list(F.struct("date", "value")) \
.alias("list_col"))
# define udf
def sorter(l):
res = sorted(l, key=operator.itemgetter(0))
return [item[1] for item in res]
sort_udf = F.udf(sorter)
# test
grouped_df.select("id", sort_udf("list_col") \
.alias("sorted_list")) \
.show(truncate = False)
+---+----------------+
|id |sorted_list |
+---+----------------+
|1 |[10, 5, 15, 20] |
|2 |[100, 500, 1500]|
+---+----------------+
from pyspark.sql import functions as F
from pyspark.sql import Window
w = Window.partitionBy('id').orderBy('date')
sorted_list_df = input_df.withColumn(
'sorted_list', F.collect_list('value').over(w)
)\
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
Window
les exemples fournis par les utilisateurs n'expliquent souvent pas vraiment ce qui se passe alors laissez-moi le disséquer pour vous.
Comme vous le savez, en utilisant collect_list
avec groupBy
donnera une liste de valeurs non ordonnée. En effet, selon la façon dont vos données sont partitionnées, Spark ajoutera des valeurs à votre liste dès qu'il trouvera une ligne dans le groupe. L'ordre dépend ensuite de la façon dont Spark planifie votre agrégation sur les exécuteurs.
Une fonction Window
vous permet de contrôler cette situation, en regroupant les lignes par une certaine valeur afin que vous puissiez effectuer une opération over
chacun des groupes résultants:
w = Window.partitionBy('id').orderBy('date')
partitionBy
- vous voulez des groupes/partitions de lignes avec le même id
orderBy
- vous voulez que chaque ligne du groupe soit triée par date
Une fois que vous avez défini l'étendue de votre fenêtre - "lignes avec le même id
, triées par date
" -, vous pouvez l'utiliser pour effectuer une opération sur elle, dans ce cas, un collect_list
:
F.collect_list('value').over(w)
À ce stade, vous avez créé une nouvelle colonne sorted_list
avec une liste ordonnée de valeurs, triées par date, mais vous avez toujours des lignes dupliquées par id
. Pour supprimer les lignes dupliquées que vous souhaitez groupBy
id
et conserver la valeur max
pour chaque groupe:
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
La question était pour PySpark mais pourrait être utile de l'avoir aussi pour Scala Spark.
import org.Apache.spark.sql.functions._
import org.Apache.spark.sql.{DataFrame, Row, SparkSession}
import org.Apache.spark.sql.expressions.{ Window, UserDefinedFunction}
import Java.sql.Date
import Java.time.LocalDate
val spark: SparkSession = ...
// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
(1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
(1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
(1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
(1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
(2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
(2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
(2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)
// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
.toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id| date|value|
//+---+----------+-----+
//| 1|2014-01-03| 10|
//| 1|2014-01-04| 5|
//| 1|2014-01-05| 15|
//| 1|2014-01-06| 20|
//| 2|2014-02-10| 100|
//| 2|2014-02-11| 500|
//| 2|2014-02-15| 1500|
//+---+----------+-----+
// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
.agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id| date_value|
// +---+--------------------+
// | 1|[[2014-01-03,10],...|
// | 2|[[2014-02-10,100]...|
// +---+--------------------+
// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {
rows.map { case Row(date: Date, value: Int) => (date, value) }
.sortBy { case (date, value) => date }
.map { case (date, value) => value }
})
// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id| value_list|
// +---+----------------+
// | 1| [10, 5, 15, 20]|
// | 2|[100, 500, 1500]|
// +---+----------------+
val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id| date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//| 1|2014-01-03| 10| [10]|
//| 1|2014-01-04| 5| [10, 5]|
//| 1|2014-01-05| 15| [10, 5, 15]|
//| 1|2014-01-06| 20| [10, 5, 15, 20]|
//| 2|2014-02-10| 100| [100]|
//| 2|2014-02-11| 500| [100, 500]|
//| 2|2014-02-15| 1500| [100, 500, 1500]|
//+---+----------+-----+---------------------+
val r2 = sortedDf.groupBy(col("id"))
.agg(max("values_sorted_by_date").as("value_list"))
r2.show()
//+---+----------------+
//| id| value_list|
//+---+----------------+
//| 1| [10, 5, 15, 20]|
//| 2|[100, 500, 1500]|
//+---+----------------+
Pour nous assurer que le tri est effectué pour chaque identifiant, nous pouvons utiliser sortWithinPartitions:
from pyspark.sql import functions as F
ordered_df = (
input_df
.repartition(input_df.id)
.sortWithinPartitions(['date'])
)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))