Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/19876#discussion_r156388871
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
---
@@ -994,6 +998,38 @@ class LinearRegressionSuite
LinearRegressionSuite.allParamSettings, checkModelData)
}
+ test("pmml export") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ def checkModel(pmml: PMML): Unit = {
+ val dd = pmml.getDataDictionary
+ assert(dd.getNumberOfFields === 3)
+ val fields = dd.getDataFields.asScala
+ assert(fields(0).getName().toString === "field_0")
+ assert(fields(0).getOpType() == OpType.CONTINUOUS)
+ val pmmlRegressionModel =
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+ val pmmlPredictors =
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+ val pmmlWeights =
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+ assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+ assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+ }
+ testPMMLWrite(sc, model, checkModel)
+ }
+
+ test("unsupported export format") {
--- End diff --
Would be great to have a test that verifies that this works with third
party implementations. Specifically, that something like
`model.write.format("org.apache.spark.ml.MyDummyWriter").save(path)` works.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]