Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19527#discussion_r159025626
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
---
@@ -0,0 +1,519 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols,
HasOutputCols}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.expressions.UserDefinedFunction
+import org.apache.spark.sql.functions.{col, lit, udf}
+import org.apache.spark.sql.types.{DoubleType, NumericType, StructField,
StructType}
+
+/** Private trait for params and common methods for OneHotEncoderEstimator
and OneHotEncoderModel */
+private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
+ with HasInputCols with HasOutputCols {
+
+ /**
+ * Param for how to handle invalid data.
+ * Options are 'keep' (invalid data presented as an extra categorical
feature) or
+ * 'error' (throw an error).
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.3.0")
+ override val handleInvalid: Param[String] = new Param[String](this,
"handleInvalid",
+ "How to handle invalid data " +
+ "Options are 'keep' (invalid data presented as an extra categorical
feature) " +
+ "or error (throw an error).",
+
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
+
+ setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
+
+ /**
+ * Whether to drop the last category in the encoded vector (default:
true)
+ * @group param
+ */
+ @Since("2.3.0")
+ final val dropLast: BooleanParam =
+ new BooleanParam(this, "dropLast", "whether to drop the last category")
+ setDefault(dropLast -> true)
+
+ /** @group getParam */
+ @Since("2.3.0")
+ def getDropLast: Boolean = $(dropLast)
+
+ protected def validateAndTransformSchema(
+ schema: StructType, dropLast: Boolean, keepInvalid: Boolean):
StructType = {
+ val inputColNames = $(inputCols)
+ val outputColNames = $(outputCols)
+ val existingFields = schema.fields
+
+ require(inputColNames.length == outputColNames.length,
+ s"The number of input columns ${inputColNames.length} must be the
same as the number of " +
+ s"output columns ${outputColNames.length}.")
+
+ // Input columns must be NumericType.
+ inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))
+
+ // Prepares output columns with proper attributes by examining input
columns.
+ val inputFields = $(inputCols).map(schema(_))
+
+ val outputFields = inputFields.zip(outputColNames).map { case
(inputField, outputColName) =>
+ OneHotEncoderCommon.transformOutputColumnSchema(
+ inputField, outputColName, dropLast, keepInvalid)
+ }
+ outputFields.foldLeft(schema) { case (newSchema, outputField) =>
+ SchemaUtils.appendColumn(newSchema, outputField)
+ }
+ }
+}
+
+/**
+ * A one-hot encoder that maps a column of category indices to a column of
binary vectors, with
+ * at most a single one-value per row that indicates the input category
index.
+ * For example with 5 categories, an input value of 2.0 would map to an
output vector of
+ * `[0.0, 0.0, 1.0, 0.0]`.
+ * The last category is not included by default (configurable via
`dropLast`),
+ * because it makes the vector entries sum up to one, and hence linearly
dependent.
+ * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ *
+ * @note This is different from scikit-learn's OneHotEncoder, which keeps
all categories.
+ * The output vectors are sparse.
+ *
+ * When `handleInvalid` is configured to 'keep', an extra "category"
indicating invalid values is
+ * added as last category. So when `dropLast` is true, invalid values are
encoded as all-zeros
+ * vector.
+ *
+ * @note When encoding multi-column by using `inputCols` and `outputCols`
params, input/output cols
+ * come in pairs, specified by the order in the arrays, and each pair is
treated independently.
+ *
+ * @see `StringIndexer` for converting categorical values into category
indices
+ */
+@Since("2.3.0")
+class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val
uid: String)
+ extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with
DefaultParamsWritable {
+
+ @Since("2.3.0")
+ def this() = this(Identifiable.randomUID("oneHotEncoder"))
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(values: Array[String]): this.type = set(inputCols,
values)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(values: Array[String]): this.type = set(outputCols,
values)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setDropLast(value: Boolean): this.type = set(dropLast, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid,
value)
+
+ @Since("2.3.0")
+ override def transformSchema(schema: StructType): StructType = {
+ // When fitting data, we want the the plain number of categories
without `handleInvalid` and
--- End diff --
There is misunderstanding. The model stores the raw numCategories now. I
don't want to change it.
For now, the transformSchema logic shared between Model and Estimator is
the same: it uses the metadata from input schema to compute the output schema
with the Params into account or not.
The numCategories is determined from this schema then. As you require
Estimator.transformSchema to return a schema with the Params into account, the
numCategories derived from the schema is not raw now. But we want to record raw
numCategories into the model, so I need to compute raw numCategories from the
numCategories derived the schema.
Another approach is, I let Estimator.transformSchema to return a schema
with the Params into account. When I need to record the raw numCategories, I
call validateAndTransformSchema again to get the raw numCategories without the
Param into account.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]