Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/19876#discussion_r160484001
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
---
@@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
LinearRegressionSuite.allParamSettings, checkModelData)
}
+ test("pmml export") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ def checkModel(pmml: PMML): Unit = {
+ val dd = pmml.getDataDictionary
+ assert(dd.getNumberOfFields === 3)
+ val fields = dd.getDataFields.asScala
+ assert(fields(0).getName().toString === "field_0")
+ assert(fields(0).getOpType() == OpType.CONTINUOUS)
+ val pmmlRegressionModel =
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+ val pmmlPredictors =
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+ val pmmlWeights =
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+ assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+ assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+ }
+ testPMMLWrite(sc, model, checkModel)
+ }
+
+ test("unsupported export format") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ intercept[SparkException] {
--- End diff --
Doesn't this and the one below it test the same thing? I think we could
remove the first one.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]