web-dev-qa-db-fra.com

Comment faire pivoter DataFrame?

Je commence à utiliser Spark DataFrames et je dois pouvoir faire pivoter les données pour créer plusieurs colonnes sur une colonne avec plusieurs lignes. Il existe des fonctionnalités intégrées dans Scalding et je crois aux Pandas en Python, mais je ne trouve rien pour le nouveau Spark Dataframe.

Je suppose que je peux écrire une fonction personnalisée pour ce faire mais je ne sais même pas comment commencer, d’autant plus que je suis novice avec Spark. Tout le monde sait comment faire cela avec des fonctionnalités intégrées ou des suggestions pour écrire quelque chose en Scala, cela est grandement apprécié. 

38
J Calbreath

Comme mentionné by David Anderson Spark fournit la fonction pivot depuis la version 1.6. La syntaxe générale ressemble à ceci:

df
  .groupBy(grouping_columns)
  .pivot(pivot_column, [values]) 
  .agg(aggregate_expressions)

Exemples d'utilisation utilisant les formats nycflights13 et csv:

Python:

from pyspark.sql.functions import avg

flights = (sqlContext
    .read
    .format("csv")
    .options(inferSchema="true", header="true")
    .load("flights.csv")
    .na.drop())

flights.registerTempTable("flights")
sqlContext.cacheTable("flights")

gexprs = ("Origin", "dest", "carrier")
aggexpr = avg("arr_delay")

flights.count()
## 336776

%timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
## 10 loops, best of 3: 1.03 s per loop

Scala:

val flights = sqlContext
  .read
  .format("csv")
  .options(Map("inferSchema" -> "true", "header" -> "true"))
  .load("flights.csv")

flights
  .groupBy($"Origin", $"dest", $"carrier")
  .pivot("hour")
  .agg(avg($"arr_delay"))

Java:

import static org.Apache.spark.sql.functions.*;
import org.Apache.spark.sql.*;

Dataset<Row> df = spark.read().format("csv")
        .option("inferSchema", "true")
        .option("header", "true")
        .load("flights.csv");

df.groupBy(col("Origin"), col("dest"), col("carrier"))
        .pivot("hour")
        .agg(avg(col("arr_delay")));

R/SparkR:

library(magrittr)

flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)

flights %>% 
  groupBy("Origin", "dest", "carrier") %>% 
  pivot("hour") %>% 
  agg(avg(column("arr_delay")))

R/sparklyr

library(dplyr)

flights <- spark_read_csv(sc, "flights", "flights.csv")

avg.arr.delay <- function(gdf) {
   expr <- invoke_static(
      sc,
      "org.Apache.spark.sql.functions",
      "avg",
      "arr_delay"
    )
    gdf %>% invoke("agg", expr, list())
}

flights %>% 
  sdf_pivot(Origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay)

SQL:

CREATE TEMPORARY VIEW flights 
USING csv 
OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;

 SELECT * FROM (
   SELECT Origin, dest, carrier, arr_delay, hour FROM flights
 ) PIVOT (
   avg(arr_delay)
   FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
 );

Exemple de données:

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","Origin","dest","air_time","distance","hour","minute","time_hour"
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00

Considérations sur les performances:

De manière générale, le pivotement est une opération coûteuse. 

Questions connexes:

52
zero323

J'ai surmonté cela en écrivant une boucle for pour créer dynamiquement une requête SQL. Dis que j'ai:

id  tag  value
1   US    50
1   UK    100
1   Can   125
2   US    75
2   UK    150
2   Can   175

et je veux:

id  US  UK   Can
1   50  100  125
2   75  150  175

Je peux créer une liste avec la valeur que je souhaite faire pivoter, puis créer une chaîne contenant la requête SQL dont j'ai besoin.

val countries = List("US", "UK", "Can")
val numCountries = countries.length - 1

var query = "select *, "
for (i <- 0 to numCountries-1) {
  query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", "
}
query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"

