Repository: spark
Updated Branches:
  refs/heads/master fd8d283ee -> aedbbaa3d


[SPARK-6053][MLLIB] support save/load in PySpark's ALS

A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley

Author: Xiangrui Meng <[email protected]>

Closes #4811 from mengxr/SPARK-5991 and squashes the following commits:

f135dac [Xiangrui Meng] update save doc
57e5200 [Xiangrui Meng] address comments
06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into 
SPARK-5991
282ec8d [Xiangrui Meng] support save/load in PySpark's ALS


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/aedbbaa3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/aedbbaa3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/aedbbaa3

Branch: refs/heads/master
Commit: aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2
Parents: fd8d283
Author: Xiangrui Meng <[email protected]>
Authored: Sun Mar 1 16:26:57 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Sun Mar 1 16:26:57 2015 -0800

----------------------------------------------------------------------
 docs/mllib-collaborative-filtering.md           |  8 ++-
 .../apache/spark/mllib/util/modelSaveLoad.scala |  2 +-
 python/pyspark/mllib/recommendation.py          | 20 ++++++-
 python/pyspark/mllib/util.py                    | 58 ++++++++++++++++++++
 4 files changed, 82 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/aedbbaa3/docs/mllib-collaborative-filtering.md
----------------------------------------------------------------------
diff --git a/docs/mllib-collaborative-filtering.md 
b/docs/mllib-collaborative-filtering.md
index 27aa4d3..7614028 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -200,10 +200,8 @@ In the following example we load rating data. Each row 
consists of a user, a pro
 We use the default ALS.train() method which assumes ratings are explicit. We 
evaluate the
 recommendation by measuring the Mean Squared Error of rating prediction.
 
-Note that the Python API does not yet support model save/load but will in the 
future.
-
 {% highlight python %}
-from pyspark.mllib.recommendation import ALS, Rating
+from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
 
 # Load and parse the data
 data = sc.textFile("data/mllib/als/test.data")
@@ -220,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: 
((r[0], r[1]), r[2]))
 ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
 MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: 
x + y) / ratesAndPreds.count()
 print("Mean Squared Error = " + str(MSE))
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
 {% endhighlight %}
 
 If the rating matrix is derived from other source of information (i.e., it is 
inferred from other

http://git-wip-us.apache.org/repos/asf/spark/blob/aedbbaa3/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
index 4458340..526d055 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -48,7 +48,7 @@ trait Saveable {
    *
    * @param sc  Spark context used to save model data.
    * @param path  Path specifying the directory in which to save this model.
-   *              This directory and any intermediate directory will be 
created if needed.
+   *              If the directory already exists, this method throws an 
exception.
    */
   def save(sc: SparkContext, path: String): Unit
 

http://git-wip-us.apache.org/repos/asf/spark/blob/aedbbaa3/python/pyspark/mllib/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/recommendation.py 
b/python/pyspark/mllib/recommendation.py
index 0d99e6d..03d7d01 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -19,7 +19,8 @@ from collections import namedtuple
 
 from pyspark import SparkContext
 from pyspark.rdd import RDD
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
+from pyspark.mllib.util import Saveable, JavaLoader
 
 __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
 
@@ -39,7 +40,8 @@ class Rating(namedtuple("Rating", ["user", "product", 
"rating"])):
         return Rating, (int(self.user), int(self.product), float(self.rating))
 
 
-class MatrixFactorizationModel(JavaModelWrapper):
+@inherit_doc
+class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
 
     """A matrix factorisation model trained by regularized alternating
     least-squares.
@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
     >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
     >>> model.predict(2,2)
     0.43...
+
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = MatrixFactorizationModel.load(sc, path)
+    >>> sameModel.predict(2,2)
+    0.43...
+    >>> try:
+    ...     os.removedirs(path)
+    ... except:
+    ...     pass
     """
     def predict(self, user, product):
         return self._java_model.predict(int(user), int(product))
@@ -98,6 +111,9 @@ class MatrixFactorizationModel(JavaModelWrapper):
     def productFeatures(self):
         return self.call("getProductFeatures")
 
+    def save(self, sc, path):
+        self.call("save", sc._jsc.sc(), path)
+
 
 class ALS(object):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/aedbbaa3/python/pyspark/mllib/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 4ed978b..17d43ea 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -168,6 +168,64 @@ class MLUtils(object):
         return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
 
 
+class Saveable(object):
+    """
+    Mixin for models and transformers which may be saved as files.
+    """
+
+    def save(self, sc, path):
+        """
+        Save this model to the given path.
+
+        This saves:
+         * human-readable (JSON) model metadata to path/metadata/
+         * Parquet formatted data to path/data/
+
+        The model may be loaded using py:meth:`Loader.load`.
+
+        :param sc: Spark context used to save model data.
+        :param path: Path specifying the directory in which to save
+                     this model. If the directory already exists,
+                     this method throws an exception.
+        """
+        raise NotImplementedError
+
+
+class Loader(object):
+    """
+    Mixin for classes which can load saved models from files.
+    """
+
+    @classmethod
+    def load(cls, sc, path):
+        """
+        Load a model from the given path. The model should have been
+        saved using py:meth:`Saveable.save`.
+
+        :param sc: Spark context used for loading model files.
+        :param path: Path specifying the directory to which the model
+                     was saved.
+        :return: model instance
+        """
+        raise NotImplemented
+
+
+class JavaLoader(Loader):
+    """
+    Mixin for classes which can load saved models using its Scala
+    implementation.
+    """
+
+    @classmethod
+    def load(cls, sc, path):
+        java_package = cls.__module__.replace("pyspark", "org.apache.spark")
+        java_class = ".".join([java_package, cls.__name__])
+        java_obj = sc._jvm
+        for name in java_class.split("."):
+            java_obj = getattr(java_obj, name)
+        return cls(java_obj.load(sc._jsc.sc(), path))
+
+
 def _test():
     import doctest
     from pyspark.context import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to