Repository: spark Updated Branches: refs/heads/master 774398045 -> 1e6f76059
[SPARK-12375][ML] VectorIndexerModel support handle unseen categories via handleInvalid ## What changes were proposed in this pull request? Support skip/error/keep strategy, similar to `StringIndexer`. Implemented via `try...catch`, so that it can avoid possible performance impact. ## How was this patch tested? Unit test added. Author: WeichenXu <weichen...@databricks.com> Closes #19588 from WeichenXu123/handle_invalid_for_vector_indexer. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1e6f7605 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1e6f7605 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1e6f7605 Branch: refs/heads/master Commit: 1e6f760593d81def059c514d34173bf2777d71ec Parents: 7743980 Author: WeichenXu <weichen...@databricks.com> Authored: Tue Nov 14 16:58:18 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Nov 14 16:58:18 2017 -0800 ---------------------------------------------------------------------- .../apache/spark/ml/feature/VectorIndexer.scala | 92 +++++++++++++++++--- .../spark/ml/feature/VectorIndexerSuite.scala | 39 +++++++++ 2 files changed, 121 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1e6f7605/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index d371da7..3403ec4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.feature import java.lang.{Double => JDouble, Integer => JInt} -import java.util.{Map => JMap} +import java.util.{Map => JMap, NoSuchElementException} import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ @@ -37,7 +38,27 @@ import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet /** Private trait for params for VectorIndexer and VectorIndexerModel */ -private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { + + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Note: this param only applies to categorical features, not continuous ones. + * Options are: + * 'skip': filter out rows with invalid data. + * 'error': throw an error. + * 'keep': put invalid data in a special additional bucket, at index numCategories. + * Default value: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) /** * Threshold for the number of values a categorical feature can take. @@ -113,6 +134,10 @@ class VectorIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) @@ -148,6 +173,11 @@ class VectorIndexer @Since("1.4.0") ( @Since("1.6.0") object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): VectorIndexer = super.load(path) @@ -287,9 +317,15 @@ class VectorIndexerModel private[ml] ( while (featureIndex < numFeatures) { if (categoryMaps.contains(featureIndex)) { // categorical feature - val featureValues: Array[String] = + val rawFeatureValues: Array[String] = categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) - if (featureValues.length == 2) { + + val featureValues = if (getHandleInvalid == VectorIndexer.KEEP_INVALID) { + (rawFeatureValues.toList :+ "__unknown").toArray + } else { + rawFeatureValues + } + if (featureValues.length == 2 && getHandleInvalid != VectorIndexer.KEEP_INVALID) { attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), values = Some(featureValues)) } else { @@ -311,22 +347,39 @@ class VectorIndexerModel private[ml] ( // TODO: Check more carefully about whether this whole class will be included in a closure. /** Per-vector transform function */ - private val transformFunc: Vector => Vector = { + private lazy val transformFunc: Vector => Vector = { val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted val localVectorMap = categoryMaps val localNumFeatures = numFeatures + val localHandleInvalid = getHandleInvalid val f: Vector => Vector = { (v: Vector) => assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" + s" $numFeatures but found length ${v.size}") v match { case dv: DenseVector => + var hasInvalid = false val tmpv = dv.copy localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => - tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + try { + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } catch { + case _: NoSuchElementException => + localHandleInvalid match { + case VectorIndexer.ERROR_INVALID => + throw new SparkException(s"VectorIndexer encountered invalid value " + + s"${tmpv(featureIndex)} on feature index ${featureIndex}. To handle " + + s"or skip invalid value, try setting VectorIndexer.handleInvalid.") + case VectorIndexer.KEEP_INVALID => + tmpv.values(featureIndex) = categoryMap.size + case VectorIndexer.SKIP_INVALID => + hasInvalid = true + } + } } - tmpv + if (hasInvalid) null else tmpv case sv: SparseVector => // We use the fact that categorical value 0 is always mapped to index 0. + var hasInvalid = false val tmpv = sv.copy var catFeatureIdx = 0 // index into sortedCatFeatureIndices var k = 0 // index into non-zero elements of sparse vector @@ -337,12 +390,26 @@ class VectorIndexerModel private[ml] ( } else if (featureIndex > tmpv.indices(k)) { k += 1 } else { - tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + try { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + } catch { + case _: NoSuchElementException => + localHandleInvalid match { + case VectorIndexer.ERROR_INVALID => + throw new SparkException(s"VectorIndexer encountered invalid value " + + s"${tmpv.values(k)} on feature index ${featureIndex}. To handle " + + s"or skip invalid value, try setting VectorIndexer.handleInvalid.") + case VectorIndexer.KEEP_INVALID => + tmpv.values(k) = localVectorMap(featureIndex).size + case VectorIndexer.SKIP_INVALID => + hasInvalid = true + } + } catFeatureIdx += 1 k += 1 } } - tmpv + if (hasInvalid) null else tmpv } } f @@ -362,7 +429,12 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol, newField.metadata) + val ds = dataset.withColumn($(outputCol), newCol, newField.metadata) + if (getHandleInvalid == VectorIndexer.SKIP_INVALID) { + ds.na.drop(Array($(outputCol))) + } else { + ds + } } @Since("1.4.0") http://git-wip-us.apache.org/repos/asf/spark/blob/1e6f7605/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index f2cca8a..69a7b75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -38,6 +38,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext // identical, of length 3 @transient var densePoints1: DataFrame = _ @transient var sparsePoints1: DataFrame = _ + @transient var densePoints1TestInvalid: DataFrame = _ + @transient var sparsePoints1TestInvalid: DataFrame = _ @transient var point1maxes: Array[Double] = _ // identical, of length 2 @@ -55,11 +57,19 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext Vectors.dense(0.0, 1.0, 2.0), Vectors.dense(0.0, 0.0, -1.0), Vectors.dense(1.0, 3.0, 2.0)) + val densePoints1SeqTestInvalid = densePoints1Seq ++ Seq( + Vectors.dense(10.0, 2.0, 0.0), + Vectors.dense(0.0, 10.0, 2.0), + Vectors.dense(1.0, 3.0, 10.0)) val sparsePoints1Seq = Seq( Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), Vectors.sparse(3, Array(2), Array(-1.0)), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + val sparsePoints1SeqTestInvalid = sparsePoints1Seq ++ Seq( + Vectors.sparse(3, Array(0, 1), Array(10.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(10.0, 2.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 10.0))) point1maxes = Array(1.0, 3.0, 2.0) val densePoints2Seq = Seq( @@ -88,6 +98,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext densePoints1 = densePoints1Seq.map(FeatureData).toDF() sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF() + densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData).toDF() + sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData).toDF() densePoints2 = densePoints2Seq.map(FeatureData).toDF() sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF() badPoints = badPointsSeq.map(FeatureData).toDF() @@ -219,6 +231,33 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) } + test("handle invalid") { + for ((points, pointsTestInvalid) <- Seq((densePoints1, densePoints1TestInvalid), + (sparsePoints1, sparsePoints1TestInvalid))) { + val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") + val model = vectorIndexer.fit(points) + intercept[SparkException] { + model.transform(pointsTestInvalid).collect() + } + val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") + val model1 = vectorIndexer1.fit(points) + val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") + .collect().map(_(0)) + val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) + assert(transformed1 === invalidTransformed1) + + val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") + val model2 = vectorIndexer2.fit(points) + val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") + .collect().map(_(0)) + assert(invalidTransformed2 === transformed1 ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors.dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0)) + ) + } + } + test("Maintain sparsity for sparse vectors") { def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { val points = data.collect().map(_.getAs[Vector](0)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org