web-dev-qa-db-fra.com

Colonne GroupBy et lignes de filtre avec valeur maximale dans Pyspark

Je suis presque certain que cela a déjà été demandé, mais ne recherche via stackoverflow n'a pas répondu à ma question. Pas un doublon de [2] puisque je veux la valeur maximale, pas l'article le plus fréquent. Je suis nouveau sur pyspark et j'essaie de faire quelque chose de très simple: je veux grouper la colonne "A" et ne conserver que la ligne de chaque groupe qui a la valeur maximale dans la colonne "B". Comme ça:

df_cleaned = df.groupBy("A").agg(F.max("B"))

Malheureusement, cela jette toutes les autres colonnes - df_cleaned ne contient que les colonnes "A" et la valeur maximale de B. Comment puis-je conserver les lignes? ("A", "B", "C" ...)

16
Thomas

Vous pouvez le faire sans udf en utilisant un Window.

Prenons l'exemple suivant:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
    ('b', 3)
]
df = sqlCtx.createDataFrame(data, ["A", "B"])
df.show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  5|
#|  a|  8|
#|  a|  7|
#|  b|  1|
#|  b|  3|
#+---+---+

Créez un Window à partitionner par colonne A et utilisez-le pour calculer le maximum de chaque groupe. Filtrez ensuite les lignes de sorte que la valeur de la colonne B soit égale au maximum.

from pyspark.sql import Window
w = Window.partitionBy('A')
df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  8|
#|  b|  3|
#+---+---+

Ou de manière équivalente en utilisant pyspark-sql:

df.registerTempTable('table')
q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
sqlCtx.sql(q).show()
#+---+---+
#|  A|  B|
#+---+---+
#|  b|  3|
#|  a|  8|
#+---+---+
21
pault

Une autre approche possible consiste à appliquer la jonction de la trame de données avec elle-même en spécifiant "leftsemi". Ce type de jointure inclut toutes les colonnes de la trame de données sur le côté gauche et aucune colonne sur le côté droit.

Par exemple:

import pyspark.sql.functions as f
data = [
    ('a', 5, 'c'),
    ('a', 8, 'd'),
    ('a', 7, 'e'),
    ('b', 1, 'f'),
    ('b', 3, 'g')
]
df = sqlContext.createDataFrame(data, ["A", "B", "C"])
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  5|  c|
|  a|  8|  d|
|  a|  7|  e|
|  b|  1|  f|
|  b|  3|  g|
+---+---+---+

La valeur maximale de la colonne B par la colonne A peut être sélectionnée en faisant:

df.groupBy('A').agg(f.max('B')
+---+---+
|  A|  B|
+---+---+
|  a|  8|
|  b|  3|
+---+---+

En utilisant cette expression comme un côté droit dans une semi-jointure gauche et en renommant la colonne obtenue max(B) en son nom d'origine B, nous pouvons obtenir le résultat requis:

df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').show()
+---+---+---+
|  B|  A|  C|
+---+---+---+
|  3|  b|  g|
|  8|  a|  d|
+---+---+---+

Le plan physique derrière cette solution et celui de la réponse acceptée sont différents et il n'est toujours pas clair pour moi lequel fonctionnera mieux sur les grandes trames de données.

Le même résultat peut être obtenu en utilisant spark syntaxe SQL faisant:

df.registerTempTable('table')
q = '''SELECT *
FROM table a LEFT SEMI
JOIN (
    SELECT 
        A,
        max(B) as max_B
    FROM table
    GROUP BY A
    ) t
ON a.A=t.A AND a.B=t.max_B
'''
sqlContext.sql(q).show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  b|  3|  g|
|  a|  8|  d|
+---+---+---+
4
ndricca