web-dev-qa-db-fra.com

Appliquer une fonction aux données groupBy avec pyspark

J'essaie d'obtenir le nombre de mots à partir d'un fichier CSV lors du regroupement sur une autre colonne. Mon csv a trois colonnes: id, message et user_id. Je lis ceci, puis je divise le message et stocke une liste d'unigrammes:

+-----------------+--------------------+--------------------+
|               id|             message|             user_id|
+-----------------+--------------------+--------------------+
|10100720363468236|[i'm, sad, to, mi...|dceafb541a1b8e894...|
|10100718944611636|[what, does, the,...|dceafb541a1b8e894...|
|10100718890699676|[at, the, oecd, w...|dceafb541a1b8e894...|
+-----------------+--------------------+--------------------+

Ensuite, étant donné mon dataframe df, je veux regrouper par user_id, puis obtenir les décomptes pour chacun des unigrammes. Comme un premier passage simple, j'ai essayé de regrouper par user_id et obtenez la longueur du champ de message groupé:

from collections import Counter
from pyspark.sql.types import ArrayType, StringType, IntegerType
from pyspark.sql.functions import udf

df = self.session.read.csv(self.corptable, header=True,
        mode="DROPMALFORMED",)

# split my messages ....
# message is now ArrayType(StringType())

grouped = df.groupBy(df["user_id"])
counter = udf(lambda l: len(l), ArrayType(StringType()))
grouped.agg(counter(df["message"]))
print(grouped.collect())

J'obtiens l'erreur suivante:

pyspark.sql.utils.AnalysisException: "expression '`message`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;"

Vous ne savez pas comment contourner cette erreur. En général, comment appliquer une fonction à une colonne lors du regroupement d'une autre? Dois-je toujours créer une fonction définie par l'utilisateur? Très nouveau pour Spark.

Edit: Voici comment j'ai résolu cela, étant donné un tokenizer dans un fichier séparé Python:

group_field = "user_id"
message_field = "message"

context = SparkContext()
session = SparkSession\
        .builder\
        .appName("dlastk")\
        .getOrCreate()

# add tokenizer
context.addPyFile(tokenizer_path)
from tokenizer import Tokenizer
tokenizer = Tokenizer()
spark_tokenizer = udf(tokenizer.tokenize, ArrayType(StringType()))

df = session.read.csv("myFile.csv", header=True,)
df = df[group_field, message_field]

# tokenize the message field
df = df.withColumn(message_field, spark_tokenizer(df[message_field]))

# create ngrams from tokenized messages
n = 1
grouped = df.rdd.map(lambda row: (row[0], Counter([" ".join(x) for x in Zip(*[row[1][i:] for i in range(n)])]))).reduceByKey(add)

# flatten the rdd so that each row contains (group_id, ngram, count, relative frequency
flat = grouped.flatMap(lambda row: [[row[0], x,y, y/sum(row[1].values())] for x,y in row[1].items()])

# rdd -> DF
flat = flat.toDF()
flat.write.csv("myNewCSV.csv")

Les données ressemblent à:

# after read
+--------------------+--------------------+
|             user_id|             message|
+--------------------+--------------------+
|00035fb0dcfbeaa8b...|To the douchebag ...|
|00035fb0dcfbeaa8b...|   T minus 1 week...|
|00035fb0dcfbeaa8b...|Last full day of ...|
+--------------------+--------------------+

# after tokenize
+--------------------+--------------------+
|             user_id|             message|
+--------------------+--------------------+
|00035fb0dcfbeaa8b...|[to, the, doucheb...|
|00035fb0dcfbeaa8b...|[t, minus, 1, wee...|
|00035fb0dcfbeaa8b...|[last, full, day,...|
+--------------------+--------------------+

# grouped: after 1grams extracted and Counters added
[('00035fb0dcfbeaa8bb70ffe24d614d4dcee446b803eb4063dccf14dd2a474611', Counter({'!': 545, '.': 373, 'the': 306, '"': 225, ...

# flat: after calculating sum and relative frequency for each 1gram
[['00035fb0dcfbeaa8bb70ffe24d614d4dcee446b803eb4063dccf14dd2a474611', 'face', 3, 0.000320547066994337], ['00035fb0dcfbeaa8bb70ffe24d614d4dcee446b803eb4063dccf14dd2a474611', 'was', 26, 0.002778074580617587] ....

# after flat RDD to DF
+--------------------+---------+---+--------------------+
|                  _1|       _2| _3|                  _4|
+--------------------+---------+---+--------------------+
|00035fb0dcfbeaa8b...|     face|  3| 3.20547066994337E-4|
|00035fb0dcfbeaa8b...|      was| 26|0.002778074580617587|
|00035fb0dcfbeaa8b...|      how| 22|0.002350678491291...|
+--------------------+---------+---+--------------------+
9
Sal

Une approche naturelle pourrait être de regrouper les mots dans une liste, puis d'utiliser la fonction python Counter() pour générer le nombre de mots. Pour les deux étapes, nous utiliserons udf. Tout d'abord, celui qui aplatira la liste imbriquée résultant de collect_list() de plusieurs tableaux:

unpack_udf = udf(
    lambda l: [item for sublist in l for item in sublist]
)

Deuxièmement, celui qui génère les tuples de comptage de mots, ou dans notre cas struct:

from pyspark.sql.types import *
from collections import Counter

# We need to specify the schema of the return object
schema_count = ArrayType(StructType([
    StructField("Word", StringType(), False),
    StructField("count", IntegerType(), False)
]))

count_udf = udf(
    lambda s: Counter(s).most_common(), 
    schema_count
)

Mettre tous ensemble:

from pyspark.sql.functions import collect_list

(df.groupBy("id")
 .agg(collect_list("message").alias("message"))
 .withColumn("message", unpack_udf("message"))
 .withColumn("message", count_udf("message"))).show(truncate = False)
+-----------------+------------------------------------------------------+
|id               |message                                               |
+-----------------+------------------------------------------------------+
|10100718890699676|[[oecd,1], [the,1], [with,1], [at,1]]                 |
|10100720363468236|[[what,3], [me,1], [sad,1], [to,1], [does,1], [the,1]]|
+-----------------+------------------------------------------------------+

Données:

df = sc.parallelize([(10100720363468236,["what", "sad", "to", "me"]),
                     (10100720363468236,["what", "what", "does", "the"]),
                     (10100718890699676,["at", "the", "oecd", "with"])]).toDF(["id", "message"])
14
mtoto

Essayer:

from  pyspark.sql.functions import *

df.withColumn("Word", explode("message")) \
  .groupBy("user_id", "Word").count() \
  .groupBy("user_id") \
  .agg(collect_list(struct("Word", "count")))
1
user6022341