Repository: spark
Updated Branches:
refs/heads/master 816963043 -> 2ea17afb6
[SPARK-22881][ML][TEST] ML regression package testsuite add StructuredStreaming
test
## What changes were proposed in this pull request?
ML regression package testsuite add StructuredStreaming test
In order to make testsuite easier to modify, new helper function added in
`MLTest`:
```
def testTransformerByGlobalCheckFunc[A : Encoder](
dataframe: DataFrame,
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(globalCheckFunction: Seq[Row] => Unit): Unit
```
## How was this patch tested?
N/A
Author: WeichenXu <[email protected]>
Author: Bago Amirbekian <[email protected]>
Closes #19979 from WeichenXu123/ml_stream_test.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ea17afb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ea17afb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ea17afb
Branch: refs/heads/master
Commit: 2ea17afb63f976500273518bf1b32f9efe250812
Parents: 8169630
Author: WeichenXu <[email protected]>
Authored: Fri Dec 29 20:06:56 2017 -0800
Committer: Joseph K. Bradley <[email protected]>
Committed: Fri Dec 29 20:06:56 2017 -0800
----------------------------------------------------------------------
.../regression/AFTSurvivalRegressionSuite.scala | 19 ++++----
.../regression/DecisionTreeRegressorSuite.scala | 43 +++++++++---------
.../spark/ml/regression/GBTRegressorSuite.scala | 23 +++++-----
.../GeneralizedLinearRegressionSuite.scala | 47 ++++++++++----------
.../ml/regression/IsotonicRegressionSuite.scala | 43 +++++++-----------
.../ml/regression/LinearRegressionSuite.scala | 25 ++++++-----
.../scala/org/apache/spark/ml/util/MLTest.scala | 39 ++++++++++++----
.../org/apache/spark/ml/util/MLTestSuite.scala | 12 ++++-
.../apache/spark/sql/streaming/StreamTest.scala | 27 ++++++-----
9 files changed, 147 insertions(+), 131 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 02e5c6d..4e4ff71 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -19,19 +19,16 @@ package org.apache.spark.ml.regression
import scala.util.Random
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._
-class AFTSurvivalRegressionSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -191,8 +188,8 @@ class AFTSurvivalRegressionSuite
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetUnivariate).select("features", "prediction",
"quantiles")
- .collect().foreach {
+ testTransformer[(Vector, Double, Double)](datasetUnivariate, model,
+ "features", "prediction", "quantiles") {
case Row(features: Vector, prediction: Double, quantiles: Vector) =>
assert(prediction ~== model.predict(features) relTol 1E-5)
assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
@@ -261,8 +258,8 @@ class AFTSurvivalRegressionSuite
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetMultivariate).select("features", "prediction",
"quantiles")
- .collect().foreach {
+ testTransformer[(Vector, Double, Double)](datasetMultivariate, model,
+ "features", "prediction", "quantiles") {
case Row(features: Vector, prediction: Double, quantiles: Vector) =>
assert(prediction ~== model.predict(features) relTol 1E-5)
assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
@@ -331,8 +328,8 @@ class AFTSurvivalRegressionSuite
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetMultivariate).select("features", "prediction",
"quantiles")
- .collect().foreach {
+ testTransformer[(Vector, Double, Double)](datasetMultivariate, model,
+ "features", "prediction", "quantiles") {
case Row(features: Vector, prediction: Double, quantiles: Vector) =>
assert(prediction ~== model.predict(features) relTol 1E-5)
assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 642f266..68a1218 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -21,19 +21,18 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-class DecisionTreeRegressorSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
import DecisionTreeRegressorSuite.compareAPIs
+ import testImplicits._
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
@@ -89,14 +88,11 @@ class DecisionTreeRegressorSuite
val df = TreeTests.setMetadata(categoricalDataPointsRDD,
categoricalFeatures, numClasses = 0)
val model = dt.fit(df)
- val predictions = model.transform(df)
- .select(model.getFeaturesCol, model.getVarianceCol)
- .collect()
-
- predictions.foreach { case Row(features: Vector, variance: Double) =>
- val expectedVariance =
model.rootNode.predictImpl(features).impurityStats.calculate()
- assert(variance === expectedVariance,
- s"Expected variance $expectedVariance but got $variance.")
+ testTransformer[(Vector, Double)](df, model, "features", "variance") {
+ case Row(features: Vector, variance: Double) =>
+ val expectedVariance =
model.rootNode.predictImpl(features).impurityStats.calculate()
+ assert(variance === expectedVariance,
+ s"Expected variance $expectedVariance but got $variance.")
}
val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
@@ -104,18 +100,19 @@ class DecisionTreeRegressorSuite
dt.setMaxDepth(1)
.setMaxBins(6)
.setSeed(0)
- val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
- val calculatedVariances =
transformVarDF.select(dt.getVarianceCol).collect().map {
- case Row(variance: Double) => variance
- }
- // Since max depth is set to 1, the best split point is that which splits
the data
- // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for
each
- // data point in the left node is 0.667 and for each data point in the
right node
- // is 2.667
- val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
- calculatedVariances.zip(expectedVariances).foreach { case (actual,
expected) =>
- assert(actual ~== expected absTol 1e-3)
+ testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF,
dt.fit(varianceDF),
+ "variance") { case rows: Seq[Row] =>
+ val calculatedVariances = rows.map(_.getDouble(0))
+
+ // Since max depth is set to 1, the best split point is that which
splits the data
+ // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance
for each
+ // data point in the left node is 0.667 and for each data point in the
right node
+ // is 2.667
+ val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
+ calculatedVariances.zip(expectedVariances).foreach { case (actual,
expected) =>
+ assert(actual ~== expected absTol 1e-3)
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index ecbb571..11c593b 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,22 +19,20 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees
=> OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.Utils
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
import GBTRegressorSuite.compareAPIs
import testImplicits._
@@ -91,11 +89,14 @@ class GBTRegressorSuite extends SparkFunSuite with
MLlibTestSparkContext
val model = gbt.fit(df)
MLTestingUtils.checkCopyAndUids(gbt, model)
- val preds = model.transform(df)
- val predictions = preds.select("prediction").rdd.map(_.getDouble(0))
- // Checks based on SPARK-8736 (to ensure it is not doing classification)
- assert(predictions.max() > 2)
- assert(predictions.min() < -1)
+
+ testTransformerByGlobalCheckFunc[(Double, Vector)](df, model,
"prediction") {
+ case rows: Seq[Row] =>
+ val predictions = rows.map(_.getDouble(0))
+ // Checks based on SPARK-8736 (to ensure it is not doing
classification)
+ assert(predictions.max > 2)
+ assert(predictions.min < -1)
+ }
}
test("Checkpointing") {
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index df7dee8..ef2ff94 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.feature.{LabeledPoint, RFormula}
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random._
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -33,8 +33,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatType
-class GeneralizedLinearRegressionSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class GeneralizedLinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
import testImplicits._
@@ -268,8 +267,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = FamilyAndLink(trainer)
- model.transform(dataset).select("features", "prediction",
"linkPrediction").collect()
- .foreach {
+ testTransformer[(Double, Vector)](dataset, model,
+ "features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) +
model.intercept
val prediction2 = familyLink.fitted(eta)
@@ -278,7 +277,7 @@ class GeneralizedLinearRegressionSuite
s"gaussian family, $link link and fitIntercept =
$fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link
Prediction mismatch: " +
s"GLM with gaussian family, $link link and fitIntercept =
$fitIntercept.")
- }
+ }
idx += 1
}
@@ -384,8 +383,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = FamilyAndLink(trainer)
- model.transform(dataset).select("features", "prediction",
"linkPrediction").collect()
- .foreach {
+ testTransformer[(Double, Vector)](dataset, model,
+ "features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) +
model.intercept
val prediction2 = familyLink.fitted(eta)
@@ -394,7 +393,7 @@ class GeneralizedLinearRegressionSuite
s"binomial family, $link link and fitIntercept =
$fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link
Prediction mismatch: " +
s"GLM with binomial family, $link link and fitIntercept =
$fitIntercept.")
- }
+ }
idx += 1
}
@@ -456,8 +455,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = FamilyAndLink(trainer)
- model.transform(dataset).select("features", "prediction",
"linkPrediction").collect()
- .foreach {
+ testTransformer[(Double, Vector)](dataset, model,
+ "features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) +
model.intercept
val prediction2 = familyLink.fitted(eta)
@@ -466,7 +465,7 @@ class GeneralizedLinearRegressionSuite
s"poisson family, $link link and fitIntercept =
$fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link
Prediction mismatch: " +
s"GLM with poisson family, $link link and fitIntercept =
$fitIntercept.")
- }
+ }
idx += 1
}
@@ -562,8 +561,8 @@ class GeneralizedLinearRegressionSuite
s"$link link and fitIntercept = $fitIntercept.")
val familyLink = FamilyAndLink(trainer)
- model.transform(dataset).select("features", "prediction",
"linkPrediction").collect()
- .foreach {
+ testTransformer[(Double, Vector)](dataset, model,
+ "features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) +
model.intercept
val prediction2 = familyLink.fitted(eta)
@@ -572,7 +571,7 @@ class GeneralizedLinearRegressionSuite
s"gamma family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link
Prediction mismatch: " +
s"GLM with gamma family, $link link and fitIntercept =
$fitIntercept.")
- }
+ }
idx += 1
}
@@ -649,8 +648,8 @@ class GeneralizedLinearRegressionSuite
s"and variancePower = $variancePower.")
val familyLink = FamilyAndLink(trainer)
- model.transform(datasetTweedie).select("features", "prediction",
"linkPrediction").collect()
- .foreach {
+ testTransformer[(Double, Double, Vector)](datasetTweedie, model,
+ "features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
@@ -661,7 +660,8 @@ class GeneralizedLinearRegressionSuite
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link
Prediction mismatch: " +
s"GLM with tweedie family, linkPower = $linkPower, fitIntercept
= $fitIntercept " +
s"and variancePower = $variancePower.")
- }
+ }
+
idx += 1
}
}
@@ -724,8 +724,8 @@ class GeneralizedLinearRegressionSuite
s"fitIntercept = $fitIntercept and variancePower = $variancePower.")
val familyLink = FamilyAndLink(trainer)
- model.transform(datasetTweedie).select("features", "prediction",
"linkPrediction").collect()
- .foreach {
+ testTransformer[(Double, Double, Vector)](datasetTweedie, model,
+ "features", "prediction", "linkPrediction") {
case Row(features: DenseVector, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) +
model.intercept
val prediction2 = familyLink.fitted(eta)
@@ -736,7 +736,8 @@ class GeneralizedLinearRegressionSuite
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link
Prediction mismatch: " +
s"GLM with tweedie family, fitIntercept = $fitIntercept " +
s"and variancePower = $variancePower.")
- }
+ }
+
idx += 1
}
}
@@ -861,8 +862,8 @@ class GeneralizedLinearRegressionSuite
s" and fitIntercept = $fitIntercept.")
val familyLink = FamilyAndLink(trainer)
- model.transform(dataset).select("features", "offset", "prediction",
"linkPrediction")
- .collect().foreach {
+ testTransformer[(Double, Double, Double, Vector)](dataset, model,
+ "features", "offset", "prediction", "linkPrediction") {
case Row(features: DenseVector, offset: Double, prediction1: Double,
linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ offset
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 180f5f7..18fbbce 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -17,15 +17,12 @@
package org.apache.spark.ml.regression
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.{DataFrame, Row}
-class IsotonicRegressionSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class IsotonicRegressionSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -44,13 +41,11 @@ class IsotonicRegressionSuite
val model = ir.fit(dataset)
- val predictions = model
- .transform(dataset)
- .select("prediction").rdd.map { case Row(pred) =>
- pred
- }.collect()
-
- assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+ testTransformerByGlobalCheckFunc[(Double, Double, Double)](dataset, model,
+ "prediction") { case rows: Seq[Row] =>
+ val predictions = rows.map(_.getDouble(0))
+ assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+ }
assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8))
assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0,
18.0))
@@ -64,13 +59,11 @@ class IsotonicRegressionSuite
val model = ir.fit(dataset)
val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0,
2.0, 9.0))
- val predictions = model
- .transform(features)
- .select("prediction").rdd.map {
- case Row(pred) => pred
- }.collect()
-
- assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
+ testTransformerByGlobalCheckFunc[Tuple1[Double]](features, model,
+ "prediction") { case rows: Seq[Row] =>
+ val predictions = rows.map(_.getDouble(0))
+ assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
+ }
}
test("params validation") {
@@ -157,13 +150,11 @@ class IsotonicRegressionSuite
val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0))
- val predictions = model
- .transform(features)
- .select("prediction").rdd.map {
- case Row(pred) => pred
- }.collect()
-
- assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
+ testTransformerByGlobalCheckFunc[Tuple1[Double]](features, model,
+ "prediction") { case rows: Seq[Row] =>
+ val predictions = rows.map(_.getDouble(0))
+ assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
+ }
}
test("read/write") {
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 9bb2895..d42cb17 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.ml.regression
import scala.util.Random
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
@@ -363,8 +362,8 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model2.intercept ~== interceptR2 relTol 1E-3)
assert(model2.coefficients ~= coefficientsR2 relTol 1E-3)
- model1.transform(datasetWithDenseFeature).select("features",
"prediction")
- .collect().foreach {
+ testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+ "features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) *
model1.coefficients(1) +
@@ -416,8 +415,8 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model2.intercept ~== interceptR2 absTol 1E-2)
assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
- model1.transform(datasetWithDenseFeature).select("features",
"prediction")
- .collect().foreach {
+ testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+ "features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) *
model1.coefficients(1) +
@@ -467,7 +466,8 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model2.intercept ~== interceptR2 relTol 1E-2)
assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
- model1.transform(datasetWithDenseFeature).select("features",
"prediction").collect().foreach {
+ testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+ "features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) *
model1.coefficients(1) +
@@ -518,7 +518,8 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model2.intercept ~== interceptR2 absTol 1E-2)
assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
- model1.transform(datasetWithDenseFeature).select("features",
"prediction").collect().foreach {
+ testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+ "features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) *
model1.coefficients(1) +
@@ -570,8 +571,8 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model2.intercept ~== interceptR2 relTol 1E-2)
assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
- model1.transform(datasetWithDenseFeature).select("features",
"prediction")
- .collect().foreach {
+ testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+ "features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) *
model1.coefficients(1) +
@@ -624,8 +625,8 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(model2.intercept ~== interceptR2 absTol 1E-2)
assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
- model1.transform(datasetWithDenseFeature).select("features",
"prediction")
- .collect().foreach {
+ testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+ "features", "prediction") {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model1.coefficients(0) + features(1) *
model1.coefficients(1) +
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 7a5426e..17678aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -53,12 +53,12 @@ trait MLTest extends StreamTest with TempDirectory { self:
Suite =>
}
}
- def testTransformerOnStreamData[A : Encoder](
+ private[util] def testTransformerOnStreamData[A : Encoder](
dataframe: DataFrame,
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
- (checkFunction: Row => Unit): Unit = {
+ (globalCheckFunction: Seq[Row] => Unit): Unit = {
val columnNames = dataframe.schema.fieldNames
val stream = MemoryStream[A]
@@ -70,22 +70,43 @@ trait MLTest extends StreamTest with TempDirectory { self:
Suite =>
.select(firstResultCol, otherResultCols: _*)
testStream(streamOutput) (
AddData(stream, data: _*),
- CheckAnswer(checkFunction)
+ CheckAnswer(globalCheckFunction)
)
}
+ private[util] def testTransformerOnDF(
+ dataframe: DataFrame,
+ transformer: Transformer,
+ firstResultCol: String,
+ otherResultCols: String*)
+ (globalCheckFunction: Seq[Row] => Unit): Unit = {
+ val dfOutput = transformer.transform(dataframe)
+ val outputs = dfOutput.select(firstResultCol, otherResultCols:
_*).collect()
+ globalCheckFunction(outputs)
+ }
+
def testTransformer[A : Encoder](
dataframe: DataFrame,
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit): Unit = {
- testTransformerOnStreamData(dataframe, transformer, firstResultCol,
- otherResultCols: _*)(checkFunction)
+ testTransformerByGlobalCheckFunc(
+ dataframe,
+ transformer,
+ firstResultCol,
+ otherResultCols: _*) { rows: Seq[Row] => rows.foreach(checkFunction(_)) }
+ }
- val dfOutput = transformer.transform(dataframe)
- dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach {
row =>
- checkFunction(row)
- }
+ def testTransformerByGlobalCheckFunc[A : Encoder](
+ dataframe: DataFrame,
+ transformer: Transformer,
+ firstResultCol: String,
+ otherResultCols: String*)
+ (globalCheckFunction: Seq[Row] => Unit): Unit = {
+ testTransformerOnStreamData(dataframe, transformer, firstResultCol,
+ otherResultCols: _*)(globalCheckFunction)
+ testTransformerOnDF(dataframe, transformer, firstResultCol,
+ otherResultCols: _*)(globalCheckFunction)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
index 56217ec..20c5b53 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ml.util
-import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.Row
@@ -32,10 +31,13 @@ class MLTestSuite extends MLTest {
val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
.setInputCol("label").setOutputCol("indexed")
val indexerModel = indexer.fit(data)
- testTransformerOnStreamData[(Int, String)](data, indexerModel, "id",
"indexed") {
+ testTransformer[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id === indexed.toInt)
}
+ testTransformerByGlobalCheckFunc[(Int, String)] (data, indexerModel, "id",
"indexed") { rows =>
+ assert(rows.map(_.getDouble(1)).max === 5.0)
+ }
intercept[Exception] {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id",
"indexed") {
@@ -43,5 +45,11 @@ class MLTestSuite extends MLTest {
assert(id != indexed.toInt)
}
}
+ intercept[Exception] {
+ testTransformerOnStreamData[(Int, String)](data, indexerModel, "id",
"indexed") {
+ rows: Seq[Row] =>
+ assert(rows.map(_.getDouble(1)).max === 1.0)
+ }
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ea17afb/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index fb9ebc8..4b7f0fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -137,8 +137,8 @@ trait StreamTest extends QueryTest with SharedSQLContext
with TimeLimits with Be
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false,
false)
- def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
- CheckAnswerRowsByFunc(checkFunction, false)
+ def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
+ CheckAnswerRowsByFunc(globalCheckFunction, false)
}
/**
@@ -161,8 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext
with TimeLimits with Be
def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
- def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
- CheckAnswerRowsByFunc(checkFunction, true)
+ def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
+ CheckAnswerRowsByFunc(globalCheckFunction, true)
}
case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean,
isSorted: Boolean)
@@ -177,9 +177,10 @@ trait StreamTest extends QueryTest with SharedSQLContext
with TimeLimits with Be
private def operatorName = if (lastOnly) "CheckLastBatch" else
"CheckAnswer"
}
- case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly:
Boolean)
- extends StreamAction with StreamMustBeRunning {
- override def toString: String = s"$operatorName:
${checkFunction.toString()}"
+ case class CheckAnswerRowsByFunc(
+ globalCheckFunction: Seq[Row] => Unit,
+ lastOnly: Boolean) extends StreamAction with StreamMustBeRunning {
+ override def toString: String = s"$operatorName"
private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else
"CheckAnswerByFunc"
}
@@ -639,14 +640,12 @@ trait StreamTest extends QueryTest with SharedSQLContext
with TimeLimits with Be
error => failTest(error)
}
- case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
+ case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
- sparkAnswer.foreach { row =>
- try {
- checkFunction(row)
- } catch {
- case e: Throwable => failTest(e.toString)
- }
+ try {
+ globalCheckFunction(sparkAnswer)
+ } catch {
+ case e: Throwable => failTest(e.toString)
}
}
pos += 1
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]