J'ai un DataFrame généré comme suit:
df.groupBy($"Hour", $"Category")
.agg(sum($"value") as "TotalValue")
.sort($"Hour".asc, $"TotalValue".desc))
Les résultats ressemblent à:
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 0| cat13| 22.1|
| 0| cat95| 19.6|
| 0| cat105| 1.3|
| 1| cat67| 28.5|
| 1| cat4| 26.8|
| 1| cat13| 12.6|
| 1| cat23| 5.3|
| 2| cat56| 39.6|
| 2| cat40| 29.7|
| 2| cat187| 27.9|
| 2| cat68| 9.8|
| 3| cat8| 35.6|
| ...| ....| ....|
+----+--------+----------+
Comme vous pouvez le constater, le DataFrame est classé par Hour
dans un ordre croissant, puis par TotalValue
par ordre décroissant.
J'aimerais sélectionner la rangée du haut de chaque groupe, c'est-à-dire.
Donc, le résultat souhaité serait:
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
| 3| cat8| 35.6|
| ...| ...| ...|
+----+--------+----------+
Il peut être utile de pouvoir également sélectionner les N premières lignes de chaque groupe.
Toute aide est grandement appréciée.
Fonctions de la fenêtre:
Quelque chose comme ça devrait faire l'affaire:
import org.Apache.spark.sql.functions.{row_number, max, broadcast}
import org.Apache.spark.sql.expressions.Window
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
Cette méthode sera inefficace en cas de biais important des données.
Agrégation SQL simple suivie de join
:
Sinon, vous pouvez rejoindre avec un bloc de données agrégé:
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
.drop("max_hour")
.drop("max_value")
dfTopByJoin.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
Il conservera les valeurs en double (s'il y a plus d'une catégorie par heure avec la même valeur totale). Vous pouvez les supprimer comme suit:
dfTopByJoin
.groupBy($"hour")
.agg(
first("category").alias("category"),
first("TotalValue").alias("TotalValue"))
Utilisation de la commande sur structs
:
Astuce soignée, mais pas très bien testée, ne nécessitant ni jointure ni fonctions de fenêtre:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
.groupBy($"hour")
.agg(max("vs").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
Avec API DataSet (Spark 1.6+, 2.0+):
Spark 1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)
df.as[Record]
.groupBy($"hour")
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
.show
// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+
Spark 2.0 ou version ultérieure:
df.as[Record]
.groupByKey(_.Hour)
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
Les deux dernières méthodes peuvent tirer parti de la combinaison côté carte et ne nécessitent pas de lecture aléatoire. La plupart du temps, elles devraient donc offrir de meilleures performances que les fonctions de fenêtre et les jointures. Celles-ci peuvent également être utilisées avec le streaming structuré en mode de sortie completed
.
N'utilisez pas:
df.orderBy(...).groupBy(...).agg(first(...), ...)
Cela peut sembler fonctionner (surtout dans le mode local
) mais il n’est pas fiable ( SPARK-16207 ). Crédits à Tzach Zohar pour reliant le problème pertinent de JIRA .
La même note s'applique à
df.orderBy(...).dropDuplicates(...)
qui utilise en interne un plan d'exécution équivalent.
Pour Spark 2.0.2 avec regroupement de plusieurs colonnes:
import org.Apache.spark.sql.functions.row_number
import org.Apache.spark.sql.expressions.Window
val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
C'est exactement la même chose que zero323 's answer mais en mode requête SQL.
En supposant que le cadre de données soit créé et enregistré en tant que
df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0 |cat26 |30.9 |
//|0 |cat13 |22.1 |
//|0 |cat95 |19.6 |
//|0 |cat105 |1.3 |
//|1 |cat67 |28.5 |
//|1 |cat4 |26.8 |
//|1 |cat13 |12.6 |
//|1 |cat23 |5.3 |
//|2 |cat56 |39.6 |
//|2 |cat40 |29.7 |
//|2 |cat187 |27.9 |
//|2 |cat68 |9.8 |
//|3 |cat8 |35.6 |
//+----+--------+----------+
Fonction de la fenêtre:
sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Agrégation SQL simple suivie d'une jointure:
sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
"(select Hour, Category, TotalValue from table tmp1 " +
"join " +
"(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
"on " +
"tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
"group by tmp3.Hour")
.show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Utilisation de la commande sur les structures:
sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Méthode des ensembles de données et ne pas faire s sont identiques à ceux de la réponse d'origine
La solution ci-dessous ne fait qu'un groupe et extrait les lignes de votre cadre de données contenant la valeur maxValue en une seule fois. Pas besoin d'autres jointures, ni de Windows.
import org.Apache.spark.sql.Row
import org.Apache.spark.sql.catalyst.encoders.RowEncoder
import org.Apache.spark.sql.DataFrame
//df is the dataframe with Day, Category, TotalValue
implicit val dfEnc = RowEncoder(df.schema)
val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
Ici tu peux faire comme ça -
val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")
data.withColumnRenamed("_1","Hour").show
Si le dataframe doit être groupé en plusieurs colonnes, cela peut aider
val keys = List("Hour", "Category");
val selectFirstValueOfNoneGroupedColumns =
df.columns
.filterNot(keys.toSet)
.map(_ -> "first").toMap
val grouped =
df.groupBy(keys.head, keys.tail: _*)
.agg(selectFirstValueOfNoneGroupedColumns)
J'espère que cela aide quelqu'un avec un problème similaire
Le modèle est groupe par clés => faire quelque chose à chaque groupe, par exemple. réduire => retourner au dataframe
Je pensais que l'abstraction Dataframe était un peu lourde dans ce cas, alors j'ai utilisé la fonctionnalité RDD
val rdd: RDD[Row] = originalDf
.rdd
.groupBy(row => row.getAs[String]("grouping_row"))
.map(iterableTuple => {
iterableTuple._2.reduce(reduceFunction)
})
val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
Une bonne façon de faire cela avec l’API Dataframe utilise la logique Argmax comme si
val df = Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")
df.groupBy($"Hour")
.agg(max(struct($"TotalValue", $"Category")).as("argmax"))
.select($"Hour", $"argmax.*").show
+----+----------+--------+
|Hour|TotalValue|Category|
+----+----------+--------+
| 1| 28.5| cat67|
| 3| 35.6| cat8|
| 2| 39.6| cat56|
| 0| 30.9| cat26|
+----+----------+--------+