Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20146#discussion_r239992360 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -79,26 +81,53 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi @Since("2.3.0") def getStringOrderType: String = $(stringOrderType) - /** Validates and transforms the input schema. */ - protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputColName = $(inputCol) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "The number of input columns does not match output columns") + ($(inputCols), $(outputCols)) + } + } + + private def validateAndTransformField( + schema: StructType, + inputColName: String, + outputColName: String): StructField = { val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be either string type or numeric type, " + s"but got $inputDataType.") - val inputFields = schema.fields - val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), + require(schema.fields.forall(_.name != outputColName), --- End diff -- Sounds good to me. Added.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org