Github user gaborgsomogyi commented on a diff in the pull request:
https://github.com/apache/spark/pull/20235#discussion_r163241569
--- Diff: mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
---
@@ -34,86 +35,122 @@ class FPGrowthSuite extends SparkFunSuite with
MLlibTestSparkContext with Defaul
}
test("FPGrowth fit and transform with different data types") {
- Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach
{ dt =>
- val data = dataset.withColumn("items",
col("items").cast(ArrayType(dt)))
- val model = new FPGrowth().setMinSupport(0.5).fit(data)
- val generatedRules = model.setMinConfidence(0.5).associationRules
- val expectedRules = spark.createDataFrame(Seq(
- (Array("2"), Array("1"), 1.0),
- (Array("1"), Array("2"), 0.75)
- )).toDF("antecedent", "consequent", "confidence")
- .withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
- .withColumn("consequent", col("consequent").cast(ArrayType(dt)))
- assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
- generatedRules.sort("antecedent").rdd.collect()))
-
- val transformed = model.transform(data)
- val expectedTransformed = spark.createDataFrame(Seq(
- (0, Array("1", "2"), Array.emptyIntArray),
- (0, Array("1", "2"), Array.emptyIntArray),
- (0, Array("1", "2"), Array.emptyIntArray),
- (0, Array("1", "3"), Array(2))
- )).toDF("id", "items", "prediction")
- .withColumn("items", col("items").cast(ArrayType(dt)))
- .withColumn("prediction", col("prediction").cast(ArrayType(dt)))
- assert(expectedTransformed.collect().toSet.equals(
- transformed.collect().toSet))
+ class DataTypeWithEncoder[A](val a: DataType)
+ (implicit val encoder: Encoder[(Int,
Array[A], Array[A])])
+
+ Array(
+ new DataTypeWithEncoder[Int](IntegerType),
+ new DataTypeWithEncoder[String](StringType),
+ new DataTypeWithEncoder[Short](ShortType),
+ new DataTypeWithEncoder[Long](LongType)
+ // , new DataTypeWithEncoder[Byte](ByteType)
+ // TODO: using ByteType produces error, as Array[Byte] is handled
as Binary
+ // cannot resolve 'CAST(`items` AS BINARY)' due to data type
mismatch:
+ // cannot cast array<tinyint> to binary;
+ ).foreach { dt => {
+ val data = dataset.withColumn("items",
col("items").cast(ArrayType(dt.a)))
+ val model = new FPGrowth().setMinSupport(0.5).fit(data)
+ val generatedRules = model.setMinConfidence(0.5).associationRules
+ val expectedRules = Seq(
+ (Array("2"), Array("1"), 1.0),
+ (Array("1"), Array("2"), 0.75)
+ ).toDF("antecedent", "consequent", "confidence")
+ .withColumn("antecedent",
col("antecedent").cast(ArrayType(dt.a)))
+ .withColumn("consequent",
col("consequent").cast(ArrayType(dt.a)))
+ assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
+ generatedRules.sort("antecedent").rdd.collect()))
+
+ val expectedTransformed = Seq(
+ (0, Array("1", "2"), Array.emptyIntArray),
+ (0, Array("1", "2"), Array.emptyIntArray),
+ (0, Array("1", "2"), Array.emptyIntArray),
+ (0, Array("1", "3"), Array(2))
+ ).toDF("id", "items", "expected")
+ .withColumn("items", col("items").cast(ArrayType(dt.a)))
+ .withColumn("expected", col("expected").cast(ArrayType(dt.a)))
+
+ testTransformer(expectedTransformed, model,
+ "expected", "prediction") {
+ case Row(expected, prediction) => assert(expected === prediction,
+ s"Expected $expected but found $prediction for data type $dt")
+ }(dt.encoder)
+ }
}
}
test("FPGrowth getFreqItems") {
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
- val expectedFreq = spark.createDataFrame(Seq(
+ val expectedFreq = Seq(
(Array("1"), 4L),
(Array("2"), 3L),
(Array("1", "2"), 3L),
(Array("2", "1"), 3L) // duplicate as the items sequence is not
guaranteed
- )).toDF("items", "expectedFreq")
+ ).toDF("items", "expectedFreq")
val freqItems = model.freqItemsets
val checkDF = freqItems.join(expectedFreq, "items")
assert(checkDF.count() == 3 && checkDF.filter(col("freq") ===
col("expectedFreq")).count() == 3)
}
test("FPGrowth getFreqItems with Null") {
- val df = spark.createDataFrame(Seq(
+ val df = Seq(
(1, Array("1", "2", "3", "5")),
(2, Array("1", "2", "3", "4")),
(3, null.asInstanceOf[Array[String]])
- )).toDF("id", "items")
+ ).toDF("id", "items")
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
- val prediction = model.transform(df)
-
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
+ testTransformerByGlobalCheckFunc[(Int, Array[String])](df, model,
"id", "prediction") {
+ rows => {
+ val predictionForId3 = rows.filter(_.getAs[Int]("id") == 3)
+ .map(_.getAs[Seq[String]]("prediction"))
+ assert(Seq(Seq.empty) === predictionForId3,
+ s"Expected empty prediction for id 3, got $predictionForId3")
+ }
+ }
}
test("FPGrowth prediction should not contain duplicates") {
// This should generate rule 1 -> 3, 2 -> 3
- val dataset = spark.createDataFrame(Seq(
+ val dataset = Seq(
Array("1", "3"),
Array("2", "3")
- ).map(Tuple1(_))).toDF("items")
+ ).toDF("items")
val model = new FPGrowth().fit(dataset)
- val prediction = model.transform(
- spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
- ).first().getAs[Seq[String]]("prediction")
+ val df = Seq(Array("1", "2")).toDF("items")
- assert(prediction === Seq("3"))
+ testTransformerByGlobalCheckFunc[(Array[String])](df, model,
"prediction") {
+ rows => {
+ assert(1 === rows.size, s"Expected exactly 1 prediction, got
$rows")
+ val predictions = rows.map(_.getAs[Seq[String]]("prediction"))
+ val expected = Seq(Seq("3"))
+ assert(expected === predictions, s"Expected $expected, got
$predictions")
+ }
+ }
}
test("FPGrowthModel setMinConfidence should affect rules generation and
transform") {
val model = new
FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
val oldRulesNum = model.associationRules.count()
- val oldPredict = model.transform(dataset)
+ val oldPredict =
model.transform(dataset).withColumnRenamed("prediction", "oldPrediction")
--- End diff --
Other places named it `oldPrediction`, this could be also then renamed.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]