Github user huaxingao commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19715#discussion_r150450222
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala ---
    @@ -129,34 +152,95 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
       @Since("2.1.0")
       def setHandleInvalid(value: String): this.type = set(handleInvalid, 
value)
     
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setNumBucketsArray(value: Array[Int]): this.type = 
set(numBucketsArray, value)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, 
value)
    +
    +  private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = {
    +    if (isSet(inputCols) && isSet(inputCol)) {
    +      logWarning("Both `inputCol` and `inputCols` are set, we ignore 
`inputCols` and this " +
    +        "`QuantileDiscretize` only map one column specified by `inputCol`")
    +      false
    +    } else if (isSet(inputCols)) {
    +      true
    +    } else {
    +      false
    +    }
    +  }
    +
    +  private[feature] def getInOutCols: (Array[String], Array[String]) = {
    +    if (!isQuantileDiscretizeMultipleColumns) {
    +      (Array($(inputCol)), Array($(outputCol)))
    +    } else {
    +      require($(inputCols).length == $(outputCols).length,
    +        "inputCols number do not match outputCols")
    +      ($(inputCols), $(outputCols))
    +    }
    +  }
    +
       @Since("1.6.0")
       override def transformSchema(schema: StructType): StructType = {
    -    SchemaUtils.checkNumericType(schema, $(inputCol))
    -    val inputFields = schema.fields
    -    require(inputFields.forall(_.name != $(outputCol)),
    -      s"Output column ${$(outputCol)} already exists.")
    -    val attr = NominalAttribute.defaultAttr.withName($(outputCol))
    -    val outputFields = inputFields :+ attr.toStructField()
    +    val (inputColNames, outputColNames) = getInOutCols
    +    val existingFields = schema.fields
    +    var outputFields = existingFields
    +    inputColNames.zip(outputColNames).map { case (inputColName, 
outputColName) =>
    +      SchemaUtils.checkNumericType(schema, inputColName)
    +      require(existingFields.forall(_.name != outputColName),
    +        s"Output column ${outputColName} already exists.")
    +      val attr = NominalAttribute.defaultAttr.withName(outputColName)
    +      outputFields :+= attr.toStructField()
    +    }
         StructType(outputFields)
       }
     
       @Since("2.0.0")
       override def fit(dataset: Dataset[_]): Bucketizer = {
         transformSchema(dataset.schema, logging = true)
    -    val splits = dataset.stat.approxQuantile($(inputCol),
    -      (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
    +    val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
    +    if (isQuantileDiscretizeMultipleColumns) {
    +      var bucketArray = Array.empty[Int]
    +      if (isSet(numBucketsArray)) {
    +        bucketArray = $(numBucketsArray)
    +      }
    +      else {
    +        bucketArray = Array($(numBuckets))
    +      }
    +      val probabilityArray = bucketArray.toSeq.flatMap { numOfBucket =>
    +        (0.0 to 1.0 by 1.0 / numOfBucket)
    +      }
    +      val splitsArray = dataset.stat.approxQuantile($(inputCols),
    +        probabilityArray.sorted.toArray.distinct, $(relativeError))
    +      val distinctSplitsArray = splitsArray.toSeq.map { splits =>
    +        getDistinctSplits(splits)
    +      }
    +      bucketizer.setSplitsArray(distinctSplitsArray.toArray)
    +      copyValues(bucketizer.setParent(this))
    +    }
    +    else {
    --- End diff --
    
    Will fix this. And fix the same problem in another place. 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to