Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/6466#discussion_r31299673
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala ---
    @@ -17,94 +17,152 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.SparkException
     import org.apache.spark.annotation.Experimental
    -import org.apache.spark.ml.UnaryTransformer
    -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, 
NominalAttribute}
    +import org.apache.spark.ml.Transformer
    +import org.apache.spark.ml.attribute._
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
     import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
    -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
    -import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
    +import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{DoubleType, StructType}
     
     /**
      * :: Experimental ::
    - * A one-hot encoder that maps a column of label indices to a column of 
binary vectors, with
    - * at most a single one-value. By default, the binary vector has an 
element for each category, so
    - * with 5 categories, an input value of 2.0 would map to an output vector 
of
    - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first 
category is omitted, so the
    - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) 
and an input value
    - * of 0.0 would map to a vector of all zeros. Including the first category 
makes the vector columns
    - * linearly dependent because they sum up to one.
    + * A one-hot encoder that maps a column of category indices to a column of 
binary vectors, with
    + * at most a single one-value per row that indicates the input category 
index.
    + * For example with 5 categories, an input value of 2.0 would map to an 
output vector of
    + * `[0.0, 0.0, 1.0, 0.0]`.
    + * The last category is not included by default (configurable via 
[[OneHotEncoder!.dropLast]]
    + * because it makes the vector entries sum up to one, and hence linearly 
dependent.
    + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
    + * Note that this is different from scikit-learn's OneHotEncoder, which 
keeps all categories.
    + * The output vectors are sparse.
    + *
    + * @see [[StringIndexer]] for converting categorical values into category 
indices
      */
     @Experimental
    -class OneHotEncoder(override val uid: String)
    -  extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol 
with HasOutputCol {
    +class OneHotEncoder(override val uid: String) extends Transformer
    +  with HasInputCol with HasOutputCol {
     
       def this() = this(Identifiable.randomUID("oneHot"))
     
       /**
    -   * Whether to include a component in the encoded vectors for the first 
category, defaults to true.
    +   * Whether to drop the last category in the encoded vector (default: 
true)
        * @group param
        */
    -  final val includeFirst: BooleanParam =
    -    new BooleanParam(this, "includeFirst", "include first category")
    -  setDefault(includeFirst -> true)
    -
    -  private var categories: Array[String] = _
    +  final val dropLast: BooleanParam =
    +    new BooleanParam(this, "dropLast", "whether to drop the last category")
    +  setDefault(dropLast -> true)
     
       /** @group setParam */
    -  def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
    +  def setDropLast(value: Boolean): this.type = set(dropLast, value)
     
       /** @group setParam */
    -  override def setInputCol(value: String): this.type = set(inputCol, value)
    +  def setInputCol(value: String): this.type = set(inputCol, value)
     
       /** @group setParam */
    -  override def setOutputCol(value: String): this.type = set(outputCol, 
value)
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
     
       override def transformSchema(schema: StructType): StructType = {
    -    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
    -    val inputFields = schema.fields
    +    val is = "_is_"
    +    val inputColName = $(inputCol)
         val outputColName = $(outputCol)
    -    require(inputFields.forall(_.name != $(outputCol)),
    -      s"Output column ${$(outputCol)} already exists.")
     
    -    val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
    -    categories = inputColAttr match {
    +    SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
    +    val inputFields = schema.fields
    +    require(!inputFields.exists(_.name == outputColName),
    +      s"Output column $outputColName already exists.")
    +
    +    val inputAttr = Attribute.fromStructField(schema(inputColName))
    +    val outputAttrNames: Option[Array[String]] = inputAttr match {
           case nominal: NominalAttribute =>
    -        nominal.values.getOrElse((0 until 
nominal.numValues.get).map(_.toString).toArray)
    -      case binary: BinaryAttribute => binary.values.getOrElse(Array("0", 
"1"))
    +        if (nominal.values.isDefined) {
    +          nominal.values.map(_.map(v => inputColName + is + v))
    +        } else if (nominal.numValues.isDefined) {
    +          nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + 
is + i))
    +        } else {
    +          None
    +        }
    +      case binary: BinaryAttribute =>
    +        if (binary.values.isDefined) {
    +          binary.values.map(_.map(v => inputColName + is + v))
    +        } else {
    +          Some(Array.tabulate(2)(i => inputColName + is + i))
    +        }
    +      case _: NumericAttribute =>
    +        throw new RuntimeException(
    +          s"The input column $inputColName cannot be numeric.")
           case _ =>
    -        throw new SparkException(s"OneHotEncoder input column 
${$(inputCol)} is not nominal")
    +        None // optimistic about unknown attributes
         }
     
    -    val attrValues = (if ($(includeFirst)) categories else 
categories.drop(1)).toArray
    -    val attr = 
NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
    -    val outputFields = inputFields :+ attr.toStructField()
    +    val filteredOutputAttrNames = outputAttrNames.map { names =>
    +      if ($(dropLast)) {
    +        require(names.length > 1,
    +          s"The input column $inputColName should have at least two 
distinct values.")
    +        names.dropRight(1)
    +      } else {
    +        names
    +      }
    +    }
    +
    +    val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
    +      val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name 
=>
    +        BinaryAttribute.defaultAttr.withName(name)
    +      }
    +      new AttributeGroup($(outputCol), attrs)
    +    } else {
    +      new AttributeGroup($(outputCol))
    +    }
    +
    +    val outputFields = inputFields :+ outputAttrGroup.toStructField()
         StructType(outputFields)
       }
     
    -  protected override def createTransformFunc(): (Double) => Vector = {
    -    val first = $(includeFirst)
    -    val vecLen = if (first) categories.length else categories.length - 1
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    // schema transformation
    +    val is = "_is_"
    +    val inputColName = $(inputCol)
    +    val outputColName = $(outputCol)
    +    val shouldDropLast = $(dropLast)
    +    var outputAttrGroup = AttributeGroup.fromStructField(
    +      transformSchema(dataset.schema)(outputColName))
    +    if (outputAttrGroup.size < 0) {
    +      // If the number of attributes is unknown, we check the values from 
the input column.
    +      val numAttrs = 
dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0))
    +        .aggregate(0.0)(
    --- End diff --
    
    Relatedly, I'm not sure how many columns Spark SQL has been tested with.  
Keeping data in Vectors might be necessary for many ML datasets.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to