Repository: spark Updated Branches: refs/heads/branch-2.0 a6edec2c5 -> ea0cf93d3
[SPARK-16177][ML] model loading backward compatibility for ml.regression ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-16177 model loading backward compatibility for ml.regression ## How was this patch tested? existing ut and manual test for loading 1.6 models. Author: Yuhao Yang <hhb...@gmail.com> Closes #13879 from hhbyyh/regreComp. (cherry picked from commit 14bc5a7f36bed19cd714a4c725a83feaccac3468) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ea0cf93d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ea0cf93d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ea0cf93d Branch: refs/heads/branch-2.0 Commit: ea0cf93d3969845e9df8305c0ce54326cdfb2bbd Parents: a6edec2 Author: Yuhao Yang <hhb...@gmail.com> Authored: Thu Jun 23 20:43:19 2016 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Thu Jun 23 20:43:29 2016 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/regression/AFTSurvivalRegression.scala | 9 +++++---- .../org/apache/spark/ml/regression/LinearRegression.scala | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ea0cf93d/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2dbac49..7c51845 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("coefficients", "intercept", "scale").head() - val coefficients = data.getAs[Vector](0) - val intercept = data.getDouble(1) - val scale = data.getDouble(2) + val Row(coefficients: Vector, intercept: Double, scale: Double) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("coefficients", "intercept", "scale") + .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) DefaultParamsReader.getAndSetParams(model, metadata) http://git-wip-us.apache.org/repos/asf/spark/blob/ea0cf93d/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2723f74..0a4d98c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) + val Row(intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("intercept", "coefficients") + .head() val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) DefaultParamsReader.getAndSetParams(model, metadata) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org