Repository: spark
Updated Branches:
  refs/heads/master baf5cac0f -> fdcee028a


[SPARK-21542][ML][PYTHON] Python persistence helper functions

## What changes were proposed in this pull request?

Added DefaultParamsWriteable, DefaultParamsReadable, DefaultParamsWriter, and 
DefaultParamsReader to Python to support Python-only persistence of 
Json-serializable parameters.

## How was this patch tested?

Instantiated an estimator with Json-serializable parameters (ex. 
LogisticRegression), saved it using the added helper functions, and loaded it 
back, and compared it to the original instance to make sure it is the same. 
This test was both done in the Python REPL and implemented in the unit tests.

Note to reviewers: there are a few excess comments that I left in the code for 
clarity but will remove before the code is merged to master.

Author: Ajay Saini <ajays...@gmail.com>

Closes #18742 from ajaysaini725/PythonPersistenceHelperFunctions.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fdcee028
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fdcee028
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fdcee028

Branch: refs/heads/master
Commit: fdcee028afa7a7ac0f8bd8f59ee4933d7caea064
Parents: baf5cac
Author: Ajay Saini <ajays...@gmail.com>
Authored: Mon Aug 7 17:03:20 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Aug 7 17:03:20 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/util/ReadWrite.scala    |  37 ++-
 python/pyspark/ml/param/__init__.py             |  11 +
 python/pyspark/ml/pipeline.py                   |  10 -
 python/pyspark/ml/tests.py                      |  34 +++
 python/pyspark/ml/util.py                       | 302 +++++++++++++++++--
 5 files changed, 342 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fdcee028/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index b54e258..65f142c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -96,21 +96,7 @@ abstract class MLWriter extends BaseReadWrite with Logging {
   @Since("1.6.0")
   @throws[IOException]("If the input path already exists but overwrite is not 
enabled.")
   def save(path: String): Unit = {
-    val hadoopConf = sc.hadoopConfiguration
-    val outputPath = new Path(path)
-    val fs = outputPath.getFileSystem(hadoopConf)
-    val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, 
fs.getWorkingDirectory)
-    if (fs.exists(qualifiedOutputPath)) {
-      if (shouldOverwrite) {
-        logInfo(s"Path $path already exists. It will be overwritten.")
-        // TODO: Revert back to the original content if save is not successful.
-        fs.delete(qualifiedOutputPath, true)
-      } else {
-        throw new IOException(s"Path $path already exists. To overwrite it, " +
-          s"please use write.overwrite().save(path) for Scala and use " +
-          s"write().overwrite().save(path) for Java and Python.")
-      }
-    }
+    new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc)
     saveImpl(path)
   }
 
@@ -471,3 +457,24 @@ private[ml] object MetaAlgorithmReadWrite {
     List((instance.uid, instance)) ++ subStageMaps
   }
 }
