huaxingao commented on a change in pull request #26064: 
[SPARK-23578][ML][PYSPARK] Binarizer support multi-column
URL: https://github.com/apache/spark/pull/26064#discussion_r333727044
 
 

 ##########
 File path: 
mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
 ##########
 @@ -122,5 +122,118 @@ class BinarizerSuite extends MLTest with 
DefaultReadWriteTest {
       .setOutputCol("myOutputCol")
       .setThreshold(0.1)
     testDefaultReadWrite(t)
+
+    val t2 = new Binarizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setThresholds(Array(30.0, 30.0, 30.0))
+    testDefaultReadWrite(t2)
+  }
+
+  test("Multiple Columns: Test thresholds") {
+    val thresholds = Array(10.0, -0.5, 0.0)
+
+    val data1 = Seq(5.0, 11.0)
+    val expected1 = Seq(0.0, 1.0)
+    val data2 = Seq(Vectors.sparse(3, Array(1), Array(0.5)),
+      Vectors.dense(Array(0.0, 0.5, 0.0)))
+    val expected2 = Seq(Vectors.dense(Array(1.0, 1.0, 1.0)),
+      Vectors.dense(Array(1.0, 1.0, 1.0)))
+    val data3 = Seq(0.0, 1.0)
+    val expected3 = Seq(0.0, 1.0)
+
+    val df = Seq(0, 1).map { idx =>
+      (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), 
expected3(idx))
+    }.toDF("input1", "input2", "input3", "expected1", "expected2", "expected3")
+
+    val binarizer = new Binarizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setThresholds(thresholds)
+
+    binarizer.transform(df)
+      .select("result1", "expected1", "result2", "expected2", "result3", 
"expected3")
+      .collect().foreach {
+      case Row(r1: Double, e1: Double, r2: Vector, e2: Vector, r3: Double, e3: 
Double) =>
+        assert(r1 === e1,
+          s"The result value is not correct after bucketing. Expected $e1 but 
found $r1")
+        assert(r2 === e2,
+          s"The result value is not correct after bucketing. Expected $e2 but 
found $r2")
+        assert(r3 === e3,
+          s"The result value is not correct after bucketing. Expected $e3 but 
found $r3")
+    }
+  }
+
+  test("Multiple Columns: Comparing setting threshold with setting thresholds 
" +
+    "explicitly with identical values") {
+    val data1 = Array.range(1, 21, 1).map(_.toDouble)
+    val data2 = Array.range(1, 40, 2).map(_.toDouble)
+    val data3 = Array.range(1, 60, 3).map(_.toDouble)
+    val df = (0 until 20).map { idx =>
+      (data1(idx), data2(idx), data3(idx))
+    }.toDF("input1", "input2", "input3")
+
+    val binarizerSingleThreshold = new Binarizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setThreshold(30.0)
+
+    val df2 = binarizerSingleThreshold.transform(df)
+
+    val binarizerMultiThreshold = new Binarizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("expected1", "expected2", "expected3"))
+      .setThresholds(Array(30.0, 30.0, 30.0))
+
+    binarizerMultiThreshold.transform(df2)
+      .select("result1", "expected1", "result2", "expected2", "result3", 
"expected3")
+      .collect().foreach {
+      case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: 
Double) =>
+        assert(r1 === e1,
+          s"The result value is not correct after bucketing. Expected $e1 but 
found $r1")
+        assert(r2 === e2,
+          s"The result value is not correct after bucketing. Expected $e2 but 
found $r2")
+        assert(r3 === e3,
+          s"The result value is not correct after bucketing. Expected $e3 but 
found $r3")
+    }
+  }
+
+  test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
+    val binarizer = new Binarizer()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("result1", "result2"))
+      .setThreshold(1.0)
+    val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
+      .map(Tuple1.apply).toDF("input")
+    intercept[IllegalArgumentException] {
+      binarizer.transform(df).count()
+    }
+  }
+
+  test("Multiple Columns: Mismatched sizes of inputCols/thresholds") {
+    val binarizer = new Binarizer()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setThresholds(Array(1.0, 2.0, 3.0))
+    val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0)
+    val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0)
+    val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+    intercept[IllegalArgumentException] {
+      binarizer.transform(df).count()
+    }
+  }
+
+  test("Multiple Columns: Set both of threshold/thresholds") {
+    val binarizer = new Binarizer()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setThresholds(Array(1.0, 2.0))
+      .setThreshold(1.0)
+    val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0)
+    val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0)
+    val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+    intercept[IllegalArgumentException] {
+      binarizer.transform(df).count()
+    }
 
 Review comment:
   Maybe add a negative test for setting both inputCol and thresholds? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to