srowen commented on a change in pull request #20146: [SPARK-11215][ML] Add 
multiple columns support to StringIndexer
URL: https://github.com/apache/spark/pull/20146#discussion_r245674258
 
 

 ##########
 File path: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
 ##########
 @@ -79,26 +86,56 @@ private[feature] trait StringIndexerBase extends Params 
with HasHandleInvalid wi
   @Since("2.3.0")
   def getStringOrderType: String = $(stringOrderType)
 
-  /** Validates and transforms the input schema. */
-  protected def validateAndTransformSchema(schema: StructType): StructType = {
-    val inputColName = $(inputCol)
+  /** Returns the input and output column names corresponding in pair. */
+  private[feature] def getInOutCols(): (Array[String], Array[String]) = {
+    ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), 
Seq(outputCols))
+
+    if (isSet(inputCol)) {
+      (Array($(inputCol)), Array($(outputCol)))
+    } else {
+      require($(inputCols).length == $(outputCols).length,
+        "The number of input columns does not match output columns")
+      ($(inputCols), $(outputCols))
+    }
+  }
+
+  private def validateAndTransformField(
+      schema: StructType,
+      inputColName: String,
+      outputColName: String): StructField = {
     val inputDataType = schema(inputColName).dataType
     require(inputDataType == StringType || 
inputDataType.isInstanceOf[NumericType],
       s"The input column $inputColName must be either string type or numeric 
type, " +
         s"but got $inputDataType.")
-    val inputFields = schema.fields
-    val outputColName = $(outputCol)
-    require(inputFields.forall(_.name != outputColName),
+    require(schema.fields.forall(_.name != outputColName),
       s"Output column $outputColName already exists.")
-    val attr = NominalAttribute.defaultAttr.withName($(outputCol))
-    val outputFields = inputFields :+ attr.toStructField()
-    StructType(outputFields)
+    NominalAttribute.defaultAttr.withName($(outputCol)).toStructField()
+  }
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(
+      schema: StructType,
+      skipNonExistsCol: Boolean = false): StructType = {
+    val (inputColNames, outputColNames) = getInOutCols()
+
+    require(outputColNames.distinct.length == outputColNames.length,
+      s"Output columns should not be duplicate.")
+
+    val outputFields = inputColNames.zip(outputColNames).flatMap {
+      case (inputColName, outputColName) =>
+        schema.fieldNames.contains(inputColName) match {
+          case true => Some(validateAndTransformField(schema, inputColName, 
outputColName))
+          case false if skipNonExistsCol => None
+          case _ => throw new SparkException(s"Input column $inputColName does 
not exist.")
 
 Review comment:
   While I prefer two if statements to a case match here, I this this is OK

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to