+
+private[ml] class FileSystemOverwrite extends Logging {
+
+  def handleOverwrite(path: String, shouldOverwrite: Boolean, sc: 
SparkContext): Unit = {
+    val hadoopConf = sc.hadoopConfiguration
+    val outputPath = new Path(path)
+    val fs = outputPath.getFileSystem(hadoopConf)
+    val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, 
fs.getWorkingDirectory)
+    if (fs.exists(qualifiedOutputPath)) {
+      if (shouldOverwrite) {
+        logInfo(s"Path $path already exists. It will be overwritten.")
+        // TODO: Revert back to the original content if save is not successful.
+        fs.delete(qualifiedOutputPath, true)
+      } else {
+        throw new IOException(s"Path $path already exists. To overwrite it, " +
+          s"please use write.overwrite().save(path) for Scala and use " +
+          s"write().overwrite().save(path) for Java and Python.")
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fdcee028/python/pyspark/ml/param/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/__init__.py 
b/python/pyspark/ml/param/__init__.py
index 4583ae8..1334207 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -384,6 +384,17 @@ class Params(Identifiable):
         that._defaultParamMap = {}
         return self._copyValues(that, extra)
 
+    def set(self, param, value):
+        """
+        Sets a parameter in the embedded param map.
+        """
+        self._shouldOwn(param)
+        try:
+            value = param.typeConverter(value)
+        except ValueError as e:
+            raise ValueError('Invalid param value given for param "%s". %s' % 
(param.name, e))
+        self._paramMap[param] = value
+
     def _shouldOwn(self, param):
         """
         Validates that the input param belongs to this Params instance.

http://git-wip-us.apache.org/repos/asf/spark/blob/fdcee028/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 4aac6a4..a8dc76b 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -132,11 +132,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         """Returns an MLWriter instance for this ML instance."""
         return JavaMLWriter(self)
 
-    @since("2.0.0")
-    def save(self, path):
-        """Save this ML instance to the given path, a shortcut of 
`write().save(path)`."""
-        self.write().save(path)
-
     @classmethod
     @since("2.0.0")
     def read(cls):
@@ -211,11 +206,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
         """Returns an MLWriter instance for this ML instance."""
         return JavaMLWriter(self)
 
-    @since("2.0.0")
-    def save(self, path):
-        """Save this ML instance to the given path, a shortcut of 
`write().save(path)`."""
-        self.write().save(path)
-
     @classmethod
     @since("2.0.0")
     def read(cls):

http://git-wip-us.apache.org/repos/asf/spark/blob/fdcee028/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 3bd4d37..6aecc7f 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -62,6 +62,7 @@ from pyspark.ml.regression import DecisionTreeRegressor, 
GeneralizedLinearRegres
     LinearRegression
 from pyspark.ml.stat import ChiSquareTest
 from pyspark.ml.tuning import *
+from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaParams, JavaWrapper
 from pyspark.serializers import PickleSerializer
 from pyspark.sql import DataFrame, Row, SparkSession
@@ -377,6 +378,12 @@ class ParamTests(PySparkTestCase):
         with self.assertRaises(KeyError):
             testParams.getInputCol()
 
+        otherParam = Param(Params._dummy(), "otherParam", "Parameter used to 
test that " +
+                           "set raises an error for a non-member parameter.",
+                           typeConverter=TypeConverters.toString)
+        with self.assertRaises(ValueError):
+            testParams.set(otherParam, "value")
+
         # Since the default is normally random, set it to a known number for 
debug str
         testParams._setDefault(seed=41)
         testParams.setSeed(43)
@@ -1189,6 +1196,33 @@ class PersistenceTest(SparkSessionTestCase):
         except OSError:
             pass
 
+    def test_default_read_write(self):
+        temp_path = tempfile.mkdtemp()
+
+        lr = LogisticRegression()
+        lr.setMaxIter(50)
+        lr.setThreshold(.75)
+        writer = DefaultParamsWriter(lr)
+
+        savePath = temp_path + "/lr"
+        writer.save(savePath)
+
+        reader = DefaultParamsReadable.read()
+        lr2 = reader.load(savePath)
+
+        self.assertEqual(lr.uid, lr2.uid)
+        self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())
+
+        # test overwrite
+        lr.setThreshold(.8)
+        writer.overwrite().save(savePath)
+
+        reader = DefaultParamsReadable.read()
+        lr3 = reader.load(savePath)
+
+        self.assertEqual(lr.uid, lr3.uid)
+        self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
+
 
 class LDATest(SparkSessionTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fdcee028/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 7863edd..6777291 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -15,16 +15,21 @@
 # limitations under the License.
 #
 
+import json
 import sys
+import os
+import time
 import uuid
 import warnings
 
 if sys.version > '3':
     basestring = str
     unicode = str
+    long = int
 
 from pyspark import SparkContext, since
 from pyspark.ml.common import inherit_doc
+from pyspark.sql import SparkSession
 
 
 def _jvm():
@@ -61,33 +66,86 @@ class Identifiable(object):
 
 
 @inherit_doc
-class MLWriter(object):
+class BaseReadWrite(object):
     """
-    Utility class that can save ML instances.
+    Base class for MLWriter and MLReader. Stores information about the 
SparkContext
+    and SparkSession.
 
-    .. versionadded:: 2.0.0
+    .. versionadded:: 2.3.0
     """
 
-    def save(self, path):
-        """Save the ML instance to the input path."""
-        raise NotImplementedError("MLWriter is not yet implemented for type: 
%s" % type(self))
-
-    def overwrite(self):
-        """Overwrites if the output path already exists."""
-        raise NotImplementedError("MLWriter is not yet implemented for type: 
%s" % type(self))
+    def __init__(self):
+        self._sparkSession = None
 
     def context(self, sqlContext):
         """
-        Sets the SQL context to use for saving.
+        Sets the Spark SQLContext to use for saving/loading.
 
         .. note:: Deprecated in 2.1 and will be removed in 3.0, use session 
instead.
         """
-        raise NotImplementedError("MLWriter is not yet implemented for type: 
%s" % type(self))
+        raise NotImplementedError("Read/Write is not yet implemented for type: 
%s" % type(self))
 
     def session(self, sparkSession):
-        """Sets the Spark Session to use for saving."""
+        """
+        Sets the Spark Session to use for saving/loading.
+        """
+        self._sparkSession = sparkSession
+        return self
+
+    @property
+    def sparkSession(self):
+        """
+        Returns the user-specified Spark Session or the default.
+        """
+        if self._sparkSession is None:
+            self._sparkSession = SparkSession.builder.getOrCreate()
+        return self._sparkSession
+
+    @property
+    def sc(self):
+        """
+        Returns the underlying `SparkContext`.
+        """
+        return self.sparkSession.sparkContext
+
+
+@inherit_doc
+class MLWriter(BaseReadWrite):
+    """
+    Utility class that can save ML instances.
+
+    .. versionadded:: 2.0.0
+    """
+
+    def __init__(self):
+        super(MLWriter, self).__init__()
+        self.shouldOverwrite = False
+
+    def _handleOverwrite(self, path):
+        from pyspark.ml.wrapper import JavaWrapper
+
+        _java_obj = 
JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
+        wrapper = JavaWrapper(_java_obj)
+        wrapper._call_java("handleOverwrite", path, True, self.sc._jsc.sc())
+
+    def save(self, path):
+        """Save the ML instance to the input path."""
+        if self.shouldOverwrite:
+            self._handleOverwrite(path)
+        self.saveImpl(path)
+
+    def saveImpl(self, path):
+        """
+        save() handles overwriting and then calls this method.  Subclasses 
should override this
+        method to implement the actual saving of the instance.
+        """
         raise NotImplementedError("MLWriter is not yet implemented for type: 
%s" % type(self))
 
+    def overwrite(self):
+        """Overwrites if the output path already exists."""
+        self.shouldOverwrite = True
+        return self
+
 
 @inherit_doc
 class JavaMLWriter(MLWriter):
@@ -140,7 +198,7 @@ class MLWritable(object):
         raise NotImplementedError("MLWritable is not yet implemented for type: 
%r" % type(self))
 
     def save(self, path):
-        """Save this ML instance to the given path, a shortcut of 
`write().save(path)`."""
+        """Save this ML instance to the given path, a shortcut of 
'write().save(path)'."""
         self.write().save(path)
 
 
@@ -156,29 +214,20 @@ class JavaMLWritable(MLWritable):
 
 
 @inherit_doc
-class MLReader(object):
+class MLReader(BaseReadWrite):
     """
     Utility class that can load ML instances.
 
     .. versionadded:: 2.0.0
     """
 
+    def __init__(self):
+        super(MLReader, self).__init__()
+
     def load(self, path):
         """Load the ML instance from the input path."""
         raise NotImplementedError("MLReader is not yet implemented for type: 
%s" % type(self))
 
-    def context(self, sqlContext):
-        """
-        Sets the SQL context to use for loading.
-
-        .. note:: Deprecated in 2.1 and will be removed in 3.0, use session 
instead.
-        """
-        raise NotImplementedError("MLReader is not yet implemented for type: 
%s" % type(self))
-
-    def session(self, sparkSession):
-        """Sets the Spark Session to use for loading."""
-        raise NotImplementedError("MLReader is not yet implemented for type: 
%s" % type(self))
-
 
 @inherit_doc
 class JavaMLReader(MLReader):
@@ -187,6 +236,7 @@ class JavaMLReader(MLReader):
     """
 
     def __init__(self, clazz):
+        super(JavaMLReader, self).__init__()
         self._clazz = clazz
         self._jread = self._load_java_obj(clazz).read()
 
@@ -283,3 +333,201 @@ class JavaPredictionModel():
         Returns the number of features the model was trained on. If unknown, 
returns -1
         """
         return self._call_java("numFeatures")
+
+
+@inherit_doc
+class DefaultParamsWritable(MLWritable):
+    """
+    .. note:: DeveloperApi
+
+    Helper trait for making simple :py:class:`Params` types writable.  If a 
:py:class:`Params`
+    class stores all data as :py:class:`Param` values, then extending this 
trait will provide
+    a default implementation of writing saved instances of the class.
+    This only handles simple :py:class:`Param` types; e.g., it will not handle
+    :py:class:`Dataset`. See :py:class:`DefaultParamsReadable`, the 
counterpart to this trait.
+
+    .. versionadded:: 2.3.0
+    """
+
+    def write(self):
+        """Returns a DefaultParamsWriter instance for this class."""
+        from pyspark.ml.param import Params
+
+        if isinstance(self, Params):
+            return DefaultParamsWriter(self)
+        else:
+            raise TypeError("Cannot use DefautParamsWritable with type %s 
because it does not " +
+                            " extend Params.", type(self))
+
+
+@inherit_doc
+class DefaultParamsWriter(MLWriter):
+    """
+    .. note:: DeveloperApi
+
+    Specialization of :py:class:`MLWriter` for :py:class:`Params` types
+
+    Class for writing Estimators and Transformers whose parameters are 
JSON-serializable.
+
+    .. versionadded:: 2.3.0
+    """
+
+    def __init__(self, instance):
+        super(DefaultParamsWriter, self).__init__()
+        self.instance = instance
+
+    def saveImpl(self, path):
+        DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
+
+    @staticmethod
+    def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
+        """
+        Saves metadata + Params to: path + "/metadata"
+        - class
+        - timestamp
+        - sparkVersion
+        - uid
+        - paramMap
+        - (optionally, extra metadata)
+        :param extraMetadata:  Extra metadata to be saved at same level as 
uid, paramMap, etc.
+        :param paramMap:  If given, this is saved in the "paramMap" field.
+        """
+        metadataPath = os.path.join(path, "metadata")
+        metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
+                                                                 sc,
+                                                                 extraMetadata,
+                                                                 paramMap)
+        sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
+
+    @staticmethod
+    def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
+        """
+        Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts 
the JSON to save.
+        This is useful for ensemble models which need to save metadata for 
many sub-models.
+
+        .. note:: :py:meth:`DefaultParamsWriter.saveMetadata` for details on 
what this includes.
+        """
+        uid = instance.uid
+        cls = instance.__module__ + '.' + instance.__class__.__name__
+        params = instance.extractParamMap()
+        jsonParams = {}
+        if paramMap is not None:
+            jsonParams = paramMap
+        else:
+            for p in params:
+                jsonParams[p.name] = params[p]
+        basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 
1000)),
+                         "sparkVersion": sc.version, "uid": uid, "paramMap": 
jsonParams}
+        if extraMetadata is not None:
+            basicMetadata.update(extraMetadata)
+        return json.dumps(basicMetadata, separators=[',',  ':'])
+
+
+@inherit_doc
+class DefaultParamsReadable(MLReadable):
+    """
+    .. note:: DeveloperApi
+
+    Helper trait for making simple :py:class:`Params` types readable.
+    If a :py:class:`Params` class stores all data as :py:class:`Param` values,
+    then extending this trait will provide a default implementation of reading 
saved
+    instances of the class. This only handles simple :py:class:`Param` types;
+    e.g., it will not handle :py:class:`Dataset`. See 
:py:class:`DefaultParamsWritable`,
+    the counterpart to this trait.
+
+    .. versionadded:: 2.3.0
+    """
+
+    @classmethod
+    def read(cls):
+        """Returns a DefaultParamsReader instance for this class."""
+        return DefaultParamsReader(cls)
+
+
+@inherit_doc
+class DefaultParamsReader(MLReader):
+    """
+    .. note:: DeveloperApi
+
+    Specialization of :py:class:`MLReader` for :py:class:`Params` types
+
+    Default :py:class:`MLReader` implementation for transformers and 
estimators that
+    contain basic (json-serializable) params and no data. This will not handle
+    more complex params or types with data (e.g., models with coefficients).
+
+    .. versionadded:: 2.3.0
+    """
+
+    def __init__(self, cls):
+        super(DefaultParamsReader, self).__init__()
+        self.cls = cls
+
+    @staticmethod
+    def __get_class(clazz):
+        """
+        Loads Python class from its name.
+        """
+        parts = clazz.split('.')
+        module = ".".join(parts[:-1])
+        m = __import__(module)
+        for comp in parts[1:]:
+            m = getattr(m, comp)
+        return m
+
+    def load(self, path):
+        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        py_type = DefaultParamsReader.__get_class(metadata['class'])
+        instance = py_type()
+        instance._resetUid(metadata['uid'])
+        DefaultParamsReader.getAndSetParams(instance, metadata)
+        return instance
+
+    @staticmethod
+    def loadMetadata(path, sc, expectedClassName=""):
+        """
+        Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
+
+        :param expectedClassName:  If non empty, this is checked against the 
loaded metadata.
+        """
+        metadataPath = os.path.join(path, "metadata")
+        metadataStr = sc.textFile(metadataPath, 1).first()
+        loadedVals = DefaultParamsReader._parseMetaData(metadataStr, 
expectedClassName)
+        return loadedVals
+
+    @staticmethod
+    def _parseMetaData(metadataStr, expectedClassName=""):
+        """
+        Parse metadata JSON string produced by 
:py:meth`DefaultParamsWriter._get_metadata_to_save`.
+        This is a helper function for 
:py:meth:`DefaultParamsReader.loadMetadata`.
+
+        :param metadataStr:  JSON string of metadata
+        :param expectedClassName:  If non empty, this is checked against the 
loaded metadata.
+        """
+        metadata = json.loads(metadataStr)
+        className = metadata['class']
+        if len(expectedClassName) > 0:
+            assert className == expectedClassName, "Error loading metadata: 
Expected " + \
+                "class name {} but found class name 
{}".format(expectedClassName, className)
+        return metadata
+
+    @staticmethod
+    def getAndSetParams(instance, metadata):
+        """
+        Extract Params from metadata, and set them in the instance.
+        """
+        for paramName in metadata['paramMap']:
+            param = instance.getParam(paramName)
+            paramValue = metadata['paramMap'][paramName]
+            instance.set(param, paramValue)
+
+    @staticmethod
+    def loadParamsInstance(path, sc):
+        """
+        Load a :py:class:`Params` instance from the given path, and return it.
+        This assumes the instance inherits from :py:class:`MLReadable`.
+        """
+        metadata = DefaultParamsReader.loadMetadata(path, sc)
+        pythonClassName = metadata['class'].replace("org.apache.spark", 
"pyspark")
+        py_type = DefaultParamsReader.__get_class(pythonClassName)
+        instance = py_type.load(path)
+        return instance


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to