spark git commit: [SPARK-11612][ML] Pipeline and PipelineModel persistence
Repository: spark Updated Branches: refs/heads/branch-1.6 32a69e4c1 -> 505eceef3 [SPARK-11612][ML] Pipeline and PipelineModel persistence Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable. Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed. CC: mengxr Author: Joseph K. Bradley Closes #9674 from jkbradley/pipeline-io. (cherry picked from commit 1c5475f1401d2233f4c61f213d1e2c2ee9673067) Signed-off-by: Joseph K. Bradley Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/505eceef Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/505eceef Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/505eceef Branch: refs/heads/branch-1.6 Commit: 505eceef303e3291253b35164fbec7e4390e8252 Parents: 32a69e4 Author: Joseph K. Bradley Authored: Mon Nov 16 17:12:39 2015 -0800 Committer: Joseph K. Bradley Committed: Mon Nov 16 17:12:48 2015 -0800 -- .../scala/org/apache/spark/ml/Pipeline.scala| 175 ++- .../org/apache/spark/ml/util/ReadWrite.scala| 4 +- .../org/apache/spark/ml/PipelineSuite.scala | 120 - .../spark/ml/util/DefaultReadWriteTest.scala| 25 +-- 4 files changed, 306 insertions(+), 18 deletions(-) -- http://git-wip-us.apache.org/repos/asf/spark/blob/505eceef/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala -- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a3e5940..25f0c69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,12 +22,19 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Reader +import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { def this() = this(Identifiable.randomUID("pipeline")) @@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + override def write: Writer = new Pipeline.PipelineWriter(this) +} + +object Pipeline extends Readable[Pipeline] { + + override def read: Reader[Pipeline] = new PipelineReader + + override def load(path: String): Pipeline = read.load(path) + + private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + +SharedReadWrite.validateStages(instance.getStages) + +override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private[ml] class PipelineReader extends Reader[Pipeline] { + +/** Checked against metadata when loading model */ +private val className = "org.apache.spark.ml.Pipeline" + +override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) +} + } + + /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + +import org.json4s.JsonDSL._ + +/** Check that all stages are Writable */ +def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { +case stage: Writable => // good +case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + +s" because it contains a stage which does not implement Writable. Non-Writable stage:" + +s" ${other.uid} of type ${other.getClass}") + } +} + +/** + * Save metadata and stages for a [[P
spark git commit: [SPARK-11612][ML] Pipeline and PipelineModel persistence
Repository: spark Updated Branches: refs/heads/master bd10eb81c -> 1c5475f14 [SPARK-11612][ML] Pipeline and PipelineModel persistence Pipeline and PipelineModel extend Readable and Writable. Persistence succeeds only when all stages are Writable. Note: This PR reinstates tests for other read/write functionality. It should probably not get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed. CC: mengxr Author: Joseph K. Bradley Closes #9674 from jkbradley/pipeline-io. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c5475f1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c5475f1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c5475f1 Branch: refs/heads/master Commit: 1c5475f1401d2233f4c61f213d1e2c2ee9673067 Parents: bd10eb8 Author: Joseph K. Bradley Authored: Mon Nov 16 17:12:39 2015 -0800 Committer: Joseph K. Bradley Committed: Mon Nov 16 17:12:39 2015 -0800 -- .../scala/org/apache/spark/ml/Pipeline.scala| 175 ++- .../org/apache/spark/ml/util/ReadWrite.scala| 4 +- .../org/apache/spark/ml/PipelineSuite.scala | 120 - .../spark/ml/util/DefaultReadWriteTest.scala| 25 +-- 4 files changed, 306 insertions(+), 18 deletions(-) -- http://git-wip-us.apache.org/repos/asf/spark/blob/1c5475f1/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala -- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a3e5940..25f0c69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,12 +22,19 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Reader +import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { def this() = this(Identifiable.randomUID("pipeline")) @@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + override def write: Writer = new Pipeline.PipelineWriter(this) +} + +object Pipeline extends Readable[Pipeline] { + + override def read: Reader[Pipeline] = new PipelineReader + + override def load(path: String): Pipeline = read.load(path) + + private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + +SharedReadWrite.validateStages(instance.getStages) + +override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private[ml] class PipelineReader extends Reader[Pipeline] { + +/** Checked against metadata when loading model */ +private val className = "org.apache.spark.ml.Pipeline" + +override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) +} + } + + /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + +import org.json4s.JsonDSL._ + +/** Check that all stages are Writable */ +def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { +case stage: Writable => // good +case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + +s" because it contains a stage which does not implement Writable. Non-Writable stage:" + +s" ${other.uid} of type ${other.getClass}") + } +} + +/** + * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * - save metadata to path/metadata + * - save stages to stages/IDX_UI