Je suis curieux de savoir s'il existe quelque chose de similaire à http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html pour Apache-spark dans la dernière version 2.0.1.
Jusqu'à présent, je n'ai trouvé que https://spark.Apache.org/docs/latest/mllib-statistics.html#stratified-sampling qui ne semble pas être un bon choix pour scinder un jeu de données très déséquilibré en train/échantillons de test.
Spark prend en charge les échantillons stratifiés, comme indiqué dans https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html
df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0)
Bien que cette réponse ne soit pas spécifique à Spark, dans Apache beam, je fais ceci pour diviser le train 66% et tester 33% (juste un exemple illustratif, vous pouvez personnaliser le partition_fn ci-dessous pour qu'il soit plus sophistiqué et accepter des arguments tels que le nombre de choix de compartiments ou de biais vers quelque chose ou assurer que la randomisation est juste dans toutes les dimensions, etc.):
raw_data = p | 'Read Data' >> Read(...)
clean_data = (raw_data
| "Clean Data" >> beam.ParDo(CleanFieldsFn())
def partition_fn(element):
return random.randint(0, 2)
random_buckets = (clean_data | beam.Partition(partition_fn, 3))
clean_train_data = ((random_buckets[0], random_buckets[1])
| beam.Flatten())
clean_eval_data = random_buckets[2]
Supposons que nous ayons un jeu de données comme celui-ci:
+---+-----+
| id|label|
+---+-----+
| 0| 0.0|
| 1| 1.0|
| 2| 0.0|
| 3| 1.0|
| 4| 0.0|
| 5| 1.0|
| 6| 0.0|
| 7| 1.0|
| 8| 0.0|
| 9| 1.0|
+---+-----+
Cet ensemble de données est parfaitement équilibré, mais cette approche fonctionnera également pour les données non équilibrées.
A présent, ajoutons à ce DataFrame des informations supplémentaires qui seront utiles pour déterminer les lignes qui doivent être entraînées. Les étapes sont les suivantes:
ratio
.label
, puis classer les observations de chaque étiquette à l'aide de row_number()
.Nous nous retrouvons avec le bloc de données suivant:
+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
| 6| 0.0| 1|
| 2| 0.0| 2|
| 0| 0.0| 3|
| 4| 0.0| 4|
| 8| 0.0| 5|
| 9| 1.0| 1|
| 5| 1.0| 2|
| 3| 1.0| 3|
| 1| 1.0| 4|
| 7| 1.0| 5|
+---+-----+----------+
Remarque: les lignes sont mélangées (voir: ordre aléatoire dans la colonne id
), partitionnées par libellé (voir: la colonne label
) et classées.
Supposons que nous souhaitons faire une répartition à 80%. Dans ce cas, nous aimerions que quatre étiquettes 1.0
et quatre étiquettes 0.0
soient dirigées vers un ensemble de données de formation, et une étiquette 1.0
et une étiquette 0.0
pour un ensemble de données test. Nous avons ces informations dans la colonne row_number
; nous pouvons donc maintenant les utiliser simplement dans une fonction définie par l'utilisateur (si row_number
est inférieur ou égal à quatre, l'exemple est présenté dans le jeu d'instructions).
Après application de la fonction définie par l'utilisateur, le bloc de données résultant est le suivant:
+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
| 6| 0.0| 1| true|
| 2| 0.0| 2| true|
| 0| 0.0| 3| true|
| 4| 0.0| 4| true|
| 8| 0.0| 5| false|
| 9| 1.0| 1| true|
| 5| 1.0| 2| true|
| 3| 1.0| 3| true|
| 1| 1.0| 4| true|
| 7| 1.0| 5| false|
+---+-----+----------+----------+
Maintenant, pour obtenir les données de train/test, il faut:
val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)
Ces étapes de tri et de partitionnement peuvent être prohibitives pour des jeux de données vraiment volumineux. Je suggère donc de commencer par filtrer autant que possible le jeu de données. Le plan physique est le suivant:
== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
+- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(label#5, 200)
+- *(1) Project [id#4, label#5]
+- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
+- LocalTableScan [id#4, label#5, _nondeterministic#9
Voici un exemple de travail complet (testé avec Spark 2.3.0 et Scala 2.11.12):
import org.Apache.spark.SparkConf
import org.Apache.spark.sql.expressions.Window
import org.Apache.spark.sql.{DataFrame, Row, SparkSession}
import org.Apache.spark.sql.functions.{col, row_number, udf, Rand}
class StratifiedTrainTestSplitter {
def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
df.groupBy(label).count().createOrReplaceTempView("labelCounts")
val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
import ss.implicits._
ss.sql(query)
.select("ratioLabel", "trainExamples")
.map((r: Row) => r.getDouble(0) -> r.getLong(1))
.collect()
.toMap
}
def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
val w = Window.partitionBy(col(label)).orderBy(col(label))
val rowNumPartitioner = row_number().over(w)
val dfRowNum = df.sort(Rand).select(col("*"), rowNumPartitioner as "row_number")
dfRowNum.show()
val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)
val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))
dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
}
}
object StratifiedTrainTestSplitter {
def getDf(ss: SparkSession): DataFrame = {
val data = Seq(
(0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
)
ss.createDataFrame(data).toDF("id", "label")
}
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.config(new SparkConf().setMaster("local[1]"))
.getOrCreate()
val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)
df.cache()
df.where(col("isTrainSet") === true).show()
df.where(col("isTrainSet") === false).show()
}
}
Remarque: les étiquettes sont Double
s dans ce cas. Si vos étiquettes sont String
s, vous devrez changer de type ici et là.