Github user MLnick commented on a diff in the pull request:
https://github.com/apache/spark/pull/19715#discussion_r155759105
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala ---
@@ -129,34 +156,106 @@ final class QuantileDiscretizer @Since("1.6.0")
(@Since("1.6.0") override val ui
@Since("2.1.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+ /** @group setParam */
+ @Since("2.3.0")
+ def setNumBucketsArray(value: Array[Int]): this.type =
set(numBucketsArray, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(value: Array[String]): this.type = set(outputCols,
value)
+
+ private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = {
+ if (isSet(inputCols) && isSet(inputCol)) {
+ logWarning("Both `inputCol` and `inputCols` are set, we ignore
`inputCols` and this " +
+ "`QuantileDiscretizer` will only map one column specified by
`inputCol`")
+ false
+ } else if (isSet(inputCols)) {
+ true
+ } else {
+ false
+ }
+ }
+
+ private[feature] def getInOutCols: (Array[String], Array[String]) = {
+ if (!isQuantileDiscretizeMultipleColumns) {
+ (Array($(inputCol)), Array($(outputCol)))
+ } else {
+ require($(inputCols).length == $(outputCols).length,
+ "inputCols number do not match outputCols")
+ ($(inputCols), $(outputCols))
+ }
+ }
+
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkNumericType(schema, $(inputCol))
- val inputFields = schema.fields
- require(inputFields.forall(_.name != $(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val (inputColNames, outputColNames) = getInOutCols
+ val existingFields = schema.fields
+ var outputFields = existingFields
+ inputColNames.zip(outputColNames).map { case (inputColName,
outputColName) =>
+ SchemaUtils.checkNumericType(schema, inputColName)
+ require(existingFields.forall(_.name != outputColName),
+ s"Output column ${outputColName} already exists.")
+ val attr = NominalAttribute.defaultAttr.withName(outputColName)
+ outputFields :+= attr.toStructField()
+ }
StructType(outputFields)
}
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
transformSchema(dataset.schema, logging = true)
- val splits = dataset.stat.approxQuantile($(inputCol),
- (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
--- End diff --
Looking at this now, the `Array.fill` approach probably adds needless
complexity.
But the multi-buckets case can perhaps still be cleaned up. How about
something like this:
```scala
override def fit(dataset: Dataset[_]): Bucketizer = {
transformSchema(dataset.schema, logging = true)
val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
if (isQuantileDiscretizeMultipleColumns) {
val splitsArray = if (isSet(numBucketsArray)) {
val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
(0.0 to 1.0 by 1.0 / numOfBuckets).toArray
}
val probabilityArray = probArrayPerCol.flatten.sorted.distinct
val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols),
probabilityArray, $(relativeError))
splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) =>
val probSet = probs.toSet
val idxSet = probabilityArray.zipWithIndex.collect {
case (p, idx) if probSet(p) =>
idx
}.toSet
splits.zipWithIndex.collect {
case (s, idx) if idxSet(idx) =>
s
}
}
} else {
dataset.stat.approxQuantile($(inputCols),
(0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
}
bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
} else {
val splits = dataset.stat.approxQuantile($(inputCol),
(0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
bucketizer.setSplits(getDistinctSplits(splits))
}
copyValues(bucketizer.setParent(this))
}
```
Then we don't need `getSplitsForEachColumn` method (or part of the above
could be factored out into a private method if it makes sense).
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]