Github user MLnick commented on a diff in the pull request:
https://github.com/apache/spark/pull/19715#discussion_r153776672
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala ---
@@ -129,34 +152,119 @@ final class QuantileDiscretizer @Since("1.6.0")
(@Since("1.6.0") override val ui
@Since("2.1.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+ /** @group setParam */
+ @Since("2.3.0")
+ def setNumBucketsArray(value: Array[Int]): this.type =
set(numBucketsArray, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(value: Array[String]): this.type = set(outputCols,
value)
+
+ private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = {
+ if (isSet(inputCols) && isSet(inputCol)) {
+ logWarning("Both `inputCol` and `inputCols` are set, we ignore
`inputCols` and this " +
+ "`QuantileDiscretize` only map one column specified by `inputCol`")
+ false
+ } else if (isSet(inputCols)) {
+ true
+ } else {
+ false
+ }
+ }
+
+ private[feature] def getInOutCols: (Array[String], Array[String]) = {
+ if (!isQuantileDiscretizeMultipleColumns) {
+ (Array($(inputCol)), Array($(outputCol)))
+ } else {
+ require($(inputCols).length == $(outputCols).length,
+ "inputCols number do not match outputCols")
+ ($(inputCols), $(outputCols))
+ }
+ }
+
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkNumericType(schema, $(inputCol))
- val inputFields = schema.fields
- require(inputFields.forall(_.name != $(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val (inputColNames, outputColNames) = getInOutCols
+ val existingFields = schema.fields
+ var outputFields = existingFields
+ inputColNames.zip(outputColNames).map { case (inputColName,
outputColName) =>
+ SchemaUtils.checkNumericType(schema, inputColName)
+ require(existingFields.forall(_.name != outputColName),
+ s"Output column ${outputColName} already exists.")
+ val attr = NominalAttribute.defaultAttr.withName(outputColName)
+ outputFields :+= attr.toStructField()
+ }
StructType(outputFields)
}
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
transformSchema(dataset.schema, logging = true)
- val splits = dataset.stat.approxQuantile($(inputCol),
- (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
+ if (isQuantileDiscretizeMultipleColumns) {
+ var bucketArray = Array.empty[Int]
--- End diff --
```scala
val bucketSeq = if (isSet(numBucketsArray)) {
$(numBucketsArray).toSeq
} else {
Seq($(numBuckets))
}
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]