This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d3149a  [SPARK-38643][ML] Validate input dataset of ml.regression
6d3149a is described below

commit 6d3149a0d5fe0652197841a589bbeb8654471e58
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Mar 24 23:46:31 2022 -0700

    [SPARK-38643][ML] Validate input dataset of ml.regression
    
    ### What changes were proposed in this pull request?
    validate the input dataset, and fail fast when containing invalid values
    
    ### Why are the changes needed?
    to avoid retruning a bad model silently
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    added testsuites
    
    Closes #35958 from zhengruifeng/regression_validate_training_dataset.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: huaxingao <huaxin_...@apple.com>
---
 .../ml/regression/AFTSurvivalRegression.scala      | 26 ++++++++++-----
 .../ml/regression/DecisionTreeRegressor.scala      | 13 ++++++--
 .../apache/spark/ml/regression/FMRegressor.scala   |  9 ++++--
 .../apache/spark/ml/regression/GBTRegressor.scala  | 14 ++++++--
 .../regression/GeneralizedLinearRegression.scala   | 28 ++++++++--------
 .../spark/ml/regression/IsotonicRegression.scala   | 16 ++++++----
 .../spark/ml/regression/LinearRegression.scala     | 16 ++++++----
 .../ml/regression/RandomForestRegressor.scala      | 12 +++++--
 .../org/apache/spark/ml/util/DatasetUtils.scala    | 12 ++++---
 .../ml/regression/AFTSurvivalRegressionSuite.scala | 37 ++++++++++++++++++++++
 .../ml/regression/DecisionTreeRegressorSuite.scala |  6 ++++
 .../spark/ml/regression/FMRegressorSuite.scala     |  5 +++
 .../spark/ml/regression/GBTRegressorSuite.scala    |  6 ++++
 .../GeneralizedLinearRegressionSuite.scala         | 31 ++++++++++++++++++
 .../ml/regression/IsotonicRegressionSuite.scala    | 32 +++++++++++++++++++
 .../ml/regression/LinearRegressionSuite.scala      |  6 ++++
 .../ml/regression/RandomForestRegressorSuite.scala |  6 ++++
 .../scala/org/apache/spark/ml/util/MLTest.scala    | 29 +++++++++++++++++
 18 files changed, 258 insertions(+), 46 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 117229b..c48fe68 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -35,6 +35,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.stat._
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
@@ -210,14 +211,23 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
         s"then cached during training. Be careful of double caching!")
     }
 
