Github user dbtsai commented on a diff in the pull request: https://github.com/apache/spark/pull/20146#discussion_r183253904 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -79,26 +80,56 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi @Since("2.3.0") def getStringOrderType: String = $(stringOrderType) - /** Validates and transforms the input schema. */ - protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputColName = $(inputCol) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "The number of input columns does not match output columns") + ($(inputCols), $(outputCols)) + } + } + + private def validateAndTransformField( + schema: StructType, + inputColName: String, + outputColName: String): StructField = { val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be either string type or numeric type, " + s"but got $inputDataType.") - val inputFields = schema.fields - val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), + require(schema.fields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() - StructType(outputFields) + NominalAttribute.defaultAttr.withName($(outputCol)).toStructField() + } + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema( + schema: StructType, + skipNonExistsCol: Boolean = false): StructType = { + val (inputColNames, outputColNames) = getInOutCols() + + val outputFields = for (i <- 0 until inputColNames.length) yield { --- End diff -- Nit, why not the following for readability? ```scala val outputFields = inputColNames.zip(outputColNames).flatMap { case (inputColName, outputColName) => schema.fieldNames.contains(inputColName) match { case true => validateAndTransformField(schema, inputColName, outputColName) case false if skipNonExistsCol => None case throw new SparkException(s"Input column $inputColName does not exist." } } ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org