Repository: spark Updated Branches: refs/heads/master b0bdfce9c -> 35db3b9fe
[SPARK-17025][ML][PYTHON] Persistence for Pipelines with Python-only Stages ## What changes were proposed in this pull request? Implemented a Python-only persistence framework for pipelines containing stages that cannot be saved using Java. ## How was this patch tested? Created a custom Python-only UnaryTransformer, included it in a Pipeline, and saved/loaded the pipeline. The loaded pipeline was compared against the original using _compare_pipelines() in tests.py. Author: Ajay Saini <ajays...@gmail.com> Closes #18888 from ajaysaini725/PythonPipelines. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35db3b9f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35db3b9f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35db3b9f Branch: refs/heads/master Commit: 35db3b9fe38dadfb8afb0b0857c09f83196398be Parents: b0bdfce Author: Ajay Saini <ajays...@gmail.com> Authored: Fri Aug 11 23:57:08 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Aug 11 23:57:08 2017 -0700 ---------------------------------------------------------------------- python/pyspark/ml/pipeline.py | 156 +++++++++++++++++++++++++++++++++++-- python/pyspark/ml/tests.py | 35 ++++++++- 2 files changed, 183 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/35db3b9f/python/pyspark/ml/pipeline.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a8dc76b..0975302 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -16,6 +16,7 @@ # import sys +import os if sys.version > '3': basestring = str @@ -23,7 +24,7 @@ if sys.version > '3': from pyspark import since, keyword_only, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc @@ -130,13 +131,16 @@ class Pipeline(Estimator, MLReadable, MLWritable): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineReader(cls) @classmethod def _from_java(cls, java_stage): @@ -172,6 +176,76 @@ class Pipeline(Estimator, MLReadable, MLWritable): @inherit_doc +class PipelineWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types + """ + + def __init__(self, instance): + super(PipelineWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.getStages() + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types + """ + + def __init__(self, cls): + super(PipelineReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': + return JavaMLReader(self.cls).load(path) + else: + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + return Pipeline(stages=stages)._resetUid(uid) + + +@inherit_doc +class PipelineModelWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types + """ + + def __init__(self, instance): + super(PipelineModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.stages + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineModelReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types + """ + + def __init__(self, cls): + super(PipelineModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': + return JavaMLReader(self.cls).load(path) + else: + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) + return PipelineModel(stages=stages)._resetUid(uid) + + +@inherit_doc class PipelineModel(Model, MLReadable, MLWritable): """ Represents a compiled pipeline with transformers and fitted models. @@ -204,13 +278,16 @@ class PipelineModel(Model, MLReadable, MLWritable): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineModelWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineModelReader(cls) @classmethod def _from_java(cls, java_stage): @@ -242,3 +319,72 @@ class PipelineModel(Model, MLReadable, MLWritable): JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj + + +@inherit_doc +class PipelineSharedReadWrite(): + """ + .. note:: DeveloperApi + + Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between + :py:class:`Pipeline` and :py:class:`PipelineModel` + + .. versionadded:: 2.3.0 + """ + + @staticmethod + def checkStagesForJava(stages): + return all(isinstance(stage, JavaMLWritable) for stage in stages) + + @staticmethod + def validateStages(stages): + """ + Check that all stages are Writable + """ + for stage in stages: + if not isinstance(stage, MLWritable): + raise ValueError("Pipeline write will fail on this pipeline " + + "because stage %s of type %s is not MLWritable", + stage.uid, type(stage)) + + @staticmethod + def saveImpl(instance, stages, sc, path): + """ + Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + - save metadata to path/metadata + - save stages to stages/IDX_UID + """ + stageUids = [stage.uid for stage in stages] + jsonParams = {'stageUids': stageUids, 'language': 'Python'} + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) + stagesDir = os.path.join(path, "stages") + for index, stage in enumerate(stages): + stage.write().save(PipelineSharedReadWrite + .getStagePath(stage.uid, index, len(stages), stagesDir)) + + @staticmethod + def load(metadata, sc, path): + """ + Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + + :return: (UID, list of stages) + """ + stagesDir = os.path.join(path, "stages") + stageUids = metadata['paramMap']['stageUids'] + stages = [] + for index, stageUid in enumerate(stageUids): + stagePath = \ + PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) + stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) + stages.append(stage) + return (metadata['uid'], stages) + + @staticmethod + def getStagePath(stageUid, stageIdx, numStages, stagesDir): + """ + Get path for saving the given stage. + """ + stageIdxDigits = len(str(numStages)) + stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid + stagePath = os.path.join(stagesDir, stageDir) + return stagePath http://git-wip-us.apache.org/repos/asf/spark/blob/35db3b9f/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6aecc7f..0495973 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -123,7 +123,7 @@ class MockTransformer(Transformer, HasFake): return dataset -class MockUnaryTransformer(UnaryTransformer): +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): shift = Param(Params._dummy(), "shift", "The amount by which to shift " + "data in a DataFrame", @@ -150,7 +150,7 @@ class MockUnaryTransformer(UnaryTransformer): def validateInputType(self, inputType): if inputType != DoubleType(): raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Integer.") + "Requires Double.") class MockEstimator(Estimator, HasFake): @@ -1063,7 +1063,7 @@ class PersistenceTest(SparkSessionTestCase): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams): + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self._compare_params(m1, m2, p) @@ -1142,6 +1142,35 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + def test_onevsrest(self): temp_path = tempfile.mkdtemp() df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org