Repository: spark
Updated Branches:
refs/heads/master cfbe11e81 -> 3d8837e59
[SPARK-22397][ML] add multiple columns support to QuantileDiscretizer
## What changes were proposed in this pull request?
add multi columns support to QuantileDiscretizer.
When calculating the splits, we can either merge together all the
probabilities into one array by calculating approxQuantiles on multiple
columns at once, or compute approxQuantiles separately for each column. After
doing the performance comparision, we found itâs better to calculating
approxQuantiles on multiple columns at once.
Here is how we measuring the performance time:
```
var duration = 0.0
for (i<- 0 until 10) {
val start = System.nanoTime()
discretizer.fit(df)
val end = System.nanoTime()
duration += (end - start) / 1e9
}
println(duration/10)
```
Here is the performance test result:
|numCols |NumRows | compute each approxQuantiles separately|compute multiple
columns approxQuantiles at one time|
|--------|----------|--------------------------------|-------------------------------------------|
|10 |60 |0.3623195839
|0.1626658607 |
|10 |6000 |0.7537239841
|0.3869370046 |
|22 |6000 |1.6497598557
|0.4767903059 |
|50 |6000 |3.2268305752 |0.7217818396
|
## How was this patch tested?
add UT in QuantileDiscretizerSuite to test multi columns supports
Author: Huaxin Gao <[email protected]>
Closes #19715 from huaxingao/spark_22397.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d8837e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d8837e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d8837e5
Branch: refs/heads/master
Commit: 3d8837e59aadd726805371041567ceff375194c0
Parents: cfbe11e
Author: Huaxin Gao <[email protected]>
Authored: Sun Dec 31 14:39:24 2017 +0200
Committer: Nick Pentreath <[email protected]>
Committed: Sun Dec 31 14:39:24 2017 +0200
----------------------------------------------------------------------
.../spark/ml/feature/QuantileDiscretizer.scala | 120 +++++++--
.../ml/feature/QuantileDiscretizerSuite.scala | 265 +++++++++++++++++++
2 files changed, 369 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3d8837e5/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 95e8830..1ec5f8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol,
HasOutputCol}
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol,
HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
@@ -51,9 +51,27 @@ private[feature] trait QuantileDiscretizerBase extends Params
def getNumBuckets: Int = getOrDefault(numBuckets)
/**
+ * Array of number of buckets (quantiles, or categories) into which data
points are grouped.
+ * Each value must be greater than or equal to 2
+ *
+ * See also [[handleInvalid]], which can optionally create an additional
bucket for NaN values.
+ *
+ * @group param
+ */
+ val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of
number of buckets " +
+ "(quantiles, or categories) into which data points are grouped. This is
for multiple " +
+ "columns input. If transforming multiple columns and numBucketsArray is
not set, but " +
+ "numBuckets is set, then numBuckets will be applied across all columns.",
+ (arrayOfNumBuckets: Array[Int]) =>
arrayOfNumBuckets.forall(ParamValidators.gtEq(2)))
+
+ /** @group getParam */
+ def getNumBucketsArray: Array[Int] = $(numBucketsArray)
+
+ /**
* Relative error (see documentation for
* `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for
description)
* Must be in the range [0, 1].
+ * Note that in multiple columns case, relative error is applied to all
columns.
* default: 0.001
* @group param
*/
@@ -68,7 +86,9 @@ private[feature] trait QuantileDiscretizerBase extends Params
/**
* Param for how to handle invalid entries. Options are 'skip' (filter out
rows with
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values
in a special
- * additional bucket).
+ * additional bucket). Note that in the multiple columns case, the invalid
handling is applied
+ * to all columns. That said for 'error' it will throw an error if any
invalids are found in
+ * any column, for 'skip' it will skip rows with any invalids in any
columns, etc.
* Default: "error"
* @group param
*/
@@ -86,6 +106,11 @@ private[feature] trait QuantileDiscretizerBase extends
Params
* categorical features. The number of bins can be set using the `numBuckets`
parameter. It is
* possible that the number of buckets used will be smaller than this value,
for example, if there
* are too few distinct values of the input to create enough distinct
quantiles.
+ * Since 2.3.0, `QuantileDiscretizer` can map multiple columns at once by
setting the `inputCols`
+ * parameter. If both of the `inputCol` and `inputCols` parameters are set, an
Exception will be
+ * thrown. To specify the number of buckets for each column, the
`numBucketsArray` parameter can
+ * be set, or if the number of buckets should be the same across columns,
`numBuckets` can be
+ * set as a convenience.
*
* NaN handling:
* null and NaN values will be ignored from the column during
`QuantileDiscretizer` fitting. This
@@ -104,7 +129,8 @@ private[feature] trait QuantileDiscretizerBase extends
Params
*/
@Since("1.6.0")
final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val
uid: String)
- extends Estimator[Bucketizer] with QuantileDiscretizerBase with
DefaultParamsWritable {
+ extends Estimator[Bucketizer] with QuantileDiscretizerBase with
DefaultParamsWritable
+ with HasInputCols with HasOutputCols {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
@@ -129,34 +155,96 @@ final class QuantileDiscretizer @Since("1.6.0")
(@Since("1.6.0") override val ui
@Since("2.1.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+ /** @group setParam */
+ @Since("2.3.0")
+ def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray,
value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+ private[feature] def getInOutCols: (Array[String], Array[String]) = {
+ require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) &&
!isSet(outputCols)) ||
+ (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) &&
isSet(outputCols)),
+ "QuantileDiscretizer only supports setting either inputCol/outputCol or"
+
+ "inputCols/outputCols."
+ )
+
+ if (isSet(inputCol)) {
+ (Array($(inputCol)), Array($(outputCol)))
+ } else {
+ require($(inputCols).length == $(outputCols).length,
+ "inputCols number do not match outputCols")
+ ($(inputCols), $(outputCols))
+ }
+ }
+
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkNumericType(schema, $(inputCol))
- val inputFields = schema.fields
- require(inputFields.forall(_.name != $(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val (inputColNames, outputColNames) = getInOutCols
+ val existingFields = schema.fields
+ var outputFields = existingFields
+ inputColNames.zip(outputColNames).foreach { case (inputColName,
outputColName) =>
+ SchemaUtils.checkNumericType(schema, inputColName)
+ require(existingFields.forall(_.name != outputColName),
+ s"Output column ${outputColName} already exists.")
+ val attr = NominalAttribute.defaultAttr.withName(outputColName)
+ outputFields :+= attr.toStructField()
+ }
StructType(outputFields)
}
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
transformSchema(dataset.schema, logging = true)
- val splits = dataset.stat.approxQuantile($(inputCol),
- (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
+ if (isSet(inputCols)) {
+ val splitsArray = if (isSet(numBucketsArray)) {
+ val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
+ (0.0 to 1.0 by 1.0 / numOfBuckets).toArray
+ }
+
+ val probabilityArray = probArrayPerCol.flatten.sorted.distinct
+ val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols),
+ probabilityArray, $(relativeError))
+
+ splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) =>
+ val probSet = probs.toSet
+ val idxSet = probabilityArray.zipWithIndex.collect {
+ case (p, idx) if probSet(p) =>
+ idx
+ }.toSet
+ splits.zipWithIndex.collect {
+ case (s, idx) if idxSet(idx) =>
+ s
+ }
+ }
+ } else {
+ dataset.stat.approxQuantile($(inputCols),
+ (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
+ }
+ bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
+ } else {
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
+ bucketizer.setSplits(getDistinctSplits(splits))
+ }
+ copyValues(bucketizer.setParent(this))
+ }
+
+ private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
splits(0) = Double.NegativeInfinity
splits(splits.length - 1) = Double.PositiveInfinity
-
val distinctSplits = splits.distinct
if (splits.length != distinctSplits.length) {
log.warn(s"Some quantiles were identical. Bucketing to
${distinctSplits.length - 1}" +
s" buckets as a result.")
}
- val bucketizer = new Bucketizer(uid)
- .setSplits(distinctSplits.sorted)
- .setHandleInvalid($(handleInvalid))
- copyValues(bucketizer.setParent(this))
+ distinctSplits.sorted
}
@Since("1.6.0")
http://git-wip-us.apache.org/repos/asf/spark/blob/3d8837e5/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index f219f77..e9a75e9 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql._
@@ -146,4 +147,268 @@ class QuantileDiscretizerSuite
val model = discretizer.fit(df)
assert(model.hasParent)
}
+
+ test("Multiple Columns: Test observed number of buckets and their sizes
match expected values") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val datasetSize = 100000
+ val numBuckets = 5
+ val data1 = Array.range(1, 100001, 1).map(_.toDouble)
+ val data2 = Array.range(1, 200000, 2).map(_.toDouble)
+ val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
+
+ val relativeError = discretizer.getRelativeError
+ val isGoodBucket = udf {
+ (size: Int) => math.abs( size - (datasetSize / numBuckets)) <=
(relativeError * datasetSize)
+ }
+
+ for (i <- 1 to 2) {
+ val observedNumBuckets = result.select("result" + i).distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of
buckets.")
+
+ val numGoodBuckets = result.groupBy("result" +
i).count.filter(isGoodBucket($"count")).count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
+ }
+ }
+
+ test("Multiple Columns: Test on data with high proportion of duplicated
values") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val numBuckets = 5
+ val expectedNumBucket = 3
+ val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0,
3.0)
+ val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0, 1.0,
2.0)
+ val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
+ for (i <- 1 to 2) {
+ val observedNumBuckets = result.select("result" + i).distinct.count
+ assert(observedNumBuckets == expectedNumBucket,
+ s"Observed number of buckets are not correct." +
+ s" Expected $expectedNumBucket but found ($observedNumBuckets")
+ }
+ }
+
+ test("Multiple Columns: Test transform on data with NaN value") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val numBuckets = 3
+ val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN,
Double.NaN, Double.NaN)
+ val expectedKeep1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0)
+ val expectedSkip1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0)
+ val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, Double.NaN,
Double.NaN, Double.NaN)
+ val expectedKeep2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0, 3.0, 3.0, 3.0)
+ val expectedSkip2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0)
+
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setNumBuckets(numBuckets)
+
+ withClue("QuantileDiscretizer with handleInvalid=error should throw
exception for NaN values") {
+ val dataFrame: DataFrame =
validData1.zip(validData2).toSeq.toDF("input1", "input2")
+ intercept[SparkException] {
+ discretizer.fit(dataFrame).transform(dataFrame).collect()
+ }
+ }
+
+ List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1,
expectedSkip2)).foreach {
+ case (u, v, w) =>
+ discretizer.setHandleInvalid(u)
+ val dataFrame: DataFrame =
validData1.zip(validData2).zip(v).zip(w).map {
+ case (((a, b), c), d) => (a, b, c, d)
+ }.toSeq.toDF("input1", "input2", "expected1", "expected2")
+ val result = discretizer.fit(dataFrame).transform(dataFrame)
+ result.select("result1", "expected1", "result2",
"expected2").collect().foreach {
+ case Row(x: Double, y: Double, z: Double, w: Double) =>
+ assert(x === y && w === z)
+ }
+ }
+ }
+
+ test("Multiple Columns: Test numBucketsArray") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val numBucketsArray: Array[Int] = Array(2, 5, 10)
+ val data1 = Array.range(1, 21, 1).map(_.toDouble)
+ val expected1 = Array (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
+ val data2 = Array.range(1, 40, 2).map(_.toDouble)
+ val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0,
+ 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0)
+ val data3 = Array.range(1, 60, 3).map(_.toDouble)
+ val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0,
+ 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0)
+ val data = (0 until 20).map { idx =>
+ (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx),
expected3(idx))
+ }
+ val df =
+ data.toDF("input1", "input2", "input3", "expected1", "expected2",
"expected3")
+
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("result1", "result2", "result3"))
+ .setNumBucketsArray(numBucketsArray)
+
+ discretizer.fit(df).transform(df).
+ select("result1", "expected1", "result2", "expected2", "result3",
"expected3")
+ .collect().foreach {
+ case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3:
Double) =>
+ assert(r1 === e1,
+ s"The result value is not correct after bucketing. Expected $e1 but
found $r1")
+ assert(r2 === e2,
+ s"The result value is not correct after bucketing. Expected $e2 but
found $r2")
+ assert(r3 === e3,
+ s"The result value is not correct after bucketing. Expected $e3 but
found $r3")
+ }
+ }
+
+ test("Multiple Columns: Compare single/multiple column(s)
QuantileDiscretizer in pipeline") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val numBucketsArray: Array[Int] = Array(2, 5, 10)
+ val data1 = Array.range(1, 21, 1).map(_.toDouble)
+ val data2 = Array.range(1, 40, 2).map(_.toDouble)
+ val data3 = Array.range(1, 60, 3).map(_.toDouble)
+ val data = (0 until 20).map { idx =>
+ (data1(idx), data2(idx), data3(idx))
+ }
+ val df =
+ data.toDF("input1", "input2", "input3")
+
+ val multiColsDiscretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("result1", "result2", "result3"))
+ .setNumBucketsArray(numBucketsArray)
+ val plForMultiCols = new Pipeline()
+ .setStages(Array(multiColsDiscretizer))
+ .fit(df)
+
+ val discretizerForCol1 = new QuantileDiscretizer()
+ .setInputCol("input1")
+ .setOutputCol("result1")
+ .setNumBuckets(numBucketsArray(0))
+
+ val discretizerForCol2 = new QuantileDiscretizer()
+ .setInputCol("input2")
+ .setOutputCol("result2")
+ .setNumBuckets(numBucketsArray(1))
+
+ val discretizerForCol3 = new QuantileDiscretizer()
+ .setInputCol("input3")
+ .setOutputCol("result3")
+ .setNumBuckets(numBucketsArray(2))
+
+ val plForSingleCol = new Pipeline()
+ .setStages(Array(discretizerForCol1, discretizerForCol2,
discretizerForCol3))
+ .fit(df)
+
+ val resultForMultiCols = plForMultiCols.transform(df)
+ .select("result1", "result2", "result3")
+ .collect()
+
+ val resultForSingleCol = plForSingleCol.transform(df)
+ .select("result1", "result2", "result3")
+ .collect()
+
+ resultForSingleCol.zip(resultForMultiCols).foreach {
+ case (rowForSingle, rowForMultiCols) =>
+ assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
+ rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
+ rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2))
+ }
+ }
+
+ test("Multiple Columns: Comparing setting numBuckets with setting
numBucketsArray " +
+ "explicitly with identical values") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val data1 = Array.range(1, 21, 1).map(_.toDouble)
+ val data2 = Array.range(1, 40, 2).map(_.toDouble)
+ val data3 = Array.range(1, 60, 3).map(_.toDouble)
+ val data = (0 until 20).map { idx =>
+ (data1(idx), data2(idx), data3(idx))
+ }
+ val df =
+ data.toDF("input1", "input2", "input3")
+
+ val discretizerSingleNumBuckets = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("result1", "result2", "result3"))
+ .setNumBuckets(10)
+
+ val discretizerNumBucketsArray = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2", "input3"))
+ .setOutputCols(Array("result1", "result2", "result3"))
+ .setNumBucketsArray(Array(10, 10, 10))
+
+ val result1 = discretizerSingleNumBuckets.fit(df).transform(df)
+ .select("result1", "result2", "result3")
+ .collect()
+ val result2 = discretizerNumBucketsArray.fit(df).transform(df)
+ .select("result1", "result2", "result3")
+ .collect()
+
+ result1.zip(result2).foreach {
+ case (row1, row2) =>
+ assert(row1.getDouble(0) == row2.getDouble(0) &&
+ row1.getDouble(1) == row2.getDouble(1) &&
+ row1.getDouble(2) == row2.getDouble(2))
+ }
+ }
+
+ test("Multiple Columns: read/write") {
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setNumBucketsArray(Array(5, 10))
+ testDefaultReadWrite(discretizer)
+ }
+
+ test("Multiple Columns: Both inputCol and inputCols are set") {
+ val spark = this.spark
+ import spark.implicits._
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(3)
+ .setInputCols(Array("input1", "input2"))
+ val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
+ .map(Tuple1.apply).toDF("input")
+ // When both inputCol and inputCols are set, we throw Exception.
+ intercept[IllegalArgumentException] {
+ discretizer.fit(df)
+ }
+ }
+
+ test("Multiple Columns: Mismatched sizes of inputCols / outputCols") {
+ val spark = this.spark
+ import spark.implicits._
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input"))
+ .setOutputCols(Array("result1", "result2"))
+ .setNumBuckets(3)
+ val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
+ .map(Tuple1.apply).toDF("input")
+ intercept[IllegalArgumentException] {
+ discretizer.fit(df)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]