-    val instances = dataset.select(col($(featuresCol)), 
col($(labelCol)).cast(DoubleType),
-      col($(censorCol)).cast(DoubleType))
-      .rdd.map { case Row(features: Vector, label: Double, censor: Double) =>
-        require(censor == 1.0 || censor == 0.0, "censor must be 1.0 or 0.0")
-        // AFT does not support instance weighting,
-        // here use Instance.weight to store censor for convenience
-        Instance(label, censor, features)
-      }.setName("training instances")
+    val validatedCensorCol = {
+      val casted = col($(censorCol)).cast(DoubleType)
+      when(casted.isNull || casted.isNaN, raise_error(lit("Censors MUST NOT be 
Null or NaN")))
+        .when(casted =!= 0 && casted =!= 1,
+          raise_error(concat(lit("Censors MUST be in {0, 1}, but got "), 
casted)))
+        .otherwise(casted)
+    }
+
+    val instances = dataset.select(
+      checkRegressionLabels($(labelCol)),
+      validatedCensorCol,
+      checkNonNanVectors($(featuresCol))
+    ).rdd.map { case Row(l: Double, c: Double, v: Vector) =>
+      // AFT does not support instance weighting,
+      // here use Instance.weight to store censor for convenience
+      Instance(l, c, v)
+    }.setName("training instances")
 
     val summarizer = instances.treeAggregate(
       Summarizer.createSummarizerBuffer("mean", "std", "count"))(
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6913718..d9942f1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -22,16 +22,18 @@ import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => 
OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => 
OldDecisionTreeModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.StructType
 
@@ -114,7 +116,14 @@ class DecisionTreeRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
       dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr 
=>
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val instances = extractInstances(dataset)
+
+    val instances = dataset.select(
+      checkRegressionLabels($(labelCol)),
+      checkNonNegativeWeights(get(weightCol)),
+      checkNonNanVectors($(featuresCol))
+    ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v)
+    }.setName("training instances")
+
     val strategy = getOldStrategy(categoricalFeatures)
     require(!strategy.bootstrap, "DecisionTreeRegressor does not need 
bootstrap sampling")
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index f70baa4..c0178ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -32,6 +32,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.regression.FactorizationMachines._
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.{linalg => OldLinalg}
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => 
OldVectors}
@@ -416,8 +417,12 @@ class FMRegressor @Since("3.0.0") (
     instr.logNumFeatures(numFeatures)
 
     val handlePersistence = dataset.storageLevel == StorageLevel.NONE
-    val labeledPoint = extractLabeledPoints(dataset)
-    val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, 
x.features))
+
+    val data = dataset.select(
+      checkRegressionLabels($(labelCol)),
+      checkNonNanVectors($(featuresCol))
+    ).rdd.map { case Row(l: Double, v: Vector) => (l, OldVectors.fromML(v))
+    }.setName("training instances")
 
     if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index fd8af71..10a203e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -22,16 +22,18 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{BLAS, Vector}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => 
OldGBTModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.StructType
 
@@ -164,8 +166,16 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
   def setWeightCol(value: String): this.type = set(weightCol, value)
 
   override protected def train(dataset: Dataset[_]): GBTRegressionModel = 
instrumented { instr =>
-    val withValidation = isDefined(validationIndicatorCol) && 
$(validationIndicatorCol).nonEmpty
 
+    def extractInstances(df: Dataset[_]) = {
+      df.select(
+        checkRegressionLabels($(labelCol)),
+        checkNonNegativeWeights(get(weightCol)),
+        checkNonNanVectors($(featuresCol))
+      ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, 
v) }
+    }
+
+    val withValidation = isDefined(validationIndicatorCol) && 
$(validationIndicatorCol).nonEmpty
     val (trainDataset, validationDataset) = if (withValidation) {
       (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
         extractInstances(dataset.filter(col($(validationIndicatorCol)))))
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 73da2af..88581d0 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -29,12 +29,12 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.feature.{Instance, OffsetInstance}
-import org.apache.spark.ml.functions.checkNonNegativeWeight
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
 import org.apache.spark.ml.optim._
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
@@ -397,16 +397,19 @@ class GeneralizedLinearRegression @Since("2.0.0") 
(@Since("2.0.0") override val
       "GeneralizedLinearRegression was given data with 0 features, and with 
Param fitIntercept " +
         "set to false. To fit a model with 0 features, fitIntercept must be 
set to true." )
 
-    val w = if (!hasWeightCol) lit(1.0) else 
checkNonNegativeWeight(col($(weightCol)))
-    val offset = if (!hasOffsetCol) lit(0.0) else 
col($(offsetCol)).cast(DoubleType)
+    val validated = dataset.select(
+      checkRegressionLabels($(labelCol)),
+      checkNonNegativeWeights(get(weightCol)),
+      if (!hasOffsetCol) lit(0.0) else checkNonNanValues($(offsetCol), 
"Offsets"),
+      checkNonNanVectors($(featuresCol))
+    )
 
     val model = if (familyAndLink.family == Gaussian && familyAndLink.link == 
Identity) {
       // TODO: Make standardizeFeatures and standardizeLabel configurable.
-      val instances: RDD[Instance] =
-        dataset.select(col($(labelCol)), w, offset, 
col($(featuresCol))).rdd.map {
-          case Row(label: Double, weight: Double, offset: Double, features: 
Vector) =>
-            Instance(label - offset, weight, features)
-        }
+      val instances = validated.rdd.map {
+        case Row(label: Double, weight: Double, offset: Double, features: 
Vector) =>
+          Instance(label - offset, weight, features)
+      }
       val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), 
elasticNetParam = 0.0,
         standardizeFeatures = true, standardizeLabel = true)
       val wlsModel = optimizer.fit(instances, instr = 
OptionalInstrumentation.create(instr),
@@ -418,11 +421,10 @@ class GeneralizedLinearRegression @Since("2.0.0") 
(@Since("2.0.0") override val
         wlsModel.diagInvAtWA.toArray, 1, getSolver)
       model.setSummary(Some(trainingSummary))
     } else {
-      val instances: RDD[OffsetInstance] =
-        dataset.select(col($(labelCol)), w, offset, 
col($(featuresCol))).rdd.map {
-          case Row(label: Double, weight: Double, offset: Double, features: 
Vector) =>
-            OffsetInstance(label, weight, offset, features)
-        }
+      val instances = validated.rdd.map {
+        case Row(label: Double, weight: Double, offset: Double, features: 
Vector) =>
+          OffsetInstance(label, weight, offset, features)
+      }
       // Fit Generalized Linear Model by iteratively reweighted least squares 
(IRLS).
       val initialModel = familyAndLink.initialize(instances, $(fitIntercept), 
$(regParam),
         instr = OptionalInstrumentation.create(instr), $(aggregationDepth))
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index ec2640e..f1f2179 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -22,18 +22,18 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.functions.checkNonNegativeWeight
 import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import 
org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.regression.{IsotonicRegression => 
MLlibIsotonicRegression}
 import org.apache.spark.mllib.regression.{IsotonicRegressionModel => 
MLlibIsotonicRegressionModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit, udf}
+import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{DoubleType, StructType}
 import org.apache.spark.storage.StorageLevel
 
@@ -81,17 +81,19 @@ private[regression] trait IsotonicRegressionBase extends 
Params with HasFeatures
    */
   protected[ml] def extractWeightedLabeledPoints(
       dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
+    val l = checkRegressionLabels($(labelCol))
+
     val f = if 
(dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
       val idx = $(featureIndex)
       val extract = udf { v: Vector => v(idx) }
-      extract(col($(featuresCol)))
+      extract(checkNonNanVectors($(featuresCol)))
     } else {
-      col($(featuresCol))
+      checkNonNanValues($(featuresCol), "Features")
     }
-    val w =
-      if (hasWeightCol) 
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0)
 
-    dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
+    val w = checkNonNegativeWeights(get(weightCol))
+
+    dataset.select(l, f, w).rdd.map {
       case Row(label: Double, feature: Double, weight: Double) => (label, 
feature, weight)
     }
   }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 12d5e59..a53ef8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -37,6 +37,7 @@ import org.apache.spark.ml.param.{DoubleParam, Param, 
ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.stat._
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -340,14 +341,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
     instr.logNumFeatures(numFeatures)
 
+    val instances = dataset.select(
+      checkRegressionLabels($(labelCol)),
+      checkNonNegativeWeights(get(weightCol)),
+      checkNonNanVectors($(featuresCol))
+    ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v)
+    }.setName("training instances")
+
     if ($(loss) == SquaredError && (($(solver) == Auto &&
       numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == 
Normal)) {
-      return trainWithNormal(dataset, instr)
+      return trainWithNormal(dataset, instances, instr)
     }
 
-    val instances = extractInstances(dataset)
-      .setName("training instances")
-
     val (summarizer, labelSummarizer) = Summarizer
       .getRegressionSummarizers(instances, $(aggregationDepth), Seq("mean", 
"std", "count"))
 
@@ -439,6 +444,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
 
   private def trainWithNormal(
       dataset: Dataset[_],
+      instances: RDD[Instance],
       instr: Instrumentation): LinearRegressionModel = {
     // For low dimensional data, WeightedLeastSquares is more efficient since 
the
     // training algorithm only requires one pass through the data. 
(SPARK-10668)
@@ -446,8 +452,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
       elasticNetParam = $(elasticNetParam), $(standardization), true,
       solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = 
$(tol))
-    val instances = extractInstances(dataset)
-      .setName("training instances")
     val model = optimizer.fit(instances, instr = 
OptionalInstrumentation.create(instr))
     // When it is trained by WeightedLeastSquares, training summary does not
     // attach returned model.
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index bb74c56..f241ff3 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -21,16 +21,18 @@ import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DatasetUtils._
 import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => 
OldRandomForestModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql._
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.StructType
 
@@ -135,7 +137,13 @@ class RandomForestRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
 
-    val instances = extractInstances(dataset)
+    val instances = dataset.select(
+      checkRegressionLabels($(labelCol)),
+      checkNonNegativeWeights(get(weightCol)),
+      checkNonNanVectors($(featuresCol))
+    ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v)
+    }.setName("training instances")
+
     val strategy =
       super.getOldStrategy(categoricalFeatures, numClasses = 0, 
OldAlgo.Regression, getOldImpurity)
     strategy.bootstrap = $(bootstrap)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
index f607a7b..48f1b70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -27,14 +27,18 @@ import org.apache.spark.sql.types._
 
 private[spark] object DatasetUtils {
 
-  private[ml] def checkRegressionLabels(labelCol: String): Column = {
-    val casted = col(labelCol).cast(DoubleType)
-    when(casted.isNull || casted.isNaN, raise_error(lit("Labels MUST NOT be 
Null or NaN")))
+  private[ml] def checkNonNanValues(colName: String, displayed: String): 
Column = {
+    val casted = col(colName).cast(DoubleType)
+    when(casted.isNull || casted.isNaN, raise_error(lit(s"$displayed MUST NOT 
be Null or NaN")))
       .when(casted === Double.NegativeInfinity || casted === 
Double.PositiveInfinity,
-        raise_error(concat(lit("Labels MUST NOT be Infinity, but got "), 
casted)))
+        raise_error(concat(lit(s"$displayed MUST NOT be Infinity, but got "), 
casted)))
       .otherwise(casted)
   }
 
+  private[ml] def checkRegressionLabels(labelCol: String): Column = {
+    checkNonNanValues(labelCol, "Labels")
+  }
+
   private[ml] def checkClassificationLabels(
     labelCol: String,
     numClasses: Option[Int]): Column = {
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index e745e7f..c8f6926 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -93,6 +93,43 @@ class AFTSurvivalRegressionSuite extends MLTest with 
DefaultReadWriteTest {
     assert(model.hasParent)
   }
 
+  test("AFTSurvivalRegression validate input dataset") {
+    testInvalidRegressionLabels { df: DataFrame =>
+      val dfWithCensors = df.withColumn("censor", lit(1.0))
+      new AFTSurvivalRegression().fit(dfWithCensors)
+    }
+
+    testInvalidVectors { df: DataFrame =>
+      val dfWithCensors = df.withColumn("censor", lit(1.0))
+      new AFTSurvivalRegression().fit(dfWithCensors)
+    }
+
+    // censors contains NULL
+    val df1 = sc.parallelize(Seq(
+      (1.0, null, Vectors.dense(1.0, 2.0)),
+      (1.0, "1.0", Vectors.dense(1.0, 2.0))
+    )).toDF("label", "str_censor", "features")
+      .select(col("label"), col("str_censor").cast("double").as("censor"), 
col("features"))
+    val e1 = intercept[Exception](new AFTSurvivalRegression().fit(df1))
+    assert(e1.getMessage.contains("Censors MUST NOT be Null or NaN"))
+
+    // censors contains NaN
+    val df2 = sc.parallelize(Seq(
+      (1.0, Double.NaN, Vectors.dense(1.0, 2.0)),
+      (1.0, 1.0, Vectors.dense(1.0, 2.0))
+    )).toDF("label", "censor", "features")
+    val e2 = intercept[Exception](new AFTSurvivalRegression().fit(df2))
+    assert(e2.getMessage.contains("Censors MUST NOT be Null or NaN"))
+
+    // censors contains invalid value: 3
+    val df3 = sc.parallelize(Seq(
+      (1.0, 1.0, Vectors.dense(1.0, 2.0)),
+      (1.0, 3.0, Vectors.dense(1.0, 2.0))
+    )).toDF("label", "censor", "features")
+    val e3 = intercept[Exception](new AFTSurvivalRegression().fit(df3))
+    assert(e3.getMessage.contains("Censors MUST be in {0, 1}"))
+  }
+
   def generateAFTInput(
       numFeatures: Int,
       xMean: Array[Double],
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 9cb03454..a7b696e 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -53,6 +53,12 @@ class DecisionTreeRegressorSuite extends MLTest with 
DefaultReadWriteTest {
   // Tests calling train()
   /////////////////////////////////////////////////////////////////////////////
 
+  test("DecisionTreeRegressor validate input dataset") {
+    testInvalidRegressionLabels(new DecisionTreeRegressor().fit(_))
+    testInvalidWeights(new 
DecisionTreeRegressor().setWeightCol("weight").fit(_))
+    testInvalidVectors(new DecisionTreeRegressor().fit(_))
+  }
+
   test("Regression stump with 3-ary (ordered) categorical features") {
     val dt = new DecisionTreeRegressor()
       .setImpurity("variance")
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
index 372432c..d5b7f3c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/FMRegressorSuite.scala
@@ -48,6 +48,11 @@ class FMRegressorSuite extends MLTest with 
DefaultReadWriteTest {
     ParamsSuite.checkParams(model)
   }
 
+  test("FMRegressor validate input dataset") {
+    testInvalidRegressionLabels(new FMRegressor().fit(_))
+    testInvalidVectors(new FMRegressor().fit(_))
+  }
+
   test("combineCoefficients") {
     val numFeatures = 2
     val factorSize = 4
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 7d84df6..7e96281 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -133,6 +133,12 @@ class GBTRegressorSuite extends MLTest with 
DefaultReadWriteTest {
     Utils.deleteRecursively(tempDir)
   }
 
+  test("GBTRegressor validate input dataset") {
+    testInvalidRegressionLabels(new GBTRegressor().fit(_))
+    testInvalidWeights(new GBTRegressor().setWeightCol("weight").fit(_))
+    testInvalidVectors(new GBTRegressor().fit(_))
+  }
+
   test("model support predict leaf index") {
     val model0 = new DecisionTreeRegressionModel("dtc", TreeTests.root0, 3)
     val model1 = new DecisionTreeRegressionModel("dtc", TreeTests.root1, 3)
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index bfa9f4b..3acb0bc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -211,6 +211,37 @@ class GeneralizedLinearRegressionSuite extends MLTest with 
DefaultReadWriteTest
     assert(model.getLink === "identity")
   }
 
+  test("GeneralizedLinearRegression validate input dataset") {
+    testInvalidRegressionLabels(new GeneralizedLinearRegression().fit(_))
+    testInvalidWeights(new 
GeneralizedLinearRegression().setWeightCol("weight").fit(_))
+    testInvalidVectors(new GeneralizedLinearRegression().fit(_))
+
+    // offsets contains NULL
+    val df1 = sc.parallelize(Seq(
+      (1.0, null, Vectors.dense(1.0, 2.0)),
+      (1.0, "1.0", Vectors.dense(1.0, 2.0))
+    )).toDF("label", "str_offset", "features")
+      .select(col("label"), col("str_offset").cast("double").as("offset"), 
col("features"))
+    val e1 = intercept[Exception](new 
GeneralizedLinearRegression().setOffsetCol("offset").fit(df1))
+    assert(e1.getMessage.contains("Offsets MUST NOT be Null or NaN"))
+
+    // offsets contains NaN
+    val df2 = sc.parallelize(Seq(
+      (1.0, Double.NaN, Vectors.dense(1.0, 2.0)),
+      (1.0, 1.0, Vectors.dense(1.0, 2.0))
+    )).toDF("label", "offset", "features")
+    val e2 = intercept[Exception](new 
GeneralizedLinearRegression().setOffsetCol("offset").fit(df2))
+    assert(e2.getMessage.contains("Offsets MUST NOT be Null or NaN"))
+
+    // offsets contains Infinity
+    val df3 = sc.parallelize(Seq(
+      (1.0, Double.PositiveInfinity, Vectors.dense(1.0, 2.0)),
+      (1.0, 1.0, Vectors.dense(1.0, 2.0))
+    )).toDF("label", "offset", "features")
+    val e3 = intercept[Exception](new 
GeneralizedLinearRegression().setOffsetCol("offset").fit(df3))
+    assert(e3.getMessage.contains("Offsets MUST NOT be Infinity"))
+  }
+
   test("prediction on single instance") {
     val glr = new GeneralizedLinearRegression
     val model = glr.setFamily("gaussian").setLink("identity")
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 18fbbce..3077a60 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.col
 
 class IsotonicRegressionSuite extends MLTest with DefaultReadWriteTest {
 
@@ -101,6 +102,37 @@ class IsotonicRegressionSuite extends MLTest with 
DefaultReadWriteTest {
     assert(model.hasParent)
   }
 
+  test("IsotonicRegression validate input dataset") {
+    testInvalidRegressionLabels(new IsotonicRegression().fit(_))
+    testInvalidWeights(new IsotonicRegression().setWeightCol("weight").fit(_))
+    testInvalidVectors(new IsotonicRegression().fit(_))
+
+    // features contains NULL
+    val df1 = sc.parallelize(Seq(
+      (1.0, 1.0, null),
+      (1.0, 1.0, "1.0")
+    )).toDF("label", "weight", "str_features")
+      .select(col("label"), col("weight"), 
col("str_features").cast("double").as("features"))
+    val e1 = intercept[Exception](new IsotonicRegression().fit(df1))
+    assert(e1.getMessage.contains("Features MUST NOT be Null or NaN"))
+
+    // features contains NaN
+    val df2 = sc.parallelize(Seq(
+      (1.0, 1.0, 1.0),
+      (1.0, 1.0, Double.NaN)
+    )).toDF("label", "weight", "features")
+    val e2 = intercept[Exception](new IsotonicRegression().fit(df2))
+    assert(e2.getMessage.contains("Features MUST NOT be Null or NaN"))
+
+    // features contains Infinity
+    val df3 = sc.parallelize(Seq(
+      (1.0, 1.0, 1.0),
+      (1.0, 1.0, Double.PositiveInfinity)
+    )).toDF("label", "weight", "features")
+    val e3 = intercept[Exception](new IsotonicRegression().fit(df3))
+    assert(e3.getMessage.contains("Features MUST NOT be Infinity"))
+  }
+
   test("set parameters") {
     val isotonicRegression = new IsotonicRegression()
       .setIsotonic(false)
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index b3098be..e4535f3 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -188,6 +188,12 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest with PMMLRe
     assert(model.numFeatures === numFeatures)
   }
 
+  test("LinearRegression validate input dataset") {
+    testInvalidRegressionLabels(new LinearRegression().fit(_))
+    testInvalidWeights(new LinearRegression().setWeightCol("weight").fit(_))
+    testInvalidVectors(new LinearRegression().fit(_))
+  }
+
   test("linear regression: can transform data with LinearRegressionModel") {
     withClue("training related params like loss are only validated during 
fitting phase") {
       val original = new LinearRegression().fit(datasetWithDenseFeature)
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7ec30de..4047e6d 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -57,6 +57,12 @@ class RandomForestRegressorSuite extends MLTest with 
DefaultReadWriteTest{
   // Tests calling train()
   /////////////////////////////////////////////////////////////////////////////
 
+  test("RandomForestRegressor validate input dataset") {
+    testInvalidRegressionLabels(new RandomForestRegressor().fit(_))
+    testInvalidWeights(new 
RandomForestRegressor().setWeightCol("weight").fit(_))
+    testInvalidVectors(new RandomForestRegressor().fit(_))
+  }
+
   def regressionTestWithContinuousFeatures(rf: RandomForestRegressor): Unit = {
     val categoricalFeaturesInfo = Map.empty[Int, Int]
     val newRF = rf
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 2b67989..b847c90 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -233,6 +233,35 @@ trait MLTest extends StreamTest with TempDirectory { self: 
Suite =>
     }
   }
 
+  def testInvalidRegressionLabels(f: DataFrame => Any): Unit = {
+    import testImplicits._
+
+    // labels contains NULL
+    val df1 = sc.parallelize(Seq(
+      (null, 1.0, Vectors.dense(1.0, 2.0)),
+      ("1.0", 1.0, Vectors.dense(1.0, 2.0))
+    )).toDF("str_label", "weight", "features")
+      .select(col("str_label").cast("double").as("label"), col("weight"), 
col("features"))
+    val e1 = intercept[Exception](f(df1))
+    assert(e1.getMessage.contains("Labels MUST NOT be Null or NaN"))
+
+    // labels contains NaN
+    val df2 = sc.parallelize(Seq(
+      (Double.NaN, 1.0, Vectors.dense(1.0, 2.0)),
+      (1.0, 1.0, Vectors.dense(1.0, 2.0))
+    )).toDF("label", "weight", "features")
+    val e2 = intercept[Exception](f(df2))
+    assert(e2.getMessage.contains("Labels MUST NOT be Null or NaN"))
+
+    // labels contains invalid value: Infinity
+    val df3 = sc.parallelize(Seq(
+      (Double.NegativeInfinity, 1.0, Vectors.dense(1.0, 2.0)),
+      (1.0, 1.0, Vectors.dense(1.0, 2.0))
+    )).toDF("label", "weight", "features")
+    val e3 = intercept[Exception](f(df3))
+    assert(e3.getMessage.contains("Labels MUST NOT be Infinity"))
+  }
+
   def testInvalidClassificationLabels(f: DataFrame => Any, numClasses: 
Option[Int]): Unit = {
     import testImplicits._
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to