smurakozi closed pull request #20235: [Spark-22887][ML][TESTS][WIP] ML test for 
StructuredStreaming: spark.ml.fpm
URL: https://github.com/apache/spark/pull/20235
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 87f8b9034dde8..8f36cbd8f8be5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -16,15 +16,16 @@
  */
 package org.apache.spark.ml.fpm
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
-class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
+class FPGrowthSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
 
   @transient var dataset: Dataset[_] = _
 
@@ -34,41 +35,57 @@ class FPGrowthSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
   }
 
   test("FPGrowth fit and transform with different data types") {
-    Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt 
=>
-      val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
-      val model = new FPGrowth().setMinSupport(0.5).fit(data)
-      val generatedRules = model.setMinConfidence(0.5).associationRules
-      val expectedRules = spark.createDataFrame(Seq(
-        (Array("2"), Array("1"), 1.0),
-        (Array("1"), Array("2"), 0.75)
-      )).toDF("antecedent", "consequent", "confidence")
-        .withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
-        .withColumn("consequent", col("consequent").cast(ArrayType(dt)))
-      assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
-        generatedRules.sort("antecedent").rdd.collect()))
-
-      val transformed = model.transform(data)
-      val expectedTransformed = spark.createDataFrame(Seq(
-        (0, Array("1", "2"), Array.emptyIntArray),
-        (0, Array("1", "2"), Array.emptyIntArray),
-        (0, Array("1", "2"), Array.emptyIntArray),
-        (0, Array("1", "3"), Array(2))
-      )).toDF("id", "items", "prediction")
-        .withColumn("items", col("items").cast(ArrayType(dt)))
-        .withColumn("prediction", col("prediction").cast(ArrayType(dt)))
-      assert(expectedTransformed.collect().toSet.equals(
-        transformed.collect().toSet))
+      class DataTypeWithEncoder[A](val dataType: DataType)
+                                  (implicit val encoder: Encoder[(Array[A], 
Array[A])])
+
+      Array(
+        new DataTypeWithEncoder[Int](IntegerType),
+        new DataTypeWithEncoder[String](StringType),
+        new DataTypeWithEncoder[Short](ShortType),
+        new DataTypeWithEncoder[Long](LongType)
+        // , new DataTypeWithEncoder[Byte](ByteType)
+        // TODO: using ByteType produces error, as Array[Byte] is handled as 
Binary
+        // cannot resolve 'CAST(`items` AS BINARY)' due to data type mismatch:
+        // cannot cast array<tinyint> to binary;
+      ).foreach { dt => {
+        val data = dataset.withColumn("items", 
col("items").cast(ArrayType(dt.dataType)))
+        val model = new FPGrowth().setMinSupport(0.5).fit(data)
+        val generatedRules = model.setMinConfidence(0.5).associationRules
+        val expectedRules = Seq(
+          (Array("2"), Array("1"), 1.0),
+          (Array("1"), Array("2"), 0.75)
+        ).toDF("antecedent", "consequent", "confidence")
+          .withColumn("antecedent", 
col("antecedent").cast(ArrayType(dt.dataType)))
+          .withColumn("consequent", 
col("consequent").cast(ArrayType(dt.dataType)))
+        assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
+          generatedRules.sort("antecedent").rdd.collect()))
+
+        val expectedTransformed = Seq(
+          (Array("1", "2"), Array.emptyIntArray),
+          (Array("1", "2"), Array.emptyIntArray),
+          (Array("1", "2"), Array.emptyIntArray),
+          (Array("1", "3"), Array(2))
+        ).toDF("items", "expected")
+          .withColumn("items", col("items").cast(ArrayType(dt.dataType)))
+          .withColumn("expected", col("expected").cast(ArrayType(dt.dataType)))
+
+        testTransformer(expectedTransformed, model,
+          "expected", "prediction") {
+          case Row(expected, prediction) => assert(expected === prediction,
+            s"Expected $expected but found $prediction for data type $dt")
+        }(dt.encoder)
+      }
     }
   }
 
   test("FPGrowth getFreqItems") {
     val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
-    val expectedFreq = spark.createDataFrame(Seq(
+    val expectedFreq = Seq(
       (Array("1"), 4L),
       (Array("2"), 3L),
       (Array("1", "2"), 3L),
       (Array("2", "1"), 3L) // duplicate as the items sequence is not 
guaranteed
-    )).toDF("items", "expectedFreq")
+    ).toDF("items", "expectedFreq")
     val freqItems = model.freqItemsets
 
     val checkDF = freqItems.join(expectedFreq, "items")
@@ -76,44 +93,64 @@ class FPGrowthSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
   }
 
   test("FPGrowth getFreqItems with Null") {
-    val df = spark.createDataFrame(Seq(
+    val df = Seq(
       (1, Array("1", "2", "3", "5")),
       (2, Array("1", "2", "3", "4")),
       (3, null.asInstanceOf[Array[String]])
-    )).toDF("id", "items")
+    ).toDF("id", "items")
     val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
-    val prediction = model.transform(df)
-    
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
+    testTransformerByGlobalCheckFunc[(Int, Array[String])](df, model, "id", 
"prediction") {
+      rows => {
+        val predictionForId3 = rows.filter(_.getAs[Int]("id") == 3)
+          .map(_.getAs[Seq[String]]("prediction"))
+        assert(Seq(Seq.empty) === predictionForId3,
+          s"Expected empty prediction for id 3, got $predictionForId3")
+      }
+    }
   }
 
   test("FPGrowth prediction should not contain duplicates") {
     // This should generate rule 1 -> 3, 2 -> 3
-    val dataset = spark.createDataFrame(Seq(
+    val dataset = Seq(
       Array("1", "3"),
       Array("2", "3")
-    ).map(Tuple1(_))).toDF("items")
+    ).toDF("items")
     val model = new FPGrowth().fit(dataset)
 
-    val prediction = model.transform(
-      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
-    ).first().getAs[Seq[String]]("prediction")
+    val df = Seq(Array("1", "2")).toDF("items")
 
-    assert(prediction === Seq("3"))
+    testTransformerByGlobalCheckFunc[(Array[String])](df, model, "prediction") 
{
+      rows => {
+        assert(1 === rows.size, s"Expected exactly 1 prediction, got $rows")
+        val predictions = rows.map(_.getAs[Seq[String]]("prediction"))
+        val expected = Seq(Seq("3"))
+        assert(expected === predictions, s"Expected $expected, got 
$predictions")
+      }
+    }
   }
 
   test("FPGrowthModel setMinConfidence should affect rules generation and 
transform") {
     val model = new 
FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
     val oldRulesNum = model.associationRules.count()
-    val oldPredict = model.transform(dataset)
+    val oldPredict = model.transform(dataset).withColumnRenamed("prediction", 
"oldPrediction")
 
     model.setMinConfidence(0.8765)
     assert(oldRulesNum > model.associationRules.count())
-    
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
+    testTransformer[(Int, Array[String], Array[String])](oldPredict, model,
+      "oldPrediction", "prediction") {
+      case Row(oldPrediction, prediction) => assert(oldPrediction !== 
prediction,
+          "Change in minConfidence was expected to affect prediction but it 
remained the same")
+    }
 
     // association rules should stay the same for same minConfidence
     model.setMinConfidence(0.1)
     assert(oldRulesNum === model.associationRules.count())
-    
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
+    testTransformer[(Int, Array[String], Array[String])](oldPredict, model,
+      "oldPrediction", "prediction") {
+      case Row(oldPrediction, prediction) => assert(oldPrediction === 
prediction,
+        "Changing minConfidence back to original value was expected to produce 
" +
+          s"original predictions. Expected $oldPrediction but found 
$prediction")
+    }
   }
 
   test("FPGrowth parameter check") {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to