This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6d3149a [SPARK-38643][ML] Validate input dataset of ml.regression
6d3149a is described below
commit 6d3149a0d5fe0652197841a589bbeb8654471e58
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Mar 24 23:46:31 2022 -0700
[SPARK-38643][ML] Validate input dataset of ml.regression
### What changes were proposed in this pull request?
validate the input dataset, and fail fast when containing invalid values
### Why are the changes needed?
to avoid retruning a bad model silently
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
added testsuites
Closes #35958 from zhengruifeng/regression_validate_training_dataset.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: huaxingao <[email protected]>
---
.../ml/regression/AFTSurvivalRegression.scala | 26 ++++++++++-----
.../ml/regression/DecisionTreeRegressor.scala | 13 ++++++--
.../apache/spark/ml/regression/FMRegressor.scala | 9 ++++--
.../apache/spark/ml/regression/GBTRegressor.scala | 14 ++++++--
.../regression/GeneralizedLinearRegression.scala | 28 ++++++++--------
.../spark/ml/regression/IsotonicRegression.scala | 16 ++++++----
.../spark/ml/regression/LinearRegression.scala | 16 ++++++----
.../ml/regression/RandomForestRegressor.scala | 12 +++++--
.../org/apache/spark/ml/util/DatasetUtils.scala | 12 ++++---
.../ml/regression/AFTSurvivalRegressionSuite.scala | 37 ++++++++++++++++++++++
.../ml/regression/DecisionTreeRegressorSuite.scala | 6 ++++
.../spark/ml/regression/FMRegressorSuite.scala | 5 +++
.../spark/ml/regression/GBTRegressorSuite.scala | 6 ++++
.../GeneralizedLinearRegressionSuite.scala | 31 ++++++++++++++++++
.../ml/regression/IsotonicRegressionSuite.scala | 32 +++++++++++++++++++
.../ml/regression/LinearRegressionSuite.scala | 6 ++++
.../ml/regression/RandomForestRegressorSuite.scala | 6 ++++
.../scala/org/apache/spark/ml/util/MLTest.scala | 29 +++++++++++++++++
18 files changed, 258 insertions(+), 46 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 117229b..c48fe68 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -35,6 +35,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -210,14 +211,23 @@ class AFTSurvivalRegression @Since("1.6.0")
(@Since("1.6.0") override val uid: S
s"then cached during training. Be careful of double caching!")
}
- val instances = dataset.select(col($(featuresCol)),
col($(labelCol)).cast(DoubleType),
- col($(censorCol)).cast(DoubleType))
- .rdd.map { case Row(features: Vector, label: Double, censor: Double) =>
- require(censor == 1.0 || censor == 0.0, "censor must be 1.0 or 0.0")
- // AFT does not support instance weighting,
- // here use Instance.weight to store censor for convenience
- Instance(label, censor, features)
- }.setName("training instances")
+ val validatedCensorCol = {
+ val casted = col($(censorCol)).cast(DoubleType)
+ when(casted.isNull || casted.isNaN, raise_error(lit("Censors MUST NOT be
Null or NaN")))
+ .when(casted =!= 0 && casted =!= 1,
+ raise_error(concat(lit("Censors MUST be in {0, 1}, but got "),
casted)))
+ .otherwise(casted)
+ }
+
+ val instances = dataset.select(
+ checkRegressionLabels($(labelCol)),
+ validatedCensorCol,
+ checkNonNanVectors($(featuresCol))
+ ).rdd.map { case Row(l: Double, c: Double, v: Vector) =>
+ // AFT does not support instance weighting,
+ // here use Instance.weight to store censor for convenience
+ Instance(l, c, v)
+ }.setName("training instances")
val summarizer = instances.treeAggregate(
Summarizer.createSummarizerBuffer("mean", "std", "count"))(
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6913718..d9942f1 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -22,16 +22,18 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy =>
OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel =>
OldDecisionTreeModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
@@ -114,7 +116,14 @@ class DecisionTreeRegressor @Since("1.4.0")
(@Since("1.4.0") override val uid: S
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr
=>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val instances = extractInstances(dataset)
+
+ val instances = dataset.select(
+ checkRegressionLabels($(labelCol)),
+ checkNonNegativeWeights(get(weightCol)),
+ checkNonNanVectors($(featuresCol))
+ ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v)
+ }.setName("training instances")
+
val strategy = getOldStrategy(categoricalFeatures)
require(!strategy.bootstrap, "DecisionTreeRegressor does not need
bootstrap sampling")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index f70baa4..c0178ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -32,6 +32,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.regression.FactorizationMachines._
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.{linalg => OldLinalg}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors =>
OldVectors}
@@ -416,8 +417,12 @@ class FMRegressor @Since("3.0.0") (
instr.logNumFeatures(numFeatures)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
- val labeledPoint = extractLabeledPoints(dataset)
- val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label,
x.features))
+
+ val data = dataset.select(
+ checkRegressionLabels($(labelCol)),
+ checkNonNanVectors($(featuresCol))
+ ).rdd.map { case Row(l: Double, v: Vector) => (l, OldVectors.fromML(v))
+ }.setName("training instances")
if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index fd8af71..10a203e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -22,16 +22,18 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel =>
OldGBTModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
@@ -164,8 +166,16 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0")
override val uid: String)
def setWeightCol(value: String): this.type = set(weightCol, value)
override protected def train(dataset: Dataset[_]): GBTRegressionModel =
instrumented { instr =>
- val withValidation = isDefined(validationIndicatorCol) &&
$(validationIndicatorCol).nonEmpty
+ def extractInstances(df: Dataset[_]) = {
+ df.select(
+ checkRegressionLabels($(labelCol)),
+ checkNonNegativeWeights(get(weightCol)),
+ checkNonNanVectors($(featuresCol))
+ ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w,
v) }
+ }
+
+ val withValidation = isDefined(validationIndicatorCol) &&
$(validationIndicatorCol).nonEmpty
val (trainDataset, validationDataset) = if (withValidation) {
(extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
extractInstances(dataset.filter(col($(validationIndicatorCol)))))
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 73da2af..88581d0 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -29,12 +29,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
-import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
@@ -397,16 +397,19 @@ class GeneralizedLinearRegression @Since("2.0.0")
(@Since("2.0.0") override val
"GeneralizedLinearRegression was given data with 0 features, and with
Param fitIntercept " +
"set to false. To fit a model with 0 features, fitIntercept must be
set to true." )
- val w = if (!hasWeightCol) lit(1.0) else
checkNonNegativeWeight(col($(weightCol)))
- val offset = if (!hasOffsetCol) lit(0.0) else
col($(offsetCol)).cast(DoubleType)
+ val validated = dataset.select(
+ checkRegressionLabels($(labelCol)),
+ checkNonNegativeWeights(get(weightCol)),
+ if (!hasOffsetCol) lit(0.0) else checkNonNanValues($(offsetCol),
"Offsets"),
+ checkNonNanVectors($(featuresCol))
+ )
val model = if (familyAndLink.family == Gaussian && familyAndLink.link ==
Identity) {
// TODO: Make standardizeFeatures and standardizeLabel configurable.
- val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, offset,
col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, offset: Double, features:
Vector) =>
- Instance(label - offset, weight, features)
- }
+ val instances = validated.rdd.map {
+ case Row(label: Double, weight: Double, offset: Double, features:
Vector) =>
+ Instance(label - offset, weight, features)
+ }
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
elasticNetParam = 0.0,
standardizeFeatures = true, standardizeLabel = true)
val wlsModel = optimizer.fit(instances, instr =
OptionalInstrumentation.create(instr),
@@ -418,11 +421,10 @@ class GeneralizedLinearRegression @Since("2.0.0")
(@Since("2.0.0") override val
wlsModel.diagInvAtWA.toArray, 1, getSolver)
model.setSummary(Some(trainingSummary))
} else {
- val instances: RDD[OffsetInstance] =
- dataset.select(col($(labelCol)), w, offset,
col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, offset: Double, features:
Vector) =>
- OffsetInstance(label, weight, offset, features)
- }
+ val instances = validated.rdd.map {
+ case Row(label: Double, weight: Double, offset: Double, features:
Vector) =>
+ OffsetInstance(label, weight, offset, features)
+ }
// Fit Generalized Linear Model by iteratively reweighted least squares
(IRLS).
val initialModel = familyAndLink.initialize(instances, $(fitIntercept),
$(regParam),
instr = OptionalInstrumentation.create(instr), $(aggregationDepth))
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index ec2640e..f1f2179 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -22,18 +22,18 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import
org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.regression.{IsotonicRegression =>
MLlibIsotonicRegression}
import org.apache.spark.mllib.regression.{IsotonicRegressionModel =>
MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit, udf}
+import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -81,17 +81,19 @@ private[regression] trait IsotonicRegressionBase extends
Params with HasFeatures
*/
protected[ml] def extractWeightedLabeledPoints(
dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
+ val l = checkRegressionLabels($(labelCol))
+
val f = if
(dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
val idx = $(featureIndex)
val extract = udf { v: Vector => v(idx) }
- extract(col($(featuresCol)))
+ extract(checkNonNanVectors($(featuresCol)))
} else {
- col($(featuresCol))
+ checkNonNanValues($(featuresCol), "Features")
}
- val w =
- if (hasWeightCol)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0)
- dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
+ val w = checkNonNegativeWeights(get(weightCol))
+
+ dataset.select(l, f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) => (label,
feature, weight)
}
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 12d5e59..a53ef8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -37,6 +37,7 @@ import org.apache.spark.ml.param.{DoubleParam, Param,
ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -340,14 +341,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0")
override val uid: String
val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)
+ val instances = dataset.select(
+ checkRegressionLabels($(labelCol)),
+ checkNonNegativeWeights(get(weightCol)),
+ checkNonNanVectors($(featuresCol))
+ ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v)
+ }.setName("training instances")
+
if ($(loss) == SquaredError && (($(solver) == Auto &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) ==
Normal)) {
- return trainWithNormal(dataset, instr)
+ return trainWithNormal(dataset, instances, instr)
}
- val instances = extractInstances(dataset)
- .setName("training instances")
-
val (summarizer, labelSummarizer) = Summarizer
.getRegressionSummarizers(instances, $(aggregationDepth), Seq("mean",
"std", "count"))
@@ -439,6 +444,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0")
override val uid: String
private def trainWithNormal(
dataset: Dataset[_],
+ instances: RDD[Instance],
instr: Instrumentation): LinearRegressionModel = {
// For low dimensional data, WeightedLeastSquares is more efficient since
the
// training algorithm only requires one pass through the data.
(SPARK-10668)
@@ -446,8 +452,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0")
override val uid: String
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
elasticNetParam = $(elasticNetParam), $(standardization), true,
solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol =
$(tol))
- val instances = extractInstances(dataset)
- .setName("training instances")
val model = optimizer.fit(instances, instr =
OptionalInstrumentation.create(instr))
// When it is trained by WeightedLeastSquares, training summary does not
// attach returned model.
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index bb74c56..f241ff3 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -21,16 +21,18 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel =>
OldRandomForestModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
@@ -135,7 +137,13 @@ class RandomForestRegressor @Since("1.4.0")
(@Since("1.4.0") override val uid: S
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val instances = extractInstances(dataset)
+ val instances = dataset.select(
+ checkRegressionLabels($(labelCol)),
+ checkNonNegativeWeights(get(weightCol)),
+ checkNonNanVectors($(featuresCol))
+ ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v)
+ }.setName("training instances")
+
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0,
OldAlgo.Regression, getOldImpurity)
strategy.bootstrap = $(bootstrap)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
index f607a7b..48f1b70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -27,14 +27,18 @@ import org.apache.spark.sql.types._
private[spark] object DatasetUtils {
- private[ml] def checkRegressionLabels(labelCol: String): Column = {
- val casted = col(labelCol).cast(DoubleType)
- when(casted.isNull || casted.isNaN, raise_error(lit("Labels MUST NOT be
Null or NaN")))
+ private[ml] def checkNonNanValues(colName: String, displayed: String):
Column = {
+ val casted = col(colName).cast(DoubleType)
+ when(casted.isNull || casted.isNaN, raise_error(lit(s"$displayed MUST NOT
be Null or NaN")))
.when(casted === Double.NegativeInfinity || casted ===
Double.PositiveInfinity,
- raise_error(concat(lit("Labels MUST NOT be Infinity, but got "),
casted)))
+ raise_error(concat(lit(s"$displayed MUST NOT be Infinity, but got "),
casted)))
.otherwise(casted)
}
+ private[ml] def checkRegressionLabels(labelCol: String): Column = {
+ checkNonNanValues(labelCol, "Labels")
+ }
+
private[ml] def checkClassificationLabels(
labelCol: String,
numClasses: Option[Int]): Column = {
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index e745e7f..c8f6926 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -93,6 +93,43 @@ class AFTSurvivalRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model.hasParent)
}
+ test("AFTSurvivalRegression validate input dataset") {
+ testInvalidRegressionLabels { df: DataFrame =>
+ val dfWithCensors = df.withColumn("censor", lit(1.0))
+ new AFTSurvivalRegression().fit(dfWithCensors)
+ }
+
+ testInvalidVectors { df: DataFrame =>
+ val dfWithCensors = df.withColumn("censor", lit(1.0))
+ new AFTSurvivalRegression().fit(dfWithCensors)
+ }
+
+ // censors contains NULL
+ val df1 = sc.parallelize(Seq(
+ (1.0, null, Vectors.dense(1.0, 2.0)),
+ (1.0, "1.0", Vectors.dense(1.0, 2.0))
+ )).toDF("label", "str_censor", "features")
+ .select(col("label"), col("str_censor").cast("double").as("censor"),
col("features"))
+ val e1 = intercept[Exception](new AFTSurvivalRegression().fit(df1))
+ assert(e1.getMessage.contains("Censors MUST NOT be Null or NaN"))
+
+ // censors contains NaN
+ val df2 = sc.parallelize(Seq(
+ (1.0, Double.NaN, Vectors.dense(1.0, 2.0)),
+ (1.0, 1.0, Vectors.dense(1.0, 2.0))
+ )).toDF("label", "censor", "features")
+ val e2 = intercept[Exception](new AFTSurvivalRegression().fit(df2))
+ assert(e2.getMessage.contains("Censors MUST NOT be Null or NaN"))
+
+ // censors contains invalid value: 3
+ val df3 = sc.parallelize(Seq(
+ (1.0, 1.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(1.0, 2.0))
+ )).toDF("label", "censor", "features")
+ val e3 = intercept[Exception](new AFTSurvivalRegression().fit(df3))
+ assert(e3.getMessage.contains("Censors MUST be in {0, 1}"))
+ }
+
def generateAFTInput(
numFeatures: Int,
xMean: Array[Double],
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 9cb03454..a7b696e 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -53,6 +53,12 @@ class DecisionTreeRegressorSuite extends MLTest with
DefaultReadWriteTest {
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
+ test("DecisionTreeRegressor validate input dataset") {
+ testInvalidRegressionLabels(new DecisionTreeRegressor().fit(_))
+ testInvalidWeights(new
DecisionTreeRegressor().setWeightCol("weight").fit(_))
+ testInvalidVectors(new DecisionTreeRegressor().fit(_))
+ }
+
test("Regression stump with 3-ary (ordered) categorical features") {
val dt = new DecisionTreeRegressor()
.setImpurity("variance")
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
index 372432c..d5b7f3c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
@@ -48,6 +48,11 @@ class FMRegressorSuite extends MLTest with
DefaultReadWriteTest {
ParamsSuite.checkParams(model)
}
+ test("FMRegressor validate input dataset") {
+ testInvalidRegressionLabels(new FMRegressor().fit(_))
+ testInvalidVectors(new FMRegressor().fit(_))
+ }
+
test("combineCoefficients") {
val numFeatures = 2
val factorSize = 4
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 7d84df6..7e96281 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -133,6 +133,12 @@ class GBTRegressorSuite extends MLTest with
DefaultReadWriteTest {
Utils.deleteRecursively(tempDir)
}
+ test("GBTRegressor validate input dataset") {
+ testInvalidRegressionLabels(new GBTRegressor().fit(_))
+ testInvalidWeights(new GBTRegressor().setWeightCol("weight").fit(_))
+ testInvalidVectors(new GBTRegressor().fit(_))
+ }
+
test("model support predict leaf index") {
val model0 = new DecisionTreeRegressionModel("dtc", TreeTests.root0, 3)
val model1 = new DecisionTreeRegressionModel("dtc", TreeTests.root1, 3)
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index bfa9f4b..3acb0bc 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -211,6 +211,37 @@ class GeneralizedLinearRegressionSuite extends MLTest with
DefaultReadWriteTest
assert(model.getLink === "identity")
}
+ test("GeneralizedLinearRegression validate input dataset") {
+ testInvalidRegressionLabels(new GeneralizedLinearRegression().fit(_))
+ testInvalidWeights(new
GeneralizedLinearRegression().setWeightCol("weight").fit(_))
+ testInvalidVectors(new GeneralizedLinearRegression().fit(_))
+
+ // offsets contains NULL
+ val df1 = sc.parallelize(Seq(
+ (1.0, null, Vectors.dense(1.0, 2.0)),
+ (1.0, "1.0", Vectors.dense(1.0, 2.0))
+ )).toDF("label", "str_offset", "features")
+ .select(col("label"), col("str_offset").cast("double").as("offset"),
col("features"))
+ val e1 = intercept[Exception](new
GeneralizedLinearRegression().setOffsetCol("offset").fit(df1))
+ assert(e1.getMessage.contains("Offsets MUST NOT be Null or NaN"))
+
+ // offsets contains NaN
+ val df2 = sc.parallelize(Seq(
+ (1.0, Double.NaN, Vectors.dense(1.0, 2.0)),
+ (1.0, 1.0, Vectors.dense(1.0, 2.0))
+ )).toDF("label", "offset", "features")
+ val e2 = intercept[Exception](new
GeneralizedLinearRegression().setOffsetCol("offset").fit(df2))
+ assert(e2.getMessage.contains("Offsets MUST NOT be Null or NaN"))
+
+ // offsets contains Infinity
+ val df3 = sc.parallelize(Seq(
+ (1.0, Double.PositiveInfinity, Vectors.dense(1.0, 2.0)),
+ (1.0, 1.0, Vectors.dense(1.0, 2.0))
+ )).toDF("label", "offset", "features")
+ val e3 = intercept[Exception](new
GeneralizedLinearRegression().setOffsetCol("offset").fit(df3))
+ assert(e3.getMessage.contains("Offsets MUST NOT be Infinity"))
+ }
+
test("prediction on single instance") {
val glr = new GeneralizedLinearRegression
val model = glr.setFamily("gaussian").setLink("identity")
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 18fbbce..3077a60 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.col
class IsotonicRegressionSuite extends MLTest with DefaultReadWriteTest {
@@ -101,6 +102,37 @@ class IsotonicRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model.hasParent)
}
+ test("IsotonicRegression validate input dataset") {
+ testInvalidRegressionLabels(new IsotonicRegression().fit(_))
+ testInvalidWeights(new IsotonicRegression().setWeightCol("weight").fit(_))
+ testInvalidVectors(new IsotonicRegression().fit(_))
+
+ // features contains NULL
+ val df1 = sc.parallelize(Seq(
+ (1.0, 1.0, null),
+ (1.0, 1.0, "1.0")
+ )).toDF("label", "weight", "str_features")
+ .select(col("label"), col("weight"),
col("str_features").cast("double").as("features"))
+ val e1 = intercept[Exception](new IsotonicRegression().fit(df1))
+ assert(e1.getMessage.contains("Features MUST NOT be Null or NaN"))
+
+ // features contains NaN
+ val df2 = sc.parallelize(Seq(
+ (1.0, 1.0, 1.0),
+ (1.0, 1.0, Double.NaN)
+ )).toDF("label", "weight", "features")
+ val e2 = intercept[Exception](new IsotonicRegression().fit(df2))
+ assert(e2.getMessage.contains("Features MUST NOT be Null or NaN"))
+
+ // features contains Infinity
+ val df3 = sc.parallelize(Seq(
+ (1.0, 1.0, 1.0),
+ (1.0, 1.0, Double.PositiveInfinity)
+ )).toDF("label", "weight", "features")
+ val e3 = intercept[Exception](new IsotonicRegression().fit(df3))
+ assert(e3.getMessage.contains("Features MUST NOT be Infinity"))
+ }
+
test("set parameters") {
val isotonicRegression = new IsotonicRegression()
.setIsotonic(false)
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index b3098be..e4535f3 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -188,6 +188,12 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest with PMMLRe
assert(model.numFeatures === numFeatures)
}
+ test("LinearRegression validate input dataset") {
+ testInvalidRegressionLabels(new LinearRegression().fit(_))
+ testInvalidWeights(new LinearRegression().setWeightCol("weight").fit(_))
+ testInvalidVectors(new LinearRegression().fit(_))
+ }
+
test("linear regression: can transform data with LinearRegressionModel") {
withClue("training related params like loss are only validated during
fitting phase") {
val original = new LinearRegression().fit(datasetWithDenseFeature)
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7ec30de..4047e6d 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -57,6 +57,12 @@ class RandomForestRegressorSuite extends MLTest with
DefaultReadWriteTest{
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
+ test("RandomForestRegressor validate input dataset") {
+ testInvalidRegressionLabels(new RandomForestRegressor().fit(_))
+ testInvalidWeights(new
RandomForestRegressor().setWeightCol("weight").fit(_))
+ testInvalidVectors(new RandomForestRegressor().fit(_))
+ }
+
def regressionTestWithContinuousFeatures(rf: RandomForestRegressor): Unit = {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val newRF = rf
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 2b67989..b847c90 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -233,6 +233,35 @@ trait MLTest extends StreamTest with TempDirectory { self:
Suite =>
}
}
+ def testInvalidRegressionLabels(f: DataFrame => Any): Unit = {
+ import testImplicits._
+
+ // labels contains NULL
+ val df1 = sc.parallelize(Seq(
+ (null, 1.0, Vectors.dense(1.0, 2.0)),
+ ("1.0", 1.0, Vectors.dense(1.0, 2.0))
+ )).toDF("str_label", "weight", "features")
+ .select(col("str_label").cast("double").as("label"), col("weight"),
col("features"))
+ val e1 = intercept[Exception](f(df1))
+ assert(e1.getMessage.contains("Labels MUST NOT be Null or NaN"))
+
+ // labels contains NaN
+ val df2 = sc.parallelize(Seq(
+ (Double.NaN, 1.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 1.0, Vectors.dense(1.0, 2.0))
+ )).toDF("label", "weight", "features")
+ val e2 = intercept[Exception](f(df2))
+ assert(e2.getMessage.contains("Labels MUST NOT be Null or NaN"))
+
+ // labels contains invalid value: Infinity
+ val df3 = sc.parallelize(Seq(
+ (Double.NegativeInfinity, 1.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 1.0, Vectors.dense(1.0, 2.0))
+ )).toDF("label", "weight", "features")
+ val e3 = intercept[Exception](f(df3))
+ assert(e3.getMessage.contains("Labels MUST NOT be Infinity"))
+ }
+
def testInvalidClassificationLabels(f: DataFrame => Any, numClasses:
Option[Int]): Unit = {
import testImplicits._
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]