Repository: spark
Updated Branches:
  refs/heads/branch-2.0 a6edec2c5 -> ea0cf93d3


[SPARK-16177][ML] model loading backward compatibility for ml.regression

## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-16177
model loading backward compatibility for ml.regression

## How was this patch tested?

existing ut and manual test for loading 1.6 models.

Author: Yuhao Yang <hhb...@gmail.com>

Closes #13879 from hhbyyh/regreComp.

(cherry picked from commit 14bc5a7f36bed19cd714a4c725a83feaccac3468)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ea0cf93d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ea0cf93d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ea0cf93d

Branch: refs/heads/branch-2.0
Commit: ea0cf93d3969845e9df8305c0ce54326cdfb2bbd
Parents: a6edec2
Author: Yuhao Yang <hhb...@gmail.com>
Authored: Thu Jun 23 20:43:19 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Jun 23 20:43:29 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/regression/AFTSurvivalRegression.scala  | 9 +++++----
 .../org/apache/spark/ml/regression/LinearRegression.scala   | 8 +++++---
 2 files changed, 10 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ea0cf93d/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2dbac49..7c51845 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends 
MLReadable[AFTSurvivalRegressionModel]
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
-        .select("coefficients", "intercept", "scale").head()
-      val coefficients = data.getAs[Vector](0)
-      val intercept = data.getDouble(1)
-      val scale = data.getDouble(2)
+      val Row(coefficients: Vector, intercept: Double, scale: Double) =
+        MLUtils.convertVectorColumnsToML(data, "coefficients")
+          .select("coefficients", "intercept", "scale")
+          .head()
       val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, 
intercept, scale)
 
       DefaultParamsReader.getAndSetParams(model, metadata)

http://git-wip-us.apache.org/repos/asf/spark/blob/ea0cf93d/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 2723f74..0a4d98c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -500,9 +501,10 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
-        .select("intercept", "coefficients").head()
-      val intercept = data.getDouble(0)
-      val coefficients = data.getAs[Vector](1)
+      val Row(intercept: Double, coefficients: Vector) =
+        MLUtils.convertVectorColumnsToML(data, "coefficients")
+          .select("intercept", "coefficients")
+          .head()
       val model = new LinearRegressionModel(metadata.uid, coefficients, 
intercept)
 
       DefaultParamsReader.getAndSetParams(model, metadata)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to