J'ai un très grand ensemble de données qui est chargé dans Hive. Il se compose d'environ 1,9 million de lignes et 1450 colonnes. J'ai besoin de déterminer la "couverture" de chacune des colonnes, c'est-à-dire la fraction de lignes qui ont des valeurs non NaN pour chaque colonne.
Voici mon code:
from pyspark import SparkContext
from pyspark.sql import HiveContext
import string as string
sc = SparkContext(appName="compute_coverages") ## Create the context
sqlContext = HiveContext(sc)
df = sqlContext.sql("select * from data_table")
nrows_tot = df.count()
covgs=sc.parallelize(df.columns)
.map(lambda x: str(x))
.map(lambda x: (x, float(df.select(x).dropna().count()) / float(nrows_tot) * 100.))
En essayant cela dans le shell pyspark, si je fais ensuite covgs.take (10), cela renvoie une pile d'erreurs assez grande. Il indique qu'il y a un problème d'enregistrement dans le fichier /usr/lib64/python2.6/pickle.py
. Ceci est la dernière partie de l'erreur:
py4j.protocol.Py4JError: An error occurred while calling o37.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.Java:333)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.Java:342)
at py4j.Gateway.invoke(Gateway.Java:252)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.Java:133)
at py4j.commands.CallCommand.execute(CallCommand.Java:79)
at py4j.GatewayConnection.run(GatewayConnection.Java:207)
at Java.lang.Thread.run(Thread.Java:745)
S'il existe une meilleure façon d'accomplir cela que celle que j'essaie, je suis ouvert aux suggestions. Je ne peux pas utiliser de pandas, cependant, car il n'est pas actuellement disponible sur le cluster sur lequel je travaille et je n'ai pas les droits pour l'installer.
Commençons par des données fictives:
from pyspark.sql import Row
row = Row("v", "x", "y", "z")
df = sc.parallelize([
row(0.0, 1, 2, 3.0), row(None, 3, 4, 5.0),
row(None, None, 6, 7.0), row(float("Nan"), 8, 9, float("NaN"))
]).toDF()
## +----+----+---+---+
## | v| x| y| z|
## +----+----+---+---+
## | 0.0| 1| 2|3.0|
## |null| 3| 4|5.0|
## |null|null| 6|7.0|
## | NaN| 8| 9|NaN|
## +----+----+---+---+
Tout ce dont vous avez besoin est une simple agrégation:
from pyspark.sql.functions import col, count, isnan, lit, sum
def count_not_null(c, nan_as_null=False):
"""Use conversion between boolean and integer
- False -> 0
- True -> 1
"""
pred = col(c).isNotNull() & (~isnan(c) if nan_as_null else lit(True))
return sum(pred.cast("integer")).alias(c)
df.agg(*[count_not_null(c) for c in df.columns]).show()
## +---+---+---+---+
## | v| x| y| z|
## +---+---+---+---+
## | 2| 3| 4| 4|
## +---+---+---+---+
ou si vous voulez traiter NaN
a NULL
:
df.agg(*[count_not_null(c, True) for c in df.columns]).show()
## +---+---+---+---+
## | v| x| y| z|
## +---+---+---+---+
## | 1| 3| 4| 3|
## +---+---+---+---
Vous pouvez également utiliser la sémantique SQL NULL
pour obtenir le même résultat sans créer de fonction personnalisée:
df.agg(*[
count(c).alias(c) # vertical (column-wise) operations in SQL ignore NULLs
for c in df.columns
]).show()
## +---+---+---+
## | x| y| z|
## +---+---+---+
## | 1| 2| 3|
## +---+---+---+
mais cela ne fonctionnera pas avec NaNs
.
Si vous préférez les fractions:
exprs = [(count_not_null(c) / count("*")).alias(c) for c in df.columns]
df.agg(*exprs).show()
## +------------------+------------------+---+
## | x| y| z|
## +------------------+------------------+---+
## |0.3333333333333333|0.6666666666666666|1.0|
## +------------------+------------------+---+
ou
# COUNT(*) is equivalent to COUNT(1) so NULLs won't be an issue
df.select(*[(count(c) / count("*")).alias(c) for c in df.columns]).show()
## +------------------+------------------+---+
## | x| y| z|
## +------------------+------------------+---+
## |0.3333333333333333|0.6666666666666666|1.0|
## +------------------+------------------+---+
Équivalent Scala:
import org.Apache.spark.sql.Column
import org.Apache.spark.sql.functions.{col, isnan, sum}
type JDouble = Java.lang.Double
val df = Seq[(JDouble, JDouble, JDouble, JDouble)](
(0.0, 1, 2, 3.0), (null, 3, 4, 5.0),
(null, null, 6, 7.0), (Java.lang.Double.NaN, 8, 9, Java.lang.Double.NaN)
).toDF()
def count_not_null(c: Column, nanAsNull: Boolean = false) = {
val pred = c.isNotNull and (if (nanAsNull) not(isnan(c)) else lit(true))
sum(pred.cast("integer"))
}
df.select(df.columns map (c => count_not_null(col(c)).alias(c)): _*).show
// +---+---+---+---+
// | _1| _2| _3| _4|
// +---+---+---+---+
// | 2| 3| 4| 4|
// +---+---+---+---+
df.select(df.columns map (c => count_not_null(col(c), true).alias(c)): _*).show
// +---+---+---+---+
// | _1| _2| _3| _4|
// +---+---+---+---+
// | 1| 3| 4| 3|
// +---+---+---+---+