Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/20132#discussion_r159157608
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
---
@@ -205,60 +210,58 @@ class OneHotEncoderModel private[ml] (
import OneHotEncoderModel._
- // Returns the category size for a given index with `dropLast` and
`handleInvalid`
+ // Returns the category size for each index with `dropLast` and
`handleInvalid`
// taken into account.
- private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
+ private def getConfigedCategorySizes: Array[Int] = {
val dropLast = getDropLast
val keepInvalid = getHandleInvalid ==
OneHotEncoderEstimator.KEEP_INVALID
if (!dropLast && keepInvalid) {
// When `handleInvalid` is "keep", an extra category is added as
last category
// for invalid data.
- orgCategorySize + 1
+ categorySizes.map(_ + 1)
} else if (dropLast && !keepInvalid) {
// When `dropLast` is true, the last category is removed.
- orgCategorySize - 1
+ categorySizes.map(_ - 1)
} else {
// When `dropLast` is true and `handleInvalid` is "keep", the extra
category for invalid
// data is removed. Thus, it is the same as the plain number of
categories.
- orgCategorySize
+ categorySizes
}
}
private def encoder: UserDefinedFunction = {
- val oneValue = Array(1.0)
- val emptyValues = Array.empty[Double]
- val emptyIndices = Array.empty[Int]
- val dropLast = getDropLast
- val handleInvalid = getHandleInvalid
- val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
+ val keepInvalid = getHandleInvalid ==
OneHotEncoderEstimator.KEEP_INVALID
+ val configedSizes = getConfigedCategorySizes
+ val localCategorySizes = categorySizes
// The udf performed on input data. The first parameter is the input
value. The second
- // parameter is the index of input.
- udf { (label: Double, idx: Int) =>
- val plainNumCategories = categorySizes(idx)
- val size = configedCategorySize(plainNumCategories, idx)
-
- if (label < 0) {
- throw new SparkException(s"Negative value: $label. Input can't be
negative.")
- } else if (label == size && dropLast && !keepInvalid) {
- // When `dropLast` is true and `handleInvalid` is not "keep",
- // the last category is removed.
- Vectors.sparse(size, emptyIndices, emptyValues)
- } else if (label >= plainNumCategories && keepInvalid) {
- // When `handleInvalid` is "keep", encodes invalid data to last
category (and removed
- // if `dropLast` is true)
- if (dropLast) {
- Vectors.sparse(size, emptyIndices, emptyValues)
+ // parameter is the index in inputCols of the column being encoded.
+ udf { (label: Double, colIdx: Int) =>
+ val origCategorySize = localCategorySizes(colIdx)
+ // idx: index in vector of the single 1-valued element
+ val idx = if (label >= 0 && label < origCategorySize) {
+ label
+ } else {
+ if (keepInvalid) {
+ origCategorySize
} else {
- Vectors.sparse(size, Array(size - 1), oneValue)
+ if (label < 0) {
+ throw new SparkException(s"Negative value: $label. Input can't
be negative. " +
--- End diff --
I have a question. Since we don't allow negative value when fitting, should
we allow it in transforming even handleInvalid is KEEP_INVALID?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]