Github user MrBago commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19527#discussion_r148190855
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala 
---
    @@ -0,0 +1,456 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.hadoop.fs.Path
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.annotation.Since
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.attribute._
    +import org.apache.spark.ml.linalg.Vectors
    +import org.apache.spark.ml.param._
    +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, 
HasOutputCols}
    +import org.apache.spark.ml.util._
    +import org.apache.spark.sql.{DataFrame, Dataset}
    +import org.apache.spark.sql.expressions.UserDefinedFunction
    +import org.apache.spark.sql.functions.{col, lit, udf}
    +import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, 
StructType}
    +
    +/** Private trait for params and common methods for OneHotEncoderEstimator 
and OneHotEncoderModel */
    +private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
    +    with HasInputCols with HasOutputCols {
    +
    +  /**
    +   * Param for how to handle invalid data.
    +   * Options are 'keep' (invalid data produces a vector of zeros) or 
'error' (throw an error).
    +   * Default: "error"
    +   * @group param
    +   */
    +  @Since("2.3.0")
    +  override val handleInvalid: Param[String] = new Param[String](this, 
"handleInvalid",
    +    "How to handle invalid data " +
    +    "Options are 'keep' (invalid data produces a vector of zeros) or error 
(throw an error).",
    +    
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
    +
    +  setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
    +
    +  /**
    +   * Whether to drop the last category in the encoded vector (default: 
true)
    +   * @group param
    +   */
    +  @Since("2.3.0")
    +  final val dropLast: BooleanParam =
    +    new BooleanParam(this, "dropLast", "whether to drop the last category")
    +  setDefault(dropLast -> true)
    +
    +  /** @group getParam */
    +  @Since("2.3.0")
    +  def getDropLast: Boolean = $(dropLast)
    +
    +  protected def validateAndTransformSchema(schema: StructType): StructType 
= {
    +    val inputColNames = $(inputCols)
    +    val outputColNames = $(outputCols)
    +    val existingFields = schema.fields
    +
    +    require(inputColNames.length == outputColNames.length,
    +      s"The number of input columns ${inputColNames.length} must be the 
same as the number of " +
    +        s"output columns ${outputColNames.length}.")
    +
    +    inputColNames.zip(outputColNames).map { case (inputColName, 
outputColName) =>
    +      require(schema(inputColName).dataType.isInstanceOf[NumericType],
    +        s"Input column must be of type NumericType but got 
${schema(inputColName).dataType}")
    +      require(!existingFields.exists(_.name == outputColName),
    +        s"Output column $outputColName already exists.")
    +    }
    +
    +    // Prepares output columns with proper attributes by examining input 
columns.
    +    val inputFields = $(inputCols).map(schema(_))
    +
    +    val outputFields = inputFields.zip(outputColNames).map { case 
(inputField, outputColName) =>
    +      OneHotEncoderCommon.transformOutputColumnSchema(
    +        inputField, $(dropLast), outputColName)
    +    }
    +    StructType(schema.fields ++ outputFields)
    +  }
    +}
    +
    +/**
    + * A one-hot encoder that maps a column of category indices to a column of 
binary vectors, with
    + * at most a single one-value per row that indicates the input category 
index.
    + * For example with 5 categories, an input value of 2.0 would map to an 
output vector of
    + * `[0.0, 0.0, 1.0, 0.0]`.
    + * The last category is not included by default (configurable via 
`dropLast`),
    + * because it makes the vector entries sum up to one, and hence linearly 
dependent.
    + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
    + *
    + * @note This is different from scikit-learn's OneHotEncoder, which keeps 
all categories.
    + * The output vectors are sparse.
    + *
    + * @see `StringIndexer` for converting categorical values into category 
indices
    + */
    +@Since("2.3.0")
    +class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val 
uid: String)
    +    extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with 
DefaultParamsWritable {
    +
    +  @Since("2.3.0")
    +  def this() = this(Identifiable.randomUID("oneHotEncoder"))
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setInputCols(values: Array[String]): this.type = set(inputCols, 
values)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setOutputCols(values: Array[String]): this.type = set(outputCols, 
values)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setDropLast(value: Boolean): this.type = set(dropLast, value)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setHandleInvalid(value: String): this.type = set(handleInvalid, 
value)
    +
    +  @Since("2.3.0")
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +
    +  @Since("2.3.0")
    +  override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
    +    val transformedSchema = transformSchema(dataset.schema)
    +    val categorySizes = new Array[Int]($(outputCols).length)
    +
    +    val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case 
(outputColName, idx) =>
    +      val numOfAttrs = AttributeGroup.fromStructField(
    +        transformedSchema(outputColName)).size
    +      if (numOfAttrs < 0) {
    +        Some(idx)
    +      } else {
    +        categorySizes(idx) = numOfAttrs
    +        None
    +      }
    +    }
    +
    +    // Some input columns don't have attributes or their attributes don't 
have necessary info.
    +    // We need to scan the data to get the number of values for each 
column.
    +    if (columnToScanIndices.length > 0) {
    +      val inputColNames = columnToScanIndices.map($(inputCols)(_))
    +      val outputColNames = columnToScanIndices.map($(outputCols)(_))
    +      val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData(
    +        dataset, $(dropLast), inputColNames, outputColNames)
    +      attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) 
=>
    +        categorySizes(idx) = attrGroup.size
    +      }
    +    }
    +
    +    val model = new OneHotEncoderModel(uid, categorySizes).setParent(this)
    +    copyValues(model)
    +  }
    +
    +  @Since("2.3.0")
    +  override def copy(extra: ParamMap): OneHotEncoderEstimator = 
defaultCopy(extra)
    +}
    +
    +@Since("2.3.0")
    +object OneHotEncoderEstimator extends 
DefaultParamsReadable[OneHotEncoderEstimator] {
    +
    +  private[feature] val KEEP_INVALID: String = "keep"
    +  private[feature] val ERROR_INVALID: String = "error"
    +  private[feature] val supportedHandleInvalids: Array[String] = 
Array(KEEP_INVALID, ERROR_INVALID)
    +
    +  @Since("2.3.0")
    +  override def load(path: String): OneHotEncoderEstimator = 
super.load(path)
    +}
    +
    +@Since("2.3.0")
    +class OneHotEncoderModel private[ml] (
    +    @Since("2.3.0") override val uid: String,
    +    @Since("2.3.0") val categorySizes: Array[Int])
    +  extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable 
{
    --- End diff --
    
    I don't have a sense of how common it is, but the use case I had in mind 
was applying a OneHotEncoder to column produced by a StringIndexerModel 
(trained on a dictionary for example). In this case the size of the Vectors 
produced is based on the cardinality of the StringIndexerModel (which we can 
easily get by using StringIndexerModel.labels.size).
    
    That being said this use case is easily covered by a simple Pipeline that 
pairs StringIndexer and OneHotEncoderEstimater at training time, so I'm not 
sure there is an issue.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to