Repository: spark
Updated Branches:
  refs/heads/master 5955a2d0f -> 994065d89


[SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator

## What changes were proposed in this pull request?

This patch adds a new class `OneHotEncoderEstimator` which extends `Estimator`. 
The `fit` method returns `OneHotEncoderModel`.

Common methods between existing `OneHotEncoder` and new 
`OneHotEncoderEstimator`, such as transforming schema, are extracted and put 
into `OneHotEncoderCommon` to reduce code duplication.

### Multi-column support

`OneHotEncoderEstimator` adds simpler multi-column support because it is new 
API and can be free from backward compatibility.

### handleInvalid Param support

`OneHotEncoderEstimator` supports `handleInvalid` Param. It supports `error` 
and `keep`.

## How was this patch tested?

Added new test suite `OneHotEncoderEstimatorSuite`.

Author: Liang-Chi Hsieh <[email protected]>

Closes #19527 from viirya/SPARK-13030.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/994065d8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/994065d8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/994065d8

Branch: refs/heads/master
Commit: 994065d891a23ed89a09b3f95bc3f1f986793e0d
Parents: 5955a2d
Author: Liang-Chi Hsieh <[email protected]>
Authored: Sun Dec 31 15:28:59 2017 -0800
Committer: Joseph K. Bradley <[email protected]>
Committed: Sun Dec 31 15:28:59 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/feature/OneHotEncoder.scala |  83 +--
 .../ml/feature/OneHotEncoderEstimator.scala     | 522 +++++++++++++++++++
 .../feature/OneHotEncoderEstimatorSuite.scala   | 421 +++++++++++++++
 3 files changed, 960 insertions(+), 66 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/994065d8/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index a669da1..5ab6c2d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, 
StructType}
  * The output vectors are sparse.
  *
  * @see `StringIndexer` for converting categorical values into category indices
+ * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and 
this `OneHotEncoder`
+ * will be removed in 3.0.0.
  */
 @Since("1.4.0")
+@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this 
`OneHotEncoder`" +
+  " will be removed in 3.0.0.", "2.3.0")
 class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) 
extends Transformer
   with HasInputCol with HasOutputCol with DefaultParamsWritable {
 
@@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") 
override val uid: String) e
   override def transformSchema(schema: StructType): StructType = {
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
+    val inputFields = schema.fields
 
     require(schema(inputColName).dataType.isInstanceOf[NumericType],
       s"Input column must be of type NumericType but got 
${schema(inputColName).dataType}")
-    val inputFields = schema.fields
     require(!inputFields.exists(_.name == outputColName),
       s"Output column $outputColName already exists.")
 
-    val inputAttr = Attribute.fromStructField(schema(inputColName))
-    val outputAttrNames: Option[Array[String]] = inputAttr match {
-      case nominal: NominalAttribute =>
-        if (nominal.values.isDefined) {
-          nominal.values
-        } else if (nominal.numValues.isDefined) {
-          nominal.numValues.map(n => Array.tabulate(n)(_.toString))
-        } else {
-          None
-        }
-      case binary: BinaryAttribute =>
-        if (binary.values.isDefined) {
-          binary.values
-        } else {
-          Some(Array.tabulate(2)(_.toString))
-        }
-      case _: NumericAttribute =>
-        throw new RuntimeException(
-          s"The input column $inputColName cannot be numeric.")
-      case _ =>
-        None // optimistic about unknown attributes
-    }
-
-    val filteredOutputAttrNames = outputAttrNames.map { names =>
-      if ($(dropLast)) {
-        require(names.length > 1,
-          s"The input column $inputColName should have at least two distinct 
values.")
-        names.dropRight(1)
-      } else {
-        names
-      }
-    }
-
-    val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
-      val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
-        BinaryAttribute.defaultAttr.withName(name)
-      }
-      new AttributeGroup($(outputCol), attrs)
-    } else {
-      new AttributeGroup($(outputCol))
-    }
-
-    val outputFields = inputFields :+ outputAttrGroup.toStructField()
+    val outputField = OneHotEncoderCommon.transformOutputColumnSchema(
+      schema(inputColName), outputColName, $(dropLast))
+    val outputFields = inputFields :+ outputField
     StructType(outputFields)
   }
 
