web-dev-qa-db-fra.com

Spark SQL: applique des fonctions d'agrégation à une liste de colonnes

Existe-t-il un moyen d'appliquer une fonction d'agrégat à toutes les colonnes (ou à une liste) d'un dataframe, lors de l'exécution d'un groupBy? En d'autres termes, y a-t-il un moyen d'éviter de le faire pour chaque colonne:

df.groupBy("col1")
  .agg(sum("col2").alias("col2"), sum("col3").alias("col3"), ...)
55
lilloraffa

Il existe plusieurs façons d'appliquer des fonctions d'agrégation à plusieurs colonnes.

La classe GroupedData fournit un certain nombre de méthodes pour les fonctions les plus courantes, notamment count, max, min, mean et sum, qui peuvent être utilisé directement comme suit:

  • Python:

    df = sqlContext.createDataFrame(
        [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)],
        ("col1", "col2", "col3"))
    
    df.groupBy("col1").sum()
    
    ## +----+---------+-----------------+---------+
    ## |col1|sum(col1)|        sum(col2)|sum(col3)|
    ## +----+---------+-----------------+---------+
    ## | 1.0|      2.0|              0.8|      1.0|
    ## |-1.0|     -2.0|6.199999999999999|      0.7|
    ## +----+---------+-----------------+---------+
    
  • Scala

    val df = sc.parallelize(Seq(
      (1.0, 0.3, 1.0), (1.0, 0.5, 0.0),
      (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2))
    ).toDF("col1", "col2", "col3")
    
    df.groupBy($"col1").min().show
    
    // +----+---------+---------+---------+
    // |col1|min(col1)|min(col2)|min(col3)|
    // +----+---------+---------+---------+
    // | 1.0|      1.0|      0.3|      0.0|
    // |-1.0|     -1.0|      0.6|      0.2|
    // +----+---------+---------+---------+
    

Vous pouvez éventuellement passer une liste de colonnes à agréger

df.groupBy("col1").sum("col2", "col3")

Vous pouvez également transmettre dictionnaire/mappe avec des colonnes aux touches et fonctions en tant que valeurs:

  • Python

    exprs = {x: "sum" for x in df.columns}
    df.groupBy("col1").agg(exprs).show()
    
    ## +----+---------+
    ## |col1|avg(col3)|
    ## +----+---------+
    ## | 1.0|      0.5|
    ## |-1.0|     0.35|
    ## +----+---------+
    
  • Scala

    val exprs = df.columns.map((_ -> "mean")).toMap
    df.groupBy($"col1").agg(exprs).show()
    
    // +----+---------+------------------+---------+
    // |col1|avg(col1)|         avg(col2)|avg(col3)|
    // +----+---------+------------------+---------+
    // | 1.0|      1.0|               0.4|      0.5|
    // |-1.0|     -1.0|3.0999999999999996|     0.35|
    // +----+---------+------------------+---------+
    

Enfin, vous pouvez utiliser varargs:

  • Python

    from pyspark.sql.functions import min
    
    exprs = [min(x) for x in df.columns]
    df.groupBy("col1").agg(*exprs).show()
    
  • Scala

    import org.Apache.spark.sql.functions.sum
    
    val exprs = df.columns.map(sum(_))
    df.groupBy($"col1").agg(exprs.head, exprs.tail: _*)
    

Il existe un autre moyen d’obtenir un effet similaire, mais cela devrait suffire amplement la plupart du temps.

Voir également:

98
zero323

Un autre exemple du même concept - mais disons - vous avez 2 colonnes différentes - et vous souhaitez appliquer différentes fonctions agg à chacune d’elles, c.-à-d.

f.groupBy("col1").agg(sum("col2").alias("col2"), avg("col3").alias("col3"), ...)

Voici le moyen d'y parvenir - bien que je ne sache pas encore comment ajouter l'alias dans ce cas

Voir l'exemple ci-dessous - Utiliser des cartes

val Claim1 = StructType(Seq(StructField("pid", StringType, true),StructField("diag1", StringType, true),StructField("diag2", StringType, true), StructField("allowed", IntegerType, true), StructField("allowed1", IntegerType, true)))
val claimsData1 = Seq(("PID1", "diag1", "diag2", 100, 200), ("PID1", "diag2", "diag3", 300, 600), ("PID1", "diag1", "diag5", 340, 680), ("PID2", "diag3", "diag4", 245, 490), ("PID2", "diag2", "diag1", 124, 248))

val claimRDD1 = sc.parallelize(claimsData1)
val claimRDDRow1 = claimRDD1.map(p => Row(p._1, p._2, p._3, p._4, p._5))
val claimRDD2DF1 = sqlContext.createDataFrame(claimRDDRow1, Claim1)

val l = List("allowed", "allowed1")
val exprs = l.map((_ -> "sum")).toMap
claimRDD2DF1.groupBy("pid").agg(exprs) show false
val exprs = Map("allowed" -> "sum", "allowed1" -> "avg")

claimRDD2DF1.groupBy("pid").agg(exprs) show false
16
Sumit Pal