Repository: spark Updated Branches: refs/heads/master aba9492d2 -> 900f14f6f
[SPARK-21729][ML][TEST] Generic test for ProbabilisticClassifier to ensure consistent output columns ## What changes were proposed in this pull request? Add test for prediction using the model with all combinations of output columns turned on/off. Make sure the output column values match, presumably by comparing vs. the case with all 3 output columns turned on. ## How was this patch tested? Test updated. Author: WeichenXu <weichen...@databricks.com> Author: WeichenXu <weichenxu...@outlook.com> Closes #19065 from WeichenXu123/generic_test_for_prob_classifier. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/900f14f6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/900f14f6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/900f14f6 Branch: refs/heads/master Commit: 900f14f6fad50369aa849922447f60d7cf06cf2f Parents: aba9492 Author: WeichenXu <weichen...@databricks.com> Authored: Fri Sep 1 17:32:33 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Sep 1 17:32:33 2017 -0700 ---------------------------------------------------------------------- .../DecisionTreeClassifierSuite.scala | 3 + .../ml/classification/GBTClassifierSuite.scala | 3 + .../LogisticRegressionSuite.scala | 6 ++ .../MultilayerPerceptronClassifierSuite.scala | 2 + .../ml/classification/NaiveBayesSuite.scala | 6 ++ .../ProbabilisticClassifierSuite.scala | 60 ++++++++++++++++++++ .../RandomForestClassifierSuite.scala | 2 + 7 files changed, 82 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 918ab27..98c879e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, "probability prediction mismatch") } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, DecisionTreeClassificationModel](newTree, newData) } test("training with 1-category categorical feature") { http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 1f79e0d..8000143 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, GBTClassificationModel](gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 6bf1253..d43c7cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -502,6 +502,9 @@ class LogisticRegressionSuite resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, LogisticRegressionModel](model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -556,6 +559,9 @@ class LogisticRegressionSuite resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, LogisticRegressionModel](model, smallBinaryDataset) } test("coefficients and intercept methods") { http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index c294e4a..d3141ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, MultilayerPerceptronClassificationModel](model, strongDataset) } test("test model probability") { http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 3a2be23..9730dd6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "multinomial") + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, NaiveBayesModel](model, testDataset) } test("Naive Bayes with weighted samples") { @@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, NaiveBayesModel](model, testDataset) } test("detect negative values") { http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 172c64a..4ecd5a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.{Dataset, Row} final class TestProbabilisticClassificationModel( override val uid: String, @@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite { "thresholds" -> Array(0.4, 0.6) ) + /** + * Helper for testing that a ProbabilisticClassificationModel computes + * the same predictions across all combinations of output columns + * (rawPrediction/probability/prediction) turned on/off. Makes sure the + * output column values match by comparing vs. the case with all 3 output + * columns turned on. + */ + def testPredictMethods[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]]( + model: M, testData: Dataset[_]): Unit = { + + val allColModel = model.copy(ParamMap.empty) + .setRawPredictionCol("rawPredictionAll") + .setProbabilityCol("probabilityAll") + .setPredictionCol("predictionAll") + val allColResult = allColModel.transform(testData) + + for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { + for (probabilityCol <- Seq("", "probabilitySingle")) { + for (predictionCol <- Seq("", "predictionSingle")) { + val newModel = model.copy(ParamMap.empty) + .setRawPredictionCol(rawPredictionCol) + .setProbabilityCol(probabilityCol) + .setPredictionCol(predictionCol) + + val result = newModel.transform(allColResult) + + import org.apache.spark.sql.functions._ + + val resultRawPredictionCol = + if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) + val resultProbabilityCol = + if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) + val resultPredictionCol = + if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + + result.select( + resultRawPredictionCol, col("rawPredictionAll"), + resultProbabilityCol, col("probabilityAll"), + resultPredictionCol, col("predictionAll") + ).collect().foreach { + case Row( + rawPredictionSingle: Vector, rawPredictionAll: Vector, + probabilitySingle: Vector, probabilityAll: Vector, + predictionSingle: Double, predictionAll: Double + ) => { + assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3) + assert(probabilitySingle ~== probabilityAll relTol 1E-3) + assert(predictionSingle ~== predictionAll relTol 1E-3) + } + } + } + } + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/900f14f6/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ca2954d..2cca2e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,8 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, RandomForestClassificationModel](model, df) } test("Fitting without numClasses in metadata") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org