@@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") 
override val uid: String) e
     // schema transformation
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
-    val shouldDropLast = $(dropLast)
-    var outputAttrGroup = AttributeGroup.fromStructField(
+
+    val outputAttrGroupFromSchema = AttributeGroup.fromStructField(
       transformSchema(dataset.schema)(outputColName))
-    if (outputAttrGroup.size < 0) {
-      // If the number of attributes is unknown, we check the values from the 
input column.
-      val numAttrs = 
dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
-        .treeAggregate(0.0)(
-          (m, x) => {
-            assert(x <= Int.MaxValue,
-              s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but 
got $x")
-            assert(x >= 0.0 && x == x.toInt,
-              s"Values from column $inputColName must be indices, but got $x.")
-            math.max(m, x)
-          },
-          (m0, m1) => {
-            math.max(m0, m1)
-          }
-        ).toInt + 1
-      val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
-      val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else 
outputAttrNames
-      val outputAttrs: Array[Attribute] =
-        filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
-      outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
+
+    val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) {
+      OneHotEncoderCommon.getOutputAttrGroupFromData(
+        dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0)
+    } else {
+      outputAttrGroupFromSchema
     }
+
     val metadata = outputAttrGroup.toMetadata()
 
     // data transformation

http://git-wip-us.apache.org/repos/asf/spark/blob/994065d8/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
new file mode 100644
index 0000000..074622d
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
@@ -0,0 +1,522 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, 
HasOutputCols}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.expressions.UserDefinedFunction
+import org.apache.spark.sql.functions.{col, lit, udf}
+import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, 
StructType}
+
+/** Private trait for params and common methods for OneHotEncoderEstimator and 
OneHotEncoderModel */
+private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
+    with HasInputCols with HasOutputCols {
+
+  /**
+   * Param for how to handle invalid data.
+   * Options are 'keep' (invalid data presented as an extra categorical 
feature) or
+   * 'error' (throw an error).
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.3.0")
+  override val handleInvalid: Param[String] = new Param[String](this, 
"handleInvalid",
+    "How to handle invalid data " +
+    "Options are 'keep' (invalid data presented as an extra categorical 
feature) " +
+    "or error (throw an error).",
+    ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
+
+  setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
+
+  /**
+   * Whether to drop the last category in the encoded vector (default: true)
+   * @group param
+   */
+  @Since("2.3.0")
+  final val dropLast: BooleanParam =
+    new BooleanParam(this, "dropLast", "whether to drop the last category")
+  setDefault(dropLast -> true)
+
+  /** @group getParam */
+  @Since("2.3.0")
+  def getDropLast: Boolean = $(dropLast)
+
+  protected def validateAndTransformSchema(
+      schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType 
= {
+    val inputColNames = $(inputCols)
+    val outputColNames = $(outputCols)
+    val existingFields = schema.fields
+
+    require(inputColNames.length == outputColNames.length,
+      s"The number of input columns ${inputColNames.length} must be the same 
as the number of " +
+        s"output columns ${outputColNames.length}.")
+
+    // Input columns must be NumericType.
+    inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))
+
+    // Prepares output columns with proper attributes by examining input 
columns.
+    val inputFields = $(inputCols).map(schema(_))
+
+    val outputFields = inputFields.zip(outputColNames).map { case (inputField, 
outputColName) =>
+      OneHotEncoderCommon.transformOutputColumnSchema(
+        inputField, outputColName, dropLast, keepInvalid)
+    }
+    outputFields.foldLeft(schema) { case (newSchema, outputField) =>
+      SchemaUtils.appendColumn(newSchema, outputField)
+    }
+  }
+}
+
+/**
+ * A one-hot encoder that maps a column of category indices to a column of 
binary vectors, with
+ * at most a single one-value per row that indicates the input category index.
+ * For example with 5 categories, an input value of 2.0 would map to an output 
vector of
+ * `[0.0, 0.0, 1.0, 0.0]`.
+ * The last category is not included by default (configurable via `dropLast`),
+ * because it makes the vector entries sum up to one, and hence linearly 
dependent.
+ * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ *
+ * @note This is different from scikit-learn's OneHotEncoder, which keeps all 
categories.
+ * The output vectors are sparse.
+ *
+ * When `handleInvalid` is configured to 'keep', an extra "category" 
indicating invalid values is
+ * added as last category. So when `dropLast` is true, invalid values are 
encoded as all-zeros
+ * vector.
+ *
+ * @note When encoding multi-column by using `inputCols` and `outputCols` 
params, input/output cols
+ * come in pairs, specified by the order in the arrays, and each pair is 
treated independently.
+ *
+ * @see `StringIndexer` for converting categorical values into category indices
+ */
+@Since("2.3.0")
+class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val 
uid: String)
+    extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with 
DefaultParamsWritable {
+
+  @Since("2.3.0")
+  def this() = this(Identifiable.randomUID("oneHotEncoder"))
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setInputCols(values: Array[String]): this.type = set(inputCols, values)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setDropLast(value: Boolean): this.type = set(dropLast, value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+  @Since("2.3.0")
+  override def transformSchema(schema: StructType): StructType = {
+    val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
+    validateAndTransformSchema(schema, dropLast = $(dropLast),
+      keepInvalid = keepInvalid)
+  }
+
+  @Since("2.3.0")
+  override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
+    transformSchema(dataset.schema)
+
+    // Compute the plain number of categories without `handleInvalid` and
+    // `dropLast` taken into account.
+    val transformedSchema = validateAndTransformSchema(dataset.schema, 
dropLast = false,
+      keepInvalid = false)
+    val categorySizes = new Array[Int]($(outputCols).length)
+
+    val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case 
(outputColName, idx) =>
+      val numOfAttrs = AttributeGroup.fromStructField(
+        transformedSchema(outputColName)).size
+      if (numOfAttrs < 0) {
+        Some(idx)
+      } else {
+        categorySizes(idx) = numOfAttrs
+        None
+      }
+    }
+
+    // Some input columns don't have attributes or their attributes don't have 
necessary info.
+    // We need to scan the data to get the number of values for each column.
+    if (columnToScanIndices.length > 0) {
+      val inputColNames = columnToScanIndices.map($(inputCols)(_))
+      val outputColNames = columnToScanIndices.map($(outputCols)(_))
+
+      // When fitting data, we want the plain number of categories without 
`handleInvalid` and
+      // `dropLast` taken into account.
+      val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData(
+        dataset, inputColNames, outputColNames, dropLast = false)
+      attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) =>
+        categorySizes(idx) = attrGroup.size
+      }
+    }
+
+    val model = new OneHotEncoderModel(uid, categorySizes).setParent(this)
+    copyValues(model)
+  }
+
+  @Since("2.3.0")
+  override def copy(extra: ParamMap): OneHotEncoderEstimator = 
defaultCopy(extra)
+}
+
+@Since("2.3.0")
+object OneHotEncoderEstimator extends 
DefaultParamsReadable[OneHotEncoderEstimator] {
+
+  private[feature] val KEEP_INVALID: String = "keep"
+  private[feature] val ERROR_INVALID: String = "error"
+  private[feature] val supportedHandleInvalids: Array[String] = 
Array(KEEP_INVALID, ERROR_INVALID)
+
+  @Since("2.3.0")
+  override def load(path: String): OneHotEncoderEstimator = super.load(path)
+}
+
+@Since("2.3.0")
+class OneHotEncoderModel private[ml] (
+    @Since("2.3.0") override val uid: String,
+    @Since("2.3.0") val categorySizes: Array[Int])
+  extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable {
+
+  import OneHotEncoderModel._
+
+  // Returns the category size for a given index with `dropLast` and 
`handleInvalid`
+  // taken into account.
+  private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
+    val dropLast = getDropLast
+    val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+
+    if (!dropLast && keepInvalid) {
+      // When `handleInvalid` is "keep", an extra category is added as last 
category
+      // for invalid data.
+      orgCategorySize + 1
+    } else if (dropLast && !keepInvalid) {
+      // When `dropLast` is true, the last category is removed.
+      orgCategorySize - 1
+    } else {
+      // When `dropLast` is true and `handleInvalid` is "keep", the extra 
category for invalid
+      // data is removed. Thus, it is the same as the plain number of 
categories.
+      orgCategorySize
+    }
+  }
+
+  private def encoder: UserDefinedFunction = {
+    val oneValue = Array(1.0)
+    val emptyValues = Array.empty[Double]
+    val emptyIndices = Array.empty[Int]
+    val dropLast = getDropLast
+    val handleInvalid = getHandleInvalid
+    val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+
+    // The udf performed on input data. The first parameter is the input 
value. The second
+    // parameter is the index of input.
+    udf { (label: Double, idx: Int) =>
+      val plainNumCategories = categorySizes(idx)
+      val size = configedCategorySize(plainNumCategories, idx)
+
+      if (label < 0) {
+        throw new SparkException(s"Negative value: $label. Input can't be 
negative.")
+      } else if (label == size && dropLast && !keepInvalid) {
+        // When `dropLast` is true and `handleInvalid` is not "keep",
+        // the last category is removed.
+        Vectors.sparse(size, emptyIndices, emptyValues)
+      } else if (label >= plainNumCategories && keepInvalid) {
+        // When `handleInvalid` is "keep", encodes invalid data to last 
category (and removed
+        // if `dropLast` is true)
+        if (dropLast) {
+          Vectors.sparse(size, emptyIndices, emptyValues)
+        } else {
+          Vectors.sparse(size, Array(size - 1), oneValue)
+        }
+      } else if (label < plainNumCategories) {
+        Vectors.sparse(size, Array(label.toInt), oneValue)
+      } else {
+        assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
+        throw new SparkException(s"Unseen value: $label. To handle unseen 
values, " +
+          s"set Param handleInvalid to 
${OneHotEncoderEstimator.KEEP_INVALID}.")
+      }
+    }
+  }
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setInputCols(values: Array[String]): this.type = set(inputCols, values)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setOutputCols(values: Array[String]): this.type = set(outputCols, values)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setDropLast(value: Boolean): this.type = set(dropLast, value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+  @Since("2.3.0")
+  override def transformSchema(schema: StructType): StructType = {
+    val inputColNames = $(inputCols)
+    val outputColNames = $(outputCols)
+
+    require(inputColNames.length == categorySizes.length,
+      s"The number of input columns ${inputColNames.length} must be the same 
as the number of " +
+        s"features ${categorySizes.length} during fitting.")
+
+    val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
+    val transformedSchema = validateAndTransformSchema(schema, dropLast = 
$(dropLast),
+      keepInvalid = keepInvalid)
+    verifyNumOfValues(transformedSchema)
+  }
+
+  /**
+   * If the metadata of input columns also specifies the number of categories, 
we need to
+   * compare with expected category number with `handleInvalid` and `dropLast` 
taken into
+   * account. Mismatched numbers will cause exception.
+   */
+  private def verifyNumOfValues(schema: StructType): StructType = {
+    $(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
+      val inputColName = $(inputCols)(idx)
+      val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
+
+      // If the input metadata specifies number of category for output column,
+      // comparing with expected category number with `handleInvalid` and
+      // `dropLast` taken into account.
+      if (attrGroup.attributes.nonEmpty) {
+        val numCategories = configedCategorySize(categorySizes(idx), idx)
+        require(attrGroup.size == numCategories, "OneHotEncoderModel expected 
" +
+          s"$numCategories categorical values for input column 
${inputColName}, " +
+            s"but the input column had metadata specifying ${attrGroup.size} 
values.")
+      }
+    }
+    schema
+  }
+
+  @Since("2.3.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    val transformedSchema = transformSchema(dataset.schema, logging = true)
+    val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
+
+    val encodedColumns = (0 until $(inputCols).length).map { idx =>
+      val inputColName = $(inputCols)(idx)
+      val outputColName = $(outputCols)(idx)
+
+      val outputAttrGroupFromSchema =
+        AttributeGroup.fromStructField(transformedSchema(outputColName))
+
+      val metadata = if (outputAttrGroupFromSchema.size < 0) {
+        OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName,
+          categorySizes(idx), $(dropLast), keepInvalid).toMetadata()
+      } else {
+        outputAttrGroupFromSchema.toMetadata()
+      }
+
+      encoder(col(inputColName).cast(DoubleType), lit(idx))
+        .as(outputColName, metadata)
+    }
+    dataset.withColumns($(outputCols), encodedColumns)
+  }
+
+  @Since("2.3.0")
+  override def copy(extra: ParamMap): OneHotEncoderModel = {
+    val copied = new OneHotEncoderModel(uid, categorySizes)
+    copyValues(copied, extra).setParent(parent)
+  }
+
+  @Since("2.3.0")
+  override def write: MLWriter = new OneHotEncoderModelWriter(this)
+}
+
+@Since("2.3.0")
+object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
+
+  private[OneHotEncoderModel]
+  class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends 
MLWriter {
+
+    private case class Data(categorySizes: Array[Int])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.categorySizes)
+      val dataPath = new Path(path, "data").toString
+      
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] {
+
+    private val className = classOf[OneHotEncoderModel].getName
+
+    override def load(path: String): OneHotEncoderModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sparkSession.read.parquet(dataPath)
+        .select("categorySizes")
+        .head()
+      val categorySizes = data.getAs[Seq[Int]](0).toArray
+      val model = new OneHotEncoderModel(metadata.uid, categorySizes)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("2.3.0")
+  override def read: MLReader[OneHotEncoderModel] = new 
OneHotEncoderModelReader
+
+  @Since("2.3.0")
+  override def load(path: String): OneHotEncoderModel = super.load(path)
+}
+
+/**
+ * Provides some helper methods used by both `OneHotEncoder` and 
`OneHotEncoderEstimator`.
+ */
+private[feature] object OneHotEncoderCommon {
+
+  private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] 
= {
+    val inputAttr = Attribute.fromStructField(inputCol)
+    inputAttr match {
+      case nominal: NominalAttribute =>
+        if (nominal.values.isDefined) {
+          nominal.values
+        } else if (nominal.numValues.isDefined) {
+          nominal.numValues.map(n => Array.tabulate(n)(_.toString))
+        } else {
+          None
+        }
+      case binary: BinaryAttribute =>
+        if (binary.values.isDefined) {
+          binary.values
+        } else {
+          Some(Array.tabulate(2)(_.toString))
+        }
+      case _: NumericAttribute =>
+        throw new RuntimeException(
+          s"The input column ${inputCol.name} cannot be continuous-value.")
+      case _ =>
+        None // optimistic about unknown attributes
+    }
+  }
+
+  /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as 
required. */
+  private def genOutputAttrGroup(
+      outputAttrNames: Option[Array[String]],
+      outputColName: String): AttributeGroup = {
+    outputAttrNames.map { attrNames =>
+      val attrs: Array[Attribute] = attrNames.map { name =>
+        BinaryAttribute.defaultAttr.withName(name)
+      }
+      new AttributeGroup(outputColName, attrs)
+    }.getOrElse{
+      new AttributeGroup(outputColName)
+    }
+  }
+
+  /**
+   * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s 
output column.
+   */
+  def transformOutputColumnSchema(
+      inputCol: StructField,
+      outputColName: String,
+      dropLast: Boolean,
+      keepInvalid: Boolean = false): StructField = {
+    val outputAttrNames = genOutputAttrNames(inputCol)
+    val filteredOutputAttrNames = outputAttrNames.map { names =>
+      if (dropLast && !keepInvalid) {
+        require(names.length > 1,
+          s"The input column ${inputCol.name} should have at least two 
distinct values.")
+        names.dropRight(1)
+      } else if (!dropLast && keepInvalid) {
+        names ++ Seq("invalidValues")
+      } else {
+        names
+      }
+    }
+
+    genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField()
+  }
+
+  /**
+   * This method is called when we want to generate `AttributeGroup` from 
actual data for
+   * one-hot encoder.
+   */
+  def getOutputAttrGroupFromData(
+      dataset: Dataset[_],
+      inputColNames: Seq[String],
+      outputColNames: Seq[String],
+      dropLast: Boolean): Seq[AttributeGroup] = {
+    // The RDD approach has advantage of early-stop if any values are invalid. 
It seems that
+    // DataFrame ops don't have equivalent functions.
+    val columns = inputColNames.map { inputColName =>
+      col(inputColName).cast(DoubleType)
+    }
+    val numOfColumns = columns.length
+
+    val numAttrsArray = dataset.select(columns: _*).rdd.map { row =>
+      (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
+    }.treeAggregate(new Array[Double](numOfColumns))(
+      (maxValues, curValues) => {
+        (0 until numOfColumns).foreach { idx =>
+          val x = curValues(idx)
+          assert(x <= Int.MaxValue,
+            s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but 
got $x.")
+          assert(x >= 0.0 && x == x.toInt,
+            s"Values from column ${inputColNames(idx)} must be indices, but 
got $x.")
+          maxValues(idx) = math.max(maxValues(idx), x)
+        }
+        maxValues
+      },
+      (m0, m1) => {
+        (0 until numOfColumns).foreach { idx =>
+          m0(idx) = math.max(m0(idx), m1(idx))
+        }
+        m0
+      }
+    ).map(_.toInt + 1)
+
+    outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) =>
+      createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, 
keepInvalid = false)
+    }
+  }
+
+  /** Creates an `AttributeGroup` with the required number of 
`BinaryAttribute`. */
+  def createAttrGroupForAttrNames(
+      outputColName: String,
+      numAttrs: Int,
+      dropLast: Boolean,
+      keepInvalid: Boolean): AttributeGroup = {
+    val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
+    val filtered = if (dropLast && !keepInvalid) {
+      outputAttrNames.dropRight(1)
+    } else if (!dropLast && keepInvalid) {
+      outputAttrNames ++ Seq("invalidValues")
+    } else {
+      outputAttrNames
+    }
+    genOutputAttrGroup(Some(filtered), outputColName)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/994065d8/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
new file mode 100644
index 0000000..1d3f845
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
@@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, 
NominalAttribute}
+import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
+
+class OneHotEncoderEstimatorSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+  import testImplicits._
+
+  test("params") {
+    ParamsSuite.checkParams(new OneHotEncoderEstimator)
+  }
+
+  test("OneHotEncoderEstimator dropLast = false") {
+    val data = Seq(
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected", new VectorUDT)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("output"))
+    assert(encoder.getDropLast === true)
+    encoder.setDropLast(false)
+    assert(encoder.getDropLast === false)
+
+    val model = encoder.fit(df)
+    val encoded = model.transform(df)
+    encoded.select("output", "expected").rdd.map { r =>
+      (r.getAs[Vector](0), r.getAs[Vector](1))
+    }.collect().foreach { case (vec1, vec2) =>
+      assert(vec1 === vec2)
+    }
+  }
+
+  test("OneHotEncoderEstimator dropLast = true") {
+    val data = Seq(
+      Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
+      Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+      Row(2.0, Vectors.sparse(2, Seq())),
+      Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
+      Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
+      Row(2.0, Vectors.sparse(2, Seq())))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected", new VectorUDT)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("output"))
+
+    val model = encoder.fit(df)
+    val encoded = model.transform(df)
+    encoded.select("output", "expected").rdd.map { r =>
+      (r.getAs[Vector](0), r.getAs[Vector](1))
+    }.collect().foreach { case (vec1, vec2) =>
+      assert(vec1 === vec2)
+    }
+  }
+
+  test("input column with ML attribute") {
+    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", 
"large")
+    val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
+      .select(col("size").as("size", attr.toMetadata()))
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("size"))
+      .setOutputCols(Array("encoded"))
+    val model = encoder.fit(df)
+    val output = model.transform(df)
+    val group = AttributeGroup.fromStructField(output.schema("encoded"))
+    assert(group.size === 2)
+    assert(group.getAttr(0) === 
BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+    assert(group.getAttr(1) === 
BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+  }
+
+  test("input column without ML attribute") {
+    val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index")
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("index"))
+      .setOutputCols(Array("encoded"))
+    val model = encoder.fit(df)
+    val output = model.transform(df)
+    val group = AttributeGroup.fromStructField(output.schema("encoded"))
+    assert(group.size === 2)
+    assert(group.getAttr(0) === 
BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+    assert(group.getAttr(1) === 
BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+  }
+
+  test("read/write") {
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("index"))
+      .setOutputCols(Array("encoded"))
+    testDefaultReadWrite(encoder)
+  }
+
+  test("OneHotEncoderModel read/write") {
+    val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 
3))
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.categorySizes === instance.categorySizes)
+  }
+
+  test("OneHotEncoderEstimator with varying types") {
+    val data = Seq(
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected", new VectorUDT)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val dfWithTypes = df
+      .withColumn("shortInput", df("input").cast(ShortType))
+      .withColumn("longInput", df("input").cast(LongType))
+      .withColumn("intInput", df("input").cast(IntegerType))
+      .withColumn("floatInput", df("input").cast(FloatType))
+      .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
+
+    val cols = Array("input", "shortInput", "longInput", "intInput",
+      "floatInput", "decimalInput")
+    for (col <- cols) {
+      val encoder = new OneHotEncoderEstimator()
+        .setInputCols(Array(col))
+        .setOutputCols(Array("output"))
+        .setDropLast(false)
+
+      val model = encoder.fit(dfWithTypes)
+      val encoded = model.transform(dfWithTypes)
+
+      encoded.select("output", "expected").rdd.map { r =>
+        (r.getAs[Vector](0), r.getAs[Vector](1))
+      }.collect().foreach { case (vec1, vec2) =>
+        assert(vec1 === vec2)
+      }
+    }
+  }
+
+  test("OneHotEncoderEstimator: encoding multiple columns and dropLast = 
false") {
+    val data = Seq(
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, 
Seq((2, 1.0)))),
+      Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, 
Seq((3, 1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, 
Seq((0, 1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, 
Seq((1, 1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, 
Seq((0, 1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, 
Seq((2, 1.0)))))
+
+    val schema = StructType(Array(
+        StructField("input1", DoubleType),
+        StructField("expected1", new VectorUDT),
+        StructField("input2", DoubleType),
+        StructField("expected2", new VectorUDT)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("output1", "output2"))
+    assert(encoder.getDropLast === true)
+    encoder.setDropLast(false)
+    assert(encoder.getDropLast === false)
+
+    val model = encoder.fit(df)
+    val encoded = model.transform(df)
+    encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r 
=>
+      (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), 
r.getAs[Vector](3))
+    }.collect().foreach { case (vec1, vec2, vec3, vec4) =>
+      assert(vec1 === vec2)
+      assert(vec3 === vec4)
+    }
+  }
+
+  test("OneHotEncoderEstimator: encoding multiple columns and dropLast = 
true") {
+    val data = Seq(
+      Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, 
Seq((2, 1.0)))),
+      Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, 
Seq())),
+      Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 
1.0)))),
+      Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, 
Seq((1, 1.0)))),
+      Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, 
Seq((0, 1.0)))),
+      Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 
1.0)))))
+
+    val schema = StructType(Array(
+        StructField("input1", DoubleType),
+        StructField("expected1", new VectorUDT),
+        StructField("input2", DoubleType),
+        StructField("expected2", new VectorUDT)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input1", "input2"))
+      .setOutputCols(Array("output1", "output2"))
+
+    val model = encoder.fit(df)
+    val encoded = model.transform(df)
+    encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r 
=>
+      (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), 
r.getAs[Vector](3))
+    }.collect().foreach { case (vec1, vec2, vec3, vec4) =>
+      assert(vec1 === vec2)
+      assert(vec3 === vec4)
+    }
+  }
+
+  test("Throw error on invalid values") {
+    val trainingData = Seq((0, 0), (1, 1), (2, 2))
+    val trainingDF = trainingData.toDF("id", "a")
+    val testData = Seq((0, 0), (1, 2), (1, 3))
+    val testDF = testData.toDF("id", "a")
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("a"))
+      .setOutputCols(Array("encoded"))
+
+    val model = encoder.fit(trainingDF)
+    val err = intercept[SparkException] {
+      model.transform(testDF).show
+    }
+    err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
+  }
+
+  test("Can't transform on negative input") {
+    val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b")
+    val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b")
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("a"))
+      .setOutputCols(Array("encoded"))
+
+    val model = encoder.fit(trainingDF)
+    val err = intercept[SparkException] {
+      model.transform(testDF).collect()
+    }
+    err.getMessage.contains("Negative value: -1.0. Input can't be negative")
+  }
+
+  test("Keep on invalid values: dropLast = false") {
+    val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
+
+    val testData = Seq(
+      Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))),
+      Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))),
+      Row(3.0, Vectors.sparse(4, Seq((3, 1.0)))))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected", new VectorUDT)))
+
+    val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("output"))
+      .setHandleInvalid("keep")
+      .setDropLast(false)
+
+    val model = encoder.fit(trainingDF)
+    val encoded = model.transform(testDF)
+    encoded.select("output", "expected").rdd.map { r =>
+      (r.getAs[Vector](0), r.getAs[Vector](1))
+    }.collect().foreach { case (vec1, vec2) =>
+      assert(vec1 === vec2)
+    }
+  }
+
+  test("Keep on invalid values: dropLast = true") {
+    val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
+
+    val testData = Seq(
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
+      Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
+      Row(3.0, Vectors.sparse(3, Seq())))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected", new VectorUDT)))
+
+    val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("output"))
+      .setHandleInvalid("keep")
+      .setDropLast(true)
+
+    val model = encoder.fit(trainingDF)
+    val encoded = model.transform(testDF)
+    encoded.select("output", "expected").rdd.map { r =>
+      (r.getAs[Vector](0), r.getAs[Vector](1))
+    }.collect().foreach { case (vec1, vec2) =>
+      assert(vec1 === vec2)
+    }
+  }
+
+  test("OneHotEncoderModel changes dropLast") {
+    val data = Seq(
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 
1.0)))),
+      Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 
1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 
1.0)))),
+      Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 
1.0)))),
+      Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected1", new VectorUDT),
+        StructField("expected2", new VectorUDT)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("output"))
+
+    val model = encoder.fit(df)
+
+    model.setDropLast(false)
+    val encoded1 = model.transform(df)
+    encoded1.select("output", "expected1").rdd.map { r =>
+      (r.getAs[Vector](0), r.getAs[Vector](1))
+    }.collect().foreach { case (vec1, vec2) =>
+      assert(vec1 === vec2)
+    }
+
+    model.setDropLast(true)
+    val encoded2 = model.transform(df)
+    encoded2.select("output", "expected2").rdd.map { r =>
+      (r.getAs[Vector](0), r.getAs[Vector](1))
+    }.collect().foreach { case (vec1, vec2) =>
+      assert(vec1 === vec2)
+    }
+  }
+
+  test("OneHotEncoderModel changes handleInvalid") {
+    val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input")
+
+    val testData = Seq(
+      Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))),
+      Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))),
+      Row(3.0, Vectors.sparse(4, Seq((3, 1.0)))))
+
+    val schema = StructType(Array(
+        StructField("input", DoubleType),
+        StructField("expected", new VectorUDT)))
+
+    val testDF = spark.createDataFrame(sc.parallelize(testData), schema)
+
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("input"))
+      .setOutputCols(Array("output"))
+
+    val model = encoder.fit(trainingDF)
+    model.setHandleInvalid("error")
+
+    val err = intercept[SparkException] {
+      model.transform(testDF).collect()
+    }
+    err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
+
+    model.setHandleInvalid("keep")
+    model.transform(testDF).collect()
+  }
+
+  test("Transforming on mismatched attributes") {
+    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", 
"large")
+    val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
+      .select(col("size").as("size", attr.toMetadata()))
+    val encoder = new OneHotEncoderEstimator()
+      .setInputCols(Array("size"))
+      .setOutputCols(Array("encoded"))
+    val model = encoder.fit(df)
+
+    val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", 
"medium", "large")
+    val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size")
+      .select(col("size").as("size", testAttr.toMetadata()))
+    val err = intercept[Exception] {
+      model.transform(testDF).collect()
+    }
+    err.getMessage.contains("OneHotEncoderModel expected 2 categorical values")
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to