Github user holdenk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19876#discussion_r161391241
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
---
@@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with
DefaultReadWriteTest {
LinearRegressionSuite.allParamSettings, checkModelData)
}
+ test("pmml export") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ def checkModel(pmml: PMML): Unit = {
+ val dd = pmml.getDataDictionary
+ assert(dd.getNumberOfFields === 3)
+ val fields = dd.getDataFields.asScala
+ assert(fields(0).getName().toString === "field_0")
+ assert(fields(0).getOpType() == OpType.CONTINUOUS)
+ val pmmlRegressionModel =
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+ val pmmlPredictors =
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+ val pmmlWeights =
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+ assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+ assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+ }
+ testPMMLWrite(sc, model, checkModel)
+ }
+
+ test("unsupported export format") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ intercept[SparkException] {
+ model.write.format("boop").save("boop")
+ }
+ intercept[SparkException] {
+ model.write.format("com.holdenkarau.boop").save("boop")
+ }
+ withClue("ML source org.apache.spark.SparkContext is not a valid
MLWriterFormat") {
+ intercept[SparkException] {
+ model.write.format("org.apache.spark.SparkContext").save("boop2")
+ }
+ }
+ }
+
+ test("dummy export format is called") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ withClue("Dummy writer doesn't write") {
+ intercept[Exception] {
--- End diff --
good point
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]