This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0cbe863e77c0 [SPARK-45547][ML] Validate Vectors with built-in function
0cbe863e77c0 is described below
commit 0cbe863e77c00e8987ddb170bdac5db4508173d7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 24 07:58:11 2023 +0800
[SPARK-45547][ML] Validate Vectors with built-in function
### What changes were proposed in this pull request?
Validate Vectors with built-in function
### Why are the changes needed?
with built-in function, the logic might be optimized further
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43380 from zhengruifeng/ml_vec_validate.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../spark/ml/classification/NaiveBayes.scala | 23 +++--------
.../apache/spark/ml/feature/VectorSizeHint.scala | 47 +++++++++-------------
.../org/apache/spark/ml/util/DatasetUtils.scala | 12 +-----
3 files changed, 27 insertions(+), 55 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 16176136a7e8..b7f9f97585fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -156,38 +156,27 @@ class NaiveBayes @Since("1.5.0") (
val validatedWeightCol = checkNonNegativeWeights(get(weightCol))
+ val vecCol = col($(featuresCol))
val validatedfeaturesCol = $(modelType) match {
case Multinomial | Complement =>
- val checkNonNegativeVector = udf { vector: Vector =>
- vector match {
- case dv: DenseVector => dv.values.forall(v => v >= 0 &&
!v.isInfinity)
- case sv: SparseVector => sv.values.forall(v => v >= 0 &&
!v.isInfinity)
- }
- }
- val vecCol = col($(featuresCol))
when(vecCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
- .when(!checkNonNegativeVector(vecCol),
+ .when(exists(unwrap_udt(vecCol).getField("values"),
+ v => v.isNaN || v < 0 || v === Double.PositiveInfinity),
raise_error(concat(
lit("Vector values MUST NOT be Negative, NaN or Infinity, but
got "),
vecCol.cast(StringType))))
.otherwise(vecCol)
case Bernoulli =>
- val checkBinaryVector = udf { vector: Vector =>
- vector match {
- case dv: DenseVector => dv.values.forall(v => v == 0 || v == 1)
- case sv: SparseVector => sv.values.forall(v => v == 0 || v == 1)
- }
- }
- val vecCol = col($(featuresCol))
when(vecCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
- .when(!checkBinaryVector(vecCol),
+ .when(exists(unwrap_udt(vecCol).getField("values"),
+ v => v =!= 0 && v =!= 1),
raise_error(concat(
lit("Vector values MUST be in {0, 1}, but got "),
vecCol.cast(StringType))))
.otherwise(vecCol)
- case _ => checkNonNanVectors($(featuresCol))
+ case _ => checkNonNanVectors(vecCol)
}
val validated = dataset.select(
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
index 2cf440efae85..5c96d07e0ca9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -17,17 +17,16 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable,
Identifiable}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
-import org.apache.spark.sql.functions.{col, udf}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StringType, StructType}
/**
* A feature transformer that adds size information to the metadata of a
vector column.
@@ -104,33 +103,25 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0")
override val uid: String)
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size
== localSize) {
dataset.toDF()
} else {
- val newCol: Column = localHandleInvalid match {
- case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol)
+ val vecCol = col(localInputCol)
+ val sizeCol = coalesce(unwrap_udt(vecCol).getField("size"),
+ array_size(unwrap_udt(vecCol).getField("values")))
+ val newVecCol = localHandleInvalid match {
+ case VectorSizeHint.OPTIMISTIC_INVALID => vecCol
case VectorSizeHint.ERROR_INVALID =>
- val checkVectorSizeUDF = udf { vector: Vector =>
- if (vector == null) {
- throw new SparkException(s"Got null vector in VectorSizeHint,
set `handleInvalid` " +
- s"to 'skip' to filter invalid rows.")
- }
- if (vector.size != localSize) {
- throw new SparkException(s"VectorSizeHint Expecting a vector of
size $localSize but" +
- s" got ${vector.size}")
- }
- vector
- }.asNondeterministic()
- checkVectorSizeUDF(col(localInputCol))
+ when(vecCol.isNull, raise_error(
+ lit("Got null vector in VectorSizeHint, set `handleInvalid` to
'skip' to " +
+ "filter invalid rows.")))
+ .when(sizeCol =!= localSize, raise_error(concat(
+ lit(s"VectorSizeHint Expecting a vector of size $localSize but
got "),
+ sizeCol.cast(StringType))))
+ .otherwise(vecCol)
case VectorSizeHint.SKIP_INVALID =>
- val checkVectorSizeUDF = udf { vector: Vector =>
- if (vector != null && vector.size == localSize) {
- vector
- } else {
- null
- }
- }
- checkVectorSizeUDF(col(localInputCol))
+ when(!vecCol.isNull && sizeCol === localSize, vecCol)
+ .otherwise(lit(null))
}
- val res = dataset.withColumn(localInputCol, newCol.as(localInputCol,
newGroup.toMetadata()))
+ val res = dataset.withColumn(localInputCol, newVecCol,
newGroup.toMetadata())
if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) {
res.na.drop(Array(localInputCol))
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
index 08ecdaf0196c..b3cb9c7f2dd1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -83,7 +83,8 @@ private[spark] object DatasetUtils extends Logging {
private[ml] def checkNonNanVectors(vectorCol: Column): Column = {
when(vectorCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
- .when(!validateVector(vectorCol),
+ .when(exists(unwrap_udt(vectorCol).getField("values"),
+ v => v.isNaN || v === Double.NegativeInfinity || v ===
Double.PositiveInfinity),
raise_error(concat(lit("Vector values MUST NOT be NaN or Infinity, but
got "),
vectorCol.cast(StringType))))
.otherwise(vectorCol)
@@ -93,15 +94,6 @@ private[spark] object DatasetUtils extends Logging {
checkNonNanVectors(col(vectorCol))
}
- private lazy val validateVector = udf { vector: Vector =>
- vector match {
- case dv: DenseVector =>
- dv.values.forall(v => !v.isNaN && !v.isInfinity)
- case sv: SparseVector =>
- sv.values.forall(v => !v.isNaN && !v.isInfinity)
- }
- }
-
private[ml] def extractInstances(
p: PredictorParams,
df: Dataset[_],
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]