J'ai des données au format suivant (RDD ou Spark DataFrame):
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])
# convert to a Spark DataFrame
schema = StructType([StructField('ID', StringType(), True),
StructField('Age', IntegerType(), True),
StructField('Country', StringType(), True),
StructField('Score', IntegerType(), True)])
df = sqlContext.createDataFrame(rdd, schema)
Ce que je voudrais faire, c'est "remodeler" les données, convertir certaines lignes du pays (en particulier les États-Unis, le Royaume-Uni et la Californie) en colonnes:
ID Age US UK CA
'X01' 41 3 1 2
'X02' 72 4 6 7
Essentiellement, j'ai besoin de quelque chose dans le sens du workflow pivot
de Python:
categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID',
columns = 'Country',
values = 'Score')
Mon ensemble de données est assez volumineux, donc je ne peux pas vraiment collect()
et ingérer les données en mémoire pour effectuer le remodelage dans Python lui-même. Existe-t-il un moyen de convertir .pivot()
dans une fonction invocable lors du mappage d'un RDD ou d'un Spark DataFrame? Toute aide serait appréciée!
Depuis Spark 1.6 vous pouvez utiliser la fonction pivot
sur GroupedData
et fournir une expression agrégée.
pivoted = (df
.groupBy("ID", "Age")
.pivot(
"Country",
['US', 'UK', 'CA']) # Optional list of levels
.sum("Score")) # alternatively you can use .agg(expr))
pivoted.show()
## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41| 3| 1| 2|
## |X02| 72| 4| 6| 7|
## +---+---+---+---+---+
Les niveaux peuvent être omis, mais s'ils sont fournis, ils peuvent à la fois améliorer les performances et servir de filtre interne.
Cette méthode est encore relativement lente mais bat certainement les données de passage manuel manuellement entre JVM et Python.
Tout d'abord, ce n'est probablement pas une bonne idée, car vous n'obtenez aucune information supplémentaire, mais vous vous liez avec un schéma fixe (c'est-à-dire que vous devez savoir combien de pays vous attendez, et bien sûr, un pays supplémentaire signifie changement de code)
Cela dit, il s'agit d'un problème SQL, qui est illustré ci-dessous. Mais au cas où vous supposeriez que ce n'est pas trop "logiciel comme" (sérieusement, j'ai entendu cela !!), alors vous pouvez vous référer à la première solution.
Solution 1:
def reshape(t):
out = []
out.append(t[0])
out.append(t[1])
for v in brc.value:
if t[2] == v:
out.append(t[3])
else:
out.append(0)
return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
if t[2] in brc.value:
return t
else:
pass
def addtup(t1,t2):
j=()
for k,v in enumerate(t1):
j=j+(t1[k]+t2[k],)
return j
def seq(tIntrm,tNext):
return addtup(tIntrm,tNext)
def comb(tP,tF):
return addtup(tP,tF)
countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
print i
Maintenant, Solution 2: Bien sûr, mieux car SQL est le bon outil pour cela
callRow = calls.map(lambda t:
Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
from (select userid,age,\
case when country='CA' then nbrCalls else 0 end ca,\
case when country='UK' then nbrCalls else 0 end uk,\
case when country='US' then nbrCalls else 0 end us,\
case when country='XX' then nbrCalls else 0 end xx \
from calls) x \
group by userid,age")
res.show()
configuration des données:
data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']
Résultat:
De la 1ère solution
(('X02', 72), (7, 6, 4, 8))
(('X01', 41), (2, 1, 3, 0))
De la 2ème solution:
root |-- age: long (nullable = true)
|-- country: string (nullable = true)
|-- nbrCalls: long (nullable = true)
|-- userid: string (nullable = true)
userid age ca uk us xx
X02 72 7 6 4 8
X01 41 2 1 3 0
Veuillez me faire savoir si cela fonctionne ou non :)
Meilleur Ayan
Voici une approche native Spark qui ne câble pas les noms des colonnes. Elle est basée sur aggregateByKey
, et utilise un dictionnaire pour collecter les colonnes qui apparaissent pour chaque clé. Ensuite, nous rassemblons tous les noms de colonnes pour créer la trame de données finale. [La version précédente utilisait jsonRDD après avoir émis un dictionnaire pour chaque enregistrement, mais c'est plus efficace.] Restreindre à une liste spécifique de colonnes, ou exclure celles comme XX
serait une modification facile.
Les performances semblent bonnes même sur des tables assez grandes. J'utilise une variation qui compte le nombre de fois que chacun d'un nombre variable d'événements se produit pour chaque ID, générant une colonne par type d'événement. Le code est fondamentalement le même sauf qu'il utilise un collections.Counter au lieu d'un dict dans le seqFn
pour compter les occurrences.
from pyspark.sql.types import *
rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])
schema = StructType([StructField('ID', StringType(), True),
StructField('Age', IntegerType(), True),
StructField('Country', StringType(), True),
StructField('Score', IntegerType(), True)])
df = sqlCtx.createDataFrame(rdd, schema)
def seqPivot(u, v):
if not u:
u = {}
u[v.Country] = v.Score
return u
def cmbPivot(u1, u2):
u1.update(u2)
return u1
pivot = (
df
.rdd
.keyBy(lambda row: row.ID)
.aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
pivot
.values()
.map(lambda u: set(u.keys()))
.reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
pivot
.map(lambda (k, u): [k] + [u.get(c) for c in columns]),
schema=StructType(
[StructField('ID', StringType())] +
[StructField(c, IntegerType()) for c in columns]
)
)
result.show()
Produit:
ID CA UK US XX
X02 7 6 4 8
X01 2 1 3 null
Donc, tout d'abord, j'ai dû apporter cette correction à votre RDD (qui correspond à votre sortie réelle):
rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])
Une fois que j'ai fait cette correction, cela a fait l'affaire:
df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
$"ID" === $"usID" and $"C1" === "US"
)
.join(
df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
$"ID" === $"ukID" and $"C2" === "UK"
)
.join(
df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"),
$"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")
Certainement pas aussi élégant que votre pivot.
Juste quelques commentaires sur la réponse très utile de patricksurry:
Voici le code légèrement modifié:
from pyspark.sql.types import *
rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])
schema = StructType([StructField('ID', StringType(), True),
StructField('Age', IntegerType(), True),
StructField('Country', StringType(), True),
StructField('Score', IntegerType(), True)])
df = sqlCtx.createDataFrame(rdd, schema)
# u is a dictionarie
# v is a Row
def seqPivot(u, v):
if not u:
u = {}
u[v.Country] = v.Score
# In the original posting the Age column was not specified
u["Age"] = v.Age
return u
# u1
# u2
def cmbPivot(u1, u2):
u1.update(u2)
return u1
pivot = (
rdd
.map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2], Score=row[3]))
.keyBy(lambda row: row.ID)
.aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
pivot
.values()
.map(lambda u: set(u.keys()))
.reduce(lambda s,t: s.union(t))
)
columns_ord = sorted(columns)
result = sqlCtx.createDataFrame(
pivot
.map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
schema=StructType(
[StructField('ID', StringType())] +
[StructField(c, IntegerType()) for c in columns_ord]
)
)
print result.show()
Enfin, la sortie doit être
+---+---+---+---+---+----+
| ID|Age| CA| UK| US| XX|
+---+---+---+---+---+----+
|X02| 72| 7| 6| 4| 8|
|X01| 41| 2| 1| 3|null|
+---+---+---+---+---+----+
Il y a un JIRA dans Hive pour PIVOT pour le faire en natif, sans une énorme instruction CASE pour chaque valeur:
https://issues.Apache.org/jira/browse/Hive-3776
Veuillez voter pour cette JIRA afin qu'elle soit mise en œuvre plus tôt. Une fois qu'il est dans Hive SQL, Spark ne manque généralement pas trop derrière et il sera finalement implémenté dans Spark également).