Repository: spark
Updated Branches:
  refs/heads/master d9f3e0168 -> c94d06264


[SPARK-6226][MLLIB] add save/load in PySpark's KMeansModel

Use `_py2java` and `_java2py` to convert Python model to/from Java model. 
yinxusen

Author: Xiangrui Meng <[email protected]>

Closes #5049 from mengxr/SPARK-6226-mengxr and squashes the following commits:

570ba81 [Xiangrui Meng] fix python style
b10b911 [Xiangrui Meng] add save/load in PySpark's KMeansModel


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c94d0626
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c94d0626
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c94d0626

Branch: refs/heads/master
Commit: c94d0626471e209ab7ebfc588f9a2992946b7ed5
Parents: d9f3e01
Author: Xiangrui Meng <[email protected]>
Authored: Tue Mar 17 12:14:40 2015 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Mar 17 12:14:40 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/clustering/KMeansModel.scala    |  5 ++++
 python/pyspark/mllib/clustering.py              | 28 +++++++++++++++++---
 python/pyspark/mllib/common.py                  |  4 +--
 3 files changed, 32 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c94d0626/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 707da53..e4e411a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.clustering
 
+import scala.collection.JavaConverters._
+
 import org.json4s._
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
@@ -34,6 +36,9 @@ import org.apache.spark.sql.Row
  */
 class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with 
Serializable {
 
+  /** A Java-friendly constructor that takes an Iterable of Vectors. */
+  def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
+
   /** Total number of clusters. */
   def k: Int = clusterCenters.length
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c94d0626/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py 
b/python/pyspark/mllib/clustering.py
index 949db57..464f49a 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -19,14 +19,16 @@ from numpy import array
 
 from pyspark import RDD
 from pyspark import SparkContext
-from pyspark.mllib.common import callMLlibFunc, callJavaFunc
-from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, 
_java2py
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
 from pyspark.mllib.stat.distribution import MultivariateGaussian
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
 
 __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
 
 
-class KMeansModel(object):
+@inherit_doc
+class KMeansModel(Saveable, Loader):
 
     """A clustering model derived from the k-means method.
 
@@ -55,6 +57,16 @@ class KMeansModel(object):
     True
     >>> type(model.clusterCenters)
     <type 'list'>
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = KMeansModel.load(sc, path)
+    >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
+    True
+    >>> try:
+    ...     os.removedirs(path)
+    ... except OSError:
+    ...     pass
     """
 
     def __init__(self, centers):
@@ -77,6 +89,16 @@ class KMeansModel(object):
                 best_distance = distance
         return best
 
+    def save(self, sc, path):
+        java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
+        java_model = 
sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = 
sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path)
+        return KMeansModel(_java2py(sc, java_model.clusterCenters()))
+
 
 class KMeans(object):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c94d0626/python/pyspark/mllib/common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 621591c..a539d2f 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -70,8 +70,8 @@ def _py2java(sc, obj):
         obj = _to_java_object_rdd(obj)
     elif isinstance(obj, SparkContext):
         obj = obj._jsc
-    elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)):
-        obj = ListConverter().convert(obj, sc._gateway._gateway_client)
+    elif isinstance(obj, list):
+        obj = ListConverter().convert([_py2java(sc, x) for x in obj], 
sc._gateway._gateway_client)
     elif isinstance(obj, JavaObject):
         pass
     elif isinstance(obj, (int, long, float, bool, basestring)):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to