myDataFrame.registerTempTable("myTable")
val myDF1 = sqlContext.sql(query)

Je peux créer une requête similaire pour ensuite effectuer l'agrégation. Ce n'est pas une solution très élégante, mais cela fonctionne et est flexible pour n'importe quelle liste de valeurs, qui peut également être transmise en tant qu'argument lorsque votre code est appelé.

13
J Calbreath

Un opérateur pivot a été ajouté à l'API de base de données Spark et fait partie de Spark 1.6.

Voir https://github.com/Apache/spark/pull/7841 pour plus de détails.

9
David Anderson

J'ai résolu un problème similaire en utilisant des cadres de données en procédant comme suit:

Créez des colonnes pour tous vos pays, avec 'valeur' ​​comme valeur:

import org.Apache.spark.sql.functions._
val countries = List("US", "UK", "Can")
val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) =>
  if(countryToCheck == countryInRow) value else 0
}
val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) }
val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")

Votre dataframe 'dfWithCountries' ressemblera à ceci:

+--+--+---+---+
|id|US| UK|Can|
+--+--+---+---+
| 1|50|  0|  0|
| 1| 0|100|  0|
| 1| 0|  0|125|
| 2|75|  0|  0|
| 2| 0|150|  0|
| 2| 0|  0|175|
+--+--+---+---+

Maintenant, vous pouvez additionner toutes les valeurs pour le résultat souhaité:

dfWithCountries.groupBy("id").sum(countries: _*).show

Résultat:

+--+-------+-------+--------+
|id|SUM(US)|SUM(UK)|SUM(Can)|
+--+-------+-------+--------+
| 1|     50|    100|     125|
| 2|     75|    150|     175|
+--+-------+-------+--------+

Ce n'est pas une solution très élégante cependant. Je devais créer une chaîne de fonctions à ajouter dans toutes les colonnes. De plus, si j'ai beaucoup de pays, je vais étendre mon jeu de données temporaire à un très large ensemble avec beaucoup de zéros.

5
Al M

Au départ, j'ai adopté la solution d'Al M. Plus tard, pris la même pensée et réécrivit cette fonction en tant que fonction de transposition.

Cette méthode transpose toutes les lignes df en colonnes de tout format de données en utilisant des colonnes clé et valeur.

pour l'entrée csv

id,tag,value
1,US,50a
1,UK,100
1,Can,125
2,US,75
2,UK,150
2,Can,175

ouput

+--+---+---+---+
|id| UK| US|Can|
+--+---+---+---+
| 2|150| 75|175|
| 1|100|50a|125|
+--+---+---+---+

méthode de transposition:

def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = {

val distinctCols =   df.select(key).distinct.map { r => r(0) }.collect().toList

val rdd = df.map { row =>
(compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] },
scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any]))
}
val pairRdd = rdd.reduceByKey(_ ++ _)
val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols))
hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols)))

}

private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = {
val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) }
val array = r._1 ++ cols
Row(array: _*)
}

private  def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = {
val idSchema = idCols.map { idCol => srcSchema.apply(idCol) }
val colSchema = srcSchema.apply(distinctCols._1)
val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) }
StructType(idSchema ++ colsSchema)
}

extrait principal

import Java.util.Date
import org.Apache.spark.SparkConf
import org.Apache.spark.SparkContext
import org.Apache.spark.sql.Row
import org.Apache.spark.sql.DataFrame
import org.Apache.spark.sql.types.StructType
import org.Apache.spark.sql.Hive.HiveContext
import org.Apache.spark.sql.types.StructField


...
...
def main(args: Array[String]): Unit = {

    val sc = new SparkContext(conf)
    val sqlContext = new org.Apache.spark.sql.SQLContext(sc)
    val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true")
    .load("data.csv")
    dfdata1.show()  
    val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value")
    dfOutput.show

}
0
Jaigates