Github user MLnick commented on a diff in the pull request:
https://github.com/apache/spark/pull/19715#discussion_r155744727
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala ---
@@ -129,34 +156,106 @@ final class QuantileDiscretizer @Since("1.6.0")
(@Since("1.6.0") override val ui
@Since("2.1.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+ /** @group setParam */
+ @Since("2.3.0")
+ def setNumBucketsArray(value: Array[Int]): this.type =
set(numBucketsArray, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(value: Array[String]): this.type = set(outputCols,
value)
+
+ private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = {
+ if (isSet(inputCols) && isSet(inputCol)) {
+ logWarning("Both `inputCol` and `inputCols` are set, we ignore
`inputCols` and this " +
+ "`QuantileDiscretizer` will only map one column specified by
`inputCol`")
+ false
+ } else if (isSet(inputCols)) {
+ true
+ } else {
+ false
+ }
+ }
+
+ private[feature] def getInOutCols: (Array[String], Array[String]) = {
+ if (!isQuantileDiscretizeMultipleColumns) {
+ (Array($(inputCol)), Array($(outputCol)))
+ } else {
+ require($(inputCols).length == $(outputCols).length,
+ "inputCols number do not match outputCols")
+ ($(inputCols), $(outputCols))
+ }
+ }
+
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkNumericType(schema, $(inputCol))
- val inputFields = schema.fields
- require(inputFields.forall(_.name != $(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val attr = NominalAttribute.defaultAttr.withName($(outputCol))
- val outputFields = inputFields :+ attr.toStructField()
+ val (inputColNames, outputColNames) = getInOutCols
+ val existingFields = schema.fields
+ var outputFields = existingFields
+ inputColNames.zip(outputColNames).map { case (inputColName,
outputColName) =>
--- End diff --
`map` can be `foreach` because there's no return value
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]