Repository: spark
Updated Branches:
  refs/heads/master cfbe11e81 -> 3d8837e59


[SPARK-22397][ML] add multiple columns support to QuantileDiscretizer

## What changes were proposed in this pull request?

add multi columns support to  QuantileDiscretizer.
When calculating the splits, we can either merge together all the  
probabilities into one array by calculating approxQuantiles on multiple 
columns at once, or compute approxQuantiles separately  for each column. After 
doing the performance comparision, we found it’s better to calculating 
approxQuantiles on multiple columns at once.

Here is how we measuring the performance time:
```
    var duration = 0.0
    for (i<- 0 until 10) {
      val start = System.nanoTime()
      discretizer.fit(df)
      val end = System.nanoTime()
      duration += (end - start) / 1e9
    }
    println(duration/10)
```
Here is the performance test result:

|numCols |NumRows  | compute each approxQuantiles separately|compute multiple 
columns approxQuantiles at one time|
|--------|----------|--------------------------------|-------------------------------------------|
|10         |60             |0.3623195839                            
|0.1626658607                                                |
|10         |6000        |0.7537239841                             
|0.3869370046                                               |
|22         |6000        |1.6497598557                             
|0.4767903059                                               |
|50         |6000        |3.2268305752                            |0.7217818396 
                                               |

## How was this patch tested?

add UT in QuantileDiscretizerSuite to test multi columns supports

Author: Huaxin Gao <huax...@us.ibm.com>

Closes #19715 from huaxingao/spark_22397.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d8837e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d8837e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d8837e5

Branch: refs/heads/master
Commit: 3d8837e59aadd726805371041567ceff375194c0
Parents: cfbe11e
Author: Huaxin Gao <huax...@us.ibm.com>
Authored: Sun Dec 31 14:39:24 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Sun Dec 31 14:39:24 2017 +0200

----------------------------------------------------------------------
 .../spark/ml/feature/QuantileDiscretizer.scala  | 120 +++++++--
 .../ml/feature/QuantileDiscretizerSuite.scala   | 265 +++++++++++++++++++
 2 files changed, 369 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d8837e5/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 95e8830..1ec5f8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, 
HasOutputCol}
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, 
HasInputCols, HasOutputCol, HasOutputCols}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types.StructType
@@ -51,9 +51,27 @@ private[feature] trait QuantileDiscretizerBase extends Params
   def getNumBuckets: Int = getOrDefault(numBuckets)
 
   /**
+   * Array of number of buckets (quantiles, or categories) into which data 
points are grouped.
+   * Each value must be greater than or equal to 2
+   *
+   * See also [[handleInvalid]], which can optionally create an additional 
bucket for NaN values.
+   *
+   * @group param
+   */
+  val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of 
number of buckets " +
+    "(quantiles, or categories) into which data points are grouped. This is 
for multiple " +
+    "columns input. If transforming multiple columns and numBucketsArray is 
not set, but " +
+    "numBuckets is set, then numBuckets will be applied across all columns.",
+    (arrayOfNumBuckets: Array[Int]) => 
arrayOfNumBuckets.forall(ParamValidators.gtEq(2)))
+
+  /** @group getParam */
+  def getNumBucketsArray: Array[Int] = $(numBucketsArray)
+
+  /**
    * Relative error (see documentation for
    * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for 
description)
    * Must be in the range [0, 1].
+   * Note that in multiple columns case, relative error is applied to all 
columns.
    * default: 0.001
    * @group param
    */
@@ -68,7 +86,9 @@ private[feature] trait QuantileDiscretizerBase extends Params
   /**
    * Param for how to handle invalid entries. Options are 'skip' (filter out 
rows with
    * invalid values), 'error' (throw an error), or 'keep' (keep invalid values 
in a special
-   * additional bucket).
+   * additional bucket). Note that in the multiple columns case, the invalid 
handling is applied
+   * to all columns. That said for 'error' it will throw an error if any 
invalids are found in
+   * any column, for 'skip' it will skip rows with any invalids in any 
columns, etc.
    * Default: "error"
    * @group param
    */
@@ -86,6 +106,11 @@ private[feature] trait QuantileDiscretizerBase extends 
Params
  * categorical features. The number of bins can be set using the `numBuckets` 
parameter. It is
  * possible that the number of buckets used will be smaller than this value, 
for example, if there
  * are too few distinct values of the input to create enough distinct 
