Github user holdenk commented on a diff in the pull request:
https://github.com/apache/spark/pull/19876#discussion_r177202362
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---
@@ -710,15 +711,58 @@ class LinearRegressionModel private[ml] (
}
/**
- * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML
instance.
+ * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for
this ML instance.
*
* For [[LinearRegressionModel]], this does NOT currently save the
training [[summary]].
* An option to save [[summary]] may be added in the future.
*
* This also does not save the [[parent]] currently.
*/
@Since("1.6.0")
- override def write: MLWriter = new
LinearRegressionModel.LinearRegressionModelWriter(this)
+ override def write: GeneralMLWriter = new GeneralMLWriter(this)
+}
+
+/** A writer for LinearRegression that handles the "internal" (or default)
format */
+private class InternalLinearRegressionModelWriter
+ extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "internal"
+ override def stageName(): String =
"org.apache.spark.ml.regression.LinearRegressionModel"
+
+ private case class Data(intercept: Double, coefficients: Vector, scale:
Double)
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val instance = stage.asInstanceOf[LinearRegressionModel]
+ val sc = sparkSession.sparkContext
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: intercept, coefficients, scale
+ val data = Data(instance.intercept, instance.coefficients,
instance.scale)
+ val dataPath = new Path(path, "data").toString
+
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+}
+
+/** A writer for LinearRegression that handles the "pmml" format */
+private class PMMLLinearRegressionModelWriter
+ extends MLWriterFormat with MLFormatRegister {
--- End diff --
I've included this in https://github.com/apache/spark/pull/20907
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]