Github user gaborgsomogyi commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20235#discussion_r163241569
  
    --- Diff: mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala 
---
    @@ -34,86 +35,122 @@ class FPGrowthSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
       }
     
       test("FPGrowth fit and transform with different data types") {
    -    Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach 
{ dt =>
    -      val data = dataset.withColumn("items", 
col("items").cast(ArrayType(dt)))
    -      val model = new FPGrowth().setMinSupport(0.5).fit(data)
    -      val generatedRules = model.setMinConfidence(0.5).associationRules
    -      val expectedRules = spark.createDataFrame(Seq(
    -        (Array("2"), Array("1"), 1.0),
    -        (Array("1"), Array("2"), 0.75)
    -      )).toDF("antecedent", "consequent", "confidence")
    -        .withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
    -        .withColumn("consequent", col("consequent").cast(ArrayType(dt)))
    -      assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
    -        generatedRules.sort("antecedent").rdd.collect()))
    -
    -      val transformed = model.transform(data)
    -      val expectedTransformed = spark.createDataFrame(Seq(
    -        (0, Array("1", "2"), Array.emptyIntArray),
    -        (0, Array("1", "2"), Array.emptyIntArray),
    -        (0, Array("1", "2"), Array.emptyIntArray),
    -        (0, Array("1", "3"), Array(2))
    -      )).toDF("id", "items", "prediction")
    -        .withColumn("items", col("items").cast(ArrayType(dt)))
    -        .withColumn("prediction", col("prediction").cast(ArrayType(dt)))
    -      assert(expectedTransformed.collect().toSet.equals(
    -        transformed.collect().toSet))
    +      class DataTypeWithEncoder[A](val a: DataType)
    +                                  (implicit val encoder: Encoder[(Int, 
Array[A], Array[A])])
    +
    +      Array(
    +        new DataTypeWithEncoder[Int](IntegerType),
    +        new DataTypeWithEncoder[String](StringType),
    +        new DataTypeWithEncoder[Short](ShortType),
    +        new DataTypeWithEncoder[Long](LongType)
    +        // , new DataTypeWithEncoder[Byte](ByteType)
    +        // TODO: using ByteType produces error, as Array[Byte] is handled 
as Binary
    +        // cannot resolve 'CAST(`items` AS BINARY)' due to data type 
mismatch:
    +        // cannot cast array<tinyint> to binary;
    +      ).foreach { dt => {
    +        val data = dataset.withColumn("items", 
col("items").cast(ArrayType(dt.a)))
    +        val model = new FPGrowth().setMinSupport(0.5).fit(data)
    +        val generatedRules = model.setMinConfidence(0.5).associationRules
    +        val expectedRules = Seq(
    +          (Array("2"), Array("1"), 1.0),
    +          (Array("1"), Array("2"), 0.75)
    +        ).toDF("antecedent", "consequent", "confidence")
    +          .withColumn("antecedent", 
col("antecedent").cast(ArrayType(dt.a)))
    +          .withColumn("consequent", 
col("consequent").cast(ArrayType(dt.a)))
    +        assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
    +          generatedRules.sort("antecedent").rdd.collect()))
    +
    +        val expectedTransformed = Seq(
    +          (0, Array("1", "2"), Array.emptyIntArray),
    +          (0, Array("1", "2"), Array.emptyIntArray),
    +          (0, Array("1", "2"), Array.emptyIntArray),
    +          (0, Array("1", "3"), Array(2))
    +        ).toDF("id", "items", "expected")
    +          .withColumn("items", col("items").cast(ArrayType(dt.a)))
    +          .withColumn("expected", col("expected").cast(ArrayType(dt.a)))
    +
    +        testTransformer(expectedTransformed, model,
    +          "expected", "prediction") {
    +          case Row(expected, prediction) => assert(expected === prediction,
    +            s"Expected $expected but found $prediction for data type $dt")
    +        }(dt.encoder)
    +      }
         }
       }
     
       test("FPGrowth getFreqItems") {
         val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
    -    val expectedFreq = spark.createDataFrame(Seq(
    +    val expectedFreq = Seq(
           (Array("1"), 4L),
           (Array("2"), 3L),
           (Array("1", "2"), 3L),
           (Array("2", "1"), 3L) // duplicate as the items sequence is not 
guaranteed
    -    )).toDF("items", "expectedFreq")
    +    ).toDF("items", "expectedFreq")
         val freqItems = model.freqItemsets
     
         val checkDF = freqItems.join(expectedFreq, "items")
         assert(checkDF.count() == 3 && checkDF.filter(col("freq") === 
col("expectedFreq")).count() == 3)
       }
     
       test("FPGrowth getFreqItems with Null") {
    -    val df = spark.createDataFrame(Seq(
    +    val df = Seq(
           (1, Array("1", "2", "3", "5")),
           (2, Array("1", "2", "3", "4")),
           (3, null.asInstanceOf[Array[String]])
    -    )).toDF("id", "items")
    +    ).toDF("id", "items")
         val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
    -    val prediction = model.transform(df)
    -    
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
    +    testTransformerByGlobalCheckFunc[(Int, Array[String])](df, model, 
"id", "prediction") {
    +      rows => {
    +        val predictionForId3 = rows.filter(_.getAs[Int]("id") == 3)
    +          .map(_.getAs[Seq[String]]("prediction"))
    +        assert(Seq(Seq.empty) === predictionForId3,
    +          s"Expected empty prediction for id 3, got $predictionForId3")
    +      }
    +    }
       }
     
       test("FPGrowth prediction should not contain duplicates") {
         // This should generate rule 1 -> 3, 2 -> 3
    -    val dataset = spark.createDataFrame(Seq(
    +    val dataset = Seq(
           Array("1", "3"),
           Array("2", "3")
    -    ).map(Tuple1(_))).toDF("items")
    +    ).toDF("items")
         val model = new FPGrowth().fit(dataset)
     
    -    val prediction = model.transform(
    -      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
    -    ).first().getAs[Seq[String]]("prediction")
    +    val df = Seq(Array("1", "2")).toDF("items")
     
    -    assert(prediction === Seq("3"))
    +    testTransformerByGlobalCheckFunc[(Array[String])](df, model, 
"prediction") {
    +      rows => {
    +        assert(1 === rows.size, s"Expected exactly 1 prediction, got 
$rows")
    +        val predictions = rows.map(_.getAs[Seq[String]]("prediction"))
    +        val expected = Seq(Seq("3"))
    +        assert(expected === predictions, s"Expected $expected, got 
$predictions")
    +      }
    +    }
       }
     
       test("FPGrowthModel setMinConfidence should affect rules generation and 
transform") {
         val model = new 
FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
         val oldRulesNum = model.associationRules.count()
    -    val oldPredict = model.transform(dataset)
    +    val oldPredict = 
model.transform(dataset).withColumnRenamed("prediction", "oldPrediction")
    --- End diff --
    
    Other places named it `oldPrediction`, this could be also then renamed.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to