quantiles.
+ * Since 2.3.0, `QuantileDiscretizer` can map multiple columns at once by 
setting the `inputCols`
+ * parameter. If both of the `inputCol` and `inputCols` parameters are set, an 
Exception will be
+ * thrown. To specify the number of buckets for each column, the 
`numBucketsArray` parameter can
+ * be set, or if the number of buckets should be the same across columns, 
`numBuckets` can be
+ * set as a convenience.
  *
  * NaN handling:
  * null and NaN values will be ignored from the column during 
`QuantileDiscretizer` fitting. This
@@ -104,7 +129,8 @@ private[feature] trait QuantileDiscretizerBase extends 
Params
  */
 @Since("1.6.0")
 final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val 
uid: String)
-  extends Estimator[Bucketizer] with QuantileDiscretizerBase with 
DefaultParamsWritable {
+  extends Estimator[Bucketizer] with QuantileDiscretizerBase with 
DefaultParamsWritable
+    with HasInputCols with HasOutputCols {
 
   @Since("1.6.0")
   def this() = this(Identifiable.randomUID("quantileDiscretizer"))
@@ -129,34 +155,96 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
   @Since("2.1.0")
   def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
 
+  /** @group setParam */
+  @Since("2.3.0")
+  def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, 
value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+  private[feature] def getInOutCols: (Array[String], Array[String]) = {
+    require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && 
!isSet(outputCols)) ||
+      (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && 
isSet(outputCols)),
+      "QuantileDiscretizer only supports setting either inputCol/outputCol or" 
+
+        "inputCols/outputCols."
+    )
+
+    if (isSet(inputCol)) {
+      (Array($(inputCol)), Array($(outputCol)))
+    } else {
+      require($(inputCols).length == $(outputCols).length,
+        "inputCols number do not match outputCols")
+      ($(inputCols), $(outputCols))
+    }
+  }
+
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkNumericType(schema, $(inputCol))
-    val inputFields = schema.fields
-    require(inputFields.forall(_.name != $(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val attr = NominalAttribute.defaultAttr.withName($(outputCol))
-    val outputFields = inputFields :+ attr.toStructField()
+    val (inputColNames, outputColNames) = getInOutCols
+    val existingFields = schema.fields
+    var outputFields = existingFields
+    inputColNames.zip(outputColNames).foreach { case (inputColName, 
outputColName) =>
+      SchemaUtils.checkNumericType(schema, inputColName)
+      require(existingFields.forall(_.name != outputColName),
+        s"Output column ${outputColName} already exists.")
+      val attr = NominalAttribute.defaultAttr.withName(outputColName)
+      outputFields :+= attr.toStructField()
+    }
     StructType(outputFields)
   }
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): Bucketizer = {
     transformSchema(dataset.schema, logging = true)
-    val splits = dataset.stat.approxQuantile($(inputCol),
-      (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+    val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
+    if (isSet(inputCols)) {
+      val splitsArray = if (isSet(numBucketsArray)) {
+        val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
+          (0.0 to 1.0 by 1.0 / numOfBuckets).toArray
+        }
+
+        val probabilityArray = probArrayPerCol.flatten.sorted.distinct
+        val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols),
+          probabilityArray, $(relativeError))
+
+        splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) =>
+          val probSet = probs.toSet
+          val idxSet = probabilityArray.zipWithIndex.collect {
+            case (p, idx) if probSet(p) =>
+              idx
+          }.toSet
+          splits.zipWithIndex.collect {
+            case (s, idx) if idxSet(idx) =>
+              s
+          }
+        }
+      } else {
+        dataset.stat.approxQuantile($(inputCols),
+          (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
+      }
+      bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
+    } else {
+      val splits = dataset.stat.approxQuantile($(inputCol),
+        (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
+      bucketizer.setSplits(getDistinctSplits(splits))
+    }
+    copyValues(bucketizer.setParent(this))
+  }
+
+  private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
     splits(0) = Double.NegativeInfinity
     splits(splits.length - 1) = Double.PositiveInfinity
-
     val distinctSplits = splits.distinct
     if (splits.length != distinctSplits.length) {
       log.warn(s"Some quantiles were identical. Bucketing to 
${distinctSplits.length - 1}" +
         s" buckets as a result.")
     }
-    val bucketizer = new Bucketizer(uid)
-      .setSplits(distinctSplits.sorted)
-      .setHandleInvalid($(handleInvalid))
-    copyValues(bucketizer.setParent(this))
+    distinctSplits.sorted
   }
 
   @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/3d8837e5/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index f219f77..e9a75e9 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.Pipeline
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql._
@@ -146,4 +147,268 @@ class QuantileDiscretizerSuite
     val model = discretizer.fit(df)
     assert(model.hasParent)
   }
