Github user holdenk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19876#discussion_r158113236
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
    @@ -994,6 +998,38 @@ class LinearRegressionSuite
           LinearRegressionSuite.allParamSettings, checkModelData)
       }
     
    +  test("pmml export") {
    +    val lr = new LinearRegression()
    +    val model = lr.fit(datasetWithWeight)
    +    def checkModel(pmml: PMML): Unit = {
    +      val dd = pmml.getDataDictionary
    +      assert(dd.getNumberOfFields === 3)
    +      val fields = dd.getDataFields.asScala
    +      assert(fields(0).getName().toString === "field_0")
    +      assert(fields(0).getOpType() == OpType.CONTINUOUS)
    +      val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
    +      val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
    +      val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
    +      assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
    +      assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
    +    }
    +    testPMMLWrite(sc, model, checkModel)
    +  }
    +
    +  test("unsupported export format") {
    --- End diff --
    
    Sure, I'll put a dummy writer in test so it doesn't clog up our class space.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to