+
+  test("Multiple Columns: Test observed number of buckets and their sizes 
match expected values") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val datasetSize = 100000
+    val numBuckets = 5
+    val data1 = Array.range(1, 100001, 1).map(_.toDouble)
+    val data2 = Array.range(1, 200000, 2).map(_.toDouble)
+    val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+
+    val discretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setNumBuckets(numBuckets)
+    val result = discretizer.fit(df).transform(df)
+
+    val relativeError = discretizer.getRelativeError
+    val isGoodBucket = udf {
+      (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= 
(relativeError * datasetSize)
+    }
+
+    for (i <- 1 to 2) {
+      val observedNumBuckets = result.select("result" + i).distinct.count
+      assert(observedNumBuckets === numBuckets,
+        "Observed number of buckets does not equal expected number of 
buckets.")
+
+      val numGoodBuckets = result.groupBy("result" + 
i).count.filter(isGoodBucket($"count")).count
+      assert(numGoodBuckets === numBuckets,
+        "Bucket sizes are not within expected relative error tolerance.")
+    }
+  }
+
+  test("Multiple Columns: Test on data with high proportion of duplicated 
values") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val numBuckets = 5
+    val expectedNumBucket = 3
+    val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 
3.0)
+    val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 
2.0)
+    val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+    val discretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setNumBuckets(numBuckets)
+    val result = discretizer.fit(df).transform(df)
+    for (i <- 1 to 2) {
+      val observedNumBuckets = result.select("result" + i).distinct.count
+      assert(observedNumBuckets == expectedNumBucket,
+        s"Observed number of buckets are not correct." +
+          s" Expected $expectedNumBucket but found ($observedNumBuckets")
+    }
+  }
+
+  test("Multiple Columns: Test transform on data with NaN value") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val numBuckets = 3
+    val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, 
Double.NaN, Double.NaN)
+    val expectedKeep1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0)
+    val expectedSkip1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0)
+    val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, Double.NaN, 
Double.NaN, Double.NaN)
+    val expectedKeep2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0, 3.0, 3.0, 3.0)
+    val expectedSkip2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0)
+
+    val discretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setNumBuckets(numBuckets)
+
+    withClue("QuantileDiscretizer with handleInvalid=error should throw 
exception for NaN values") {
+      val dataFrame: DataFrame = 
validData1.zip(validData2).toSeq.toDF("input1", "input2")
+      intercept[SparkException] {
+        discretizer.fit(dataFrame).transform(dataFrame).collect()
+      }
+    }
+
+    List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, 
expectedSkip2)).foreach {
+      case (u, v, w) =>
+        discretizer.setHandleInvalid(u)
+        val dataFrame: DataFrame = 
validData1.zip(validData2).zip(v).zip(w).map {
+          case (((a, b), c), d) => (a, b, c, d)
+        }.toSeq.toDF("input1", "input2", "expected1", "expected2")
+        val result = discretizer.fit(dataFrame).transform(dataFrame)
+        result.select("result1", "expected1", "result2", 
"expected2").collect().foreach {
+          case Row(x: Double, y: Double, z: Double, w: Double) =>
+            assert(x === y && w === z)
+        }
+    }
+  }
+
+  test("Multiple Columns: Test numBucketsArray") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val numBucketsArray: Array[Int] = Array(2, 5, 10)
+    val data1 = Array.range(1, 21, 1).map(_.toDouble)
+    val expected1 = Array (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
+      1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
+    val data2 = Array.range(1, 40, 2).map(_.toDouble)
+    val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0,
+      2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0)
+    val data3 = Array.range(1, 60, 3).map(_.toDouble)
+    val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0,
+      5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0)
+    val data = (0 until 20).map { idx =>
+      (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), 
expected3(idx))
+    }
+    val df =
+      data.toDF("input1", "input2", "input3", "expected1", "expected2", 
"expected3")
+
+    val discretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setNumBucketsArray(numBucketsArray)
+
+    discretizer.fit(df).transform(df).
+      select("result1", "expected1", "result2", "expected2", "result3", 
"expected3")
+      .collect().foreach {
+      case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: 
Double) =>
+        assert(r1 === e1,
+          s"The result value is not correct after bucketing. Expected $e1 but 
found $r1")
+        assert(r2 === e2,
+          s"The result value is not correct after bucketing. Expected $e2 but 
found $r2")
+        assert(r3 === e3,
+          s"The result value is not correct after bucketing. Expected $e3 but 
found $r3")
+    }
+  }
+
+  test("Multiple Columns: Compare single/multiple column(s) 
QuantileDiscretizer in pipeline") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val numBucketsArray: Array[Int] = Array(2, 5, 10)
+    val data1 = Array.range(1, 21, 1).map(_.toDouble)
+    val data2 = Array.range(1, 40, 2).map(_.toDouble)
+    val data3 = Array.range(1, 60, 3).map(_.toDouble)
+    val data = (0 until 20).map { idx =>
+      (data1(idx), data2(idx), data3(idx))
+    }
+    val df =
+      data.toDF("input1", "input2", "input3")
+
+    val multiColsDiscretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setNumBucketsArray(numBucketsArray)
+    val plForMultiCols = new Pipeline()
+      .setStages(Array(multiColsDiscretizer))
+      .fit(df)
+
+    val discretizerForCol1 = new QuantileDiscretizer()
+      .setInputCol("input1")
+      .setOutputCol("result1")
+      .setNumBuckets(numBucketsArray(0))
+
+    val discretizerForCol2 = new QuantileDiscretizer()
+      .setInputCol("input2")
+      .setOutputCol("result2")
+      .setNumBuckets(numBucketsArray(1))
+
+    val discretizerForCol3 = new QuantileDiscretizer()
+      .setInputCol("input3")
+      .setOutputCol("result3")
+      .setNumBuckets(numBucketsArray(2))
+
+    val plForSingleCol = new Pipeline()
+      .setStages(Array(discretizerForCol1, discretizerForCol2, 
discretizerForCol3))
+      .fit(df)
+
+    val resultForMultiCols = plForMultiCols.transform(df)
+      .select("result1", "result2", "result3")
+      .collect()
+
+    val resultForSingleCol = plForSingleCol.transform(df)
+      .select("result1", "result2", "result3")
+      .collect()
+
+    resultForSingleCol.zip(resultForMultiCols).foreach {
+      case (rowForSingle, rowForMultiCols) =>
+        assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
+          rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
+          rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2))
+    }
+  }
+
+  test("Multiple Columns: Comparing setting numBuckets with setting 
numBucketsArray " +
+    "explicitly with identical values") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val data1 = Array.range(1, 21, 1).map(_.toDouble)
+    val data2 = Array.range(1, 40, 2).map(_.toDouble)
+    val data3 = Array.range(1, 60, 3).map(_.toDouble)
+    val data = (0 until 20).map { idx =>
+      (data1(idx), data2(idx), data3(idx))
+    }
+    val df =
+      data.toDF("input1", "input2", "input3")
+
+    val discretizerSingleNumBuckets = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setNumBuckets(10)
+
+    val discretizerNumBucketsArray = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2", "input3"))
+      .setOutputCols(Array("result1", "result2", "result3"))
+      .setNumBucketsArray(Array(10, 10, 10))
+
+    val result1 = discretizerSingleNumBuckets.fit(df).transform(df)
+      .select("result1", "result2", "result3")
+      .collect()
+    val result2 = discretizerNumBucketsArray.fit(df).transform(df)
+      .select("result1", "result2", "result3")
+      .collect()
+
+    result1.zip(result2).foreach {
+      case (row1, row2) =>
+        assert(row1.getDouble(0) == row2.getDouble(0) &&
+          row1.getDouble(1) == row2.getDouble(1) &&
+          row1.getDouble(2) == row2.getDouble(2))
+    }
+  }
+
+  test("Multiple Columns: read/write") {
+    val discretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("result1", "result2"))
+      .setNumBucketsArray(Array(5, 10))
+    testDefaultReadWrite(discretizer)
+  }
+
+  test("Multiple Columns: Both inputCol and inputCols are set") {
+    val spark = this.spark
+    import spark.implicits._
+    val discretizer = new QuantileDiscretizer()
+      .setInputCol("input")
+      .setOutputCol("result")
+      .setNumBuckets(3)
+      .setInputCols(Array("input1", "input2"))
+    val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
+      .map(Tuple1.apply).toDF("input")
+    // When both inputCol and inputCols are set, we throw Exception.
+    intercept[IllegalArgumentException] {
+      discretizer.fit(df)
+    }
+  }
+
+  test("Multiple Columns: Mismatched sizes of inputCols / outputCols") {
+    val spark = this.spark
+    import spark.implicits._
+    val discretizer = new QuantileDiscretizer()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("result1", "result2"))
+      .setNumBuckets(3)
+    val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
+      .map(Tuple1.apply).toDF("input")
+    intercept[IllegalArgumentException] {
+      discretizer.fit(df)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to