Repository: spark Updated Branches: refs/heads/master d9f3e0168 -> c94d06264
[SPARK-6226][MLLIB] add save/load in PySpark's KMeansModel Use `_py2java` and `_java2py` to convert Python model to/from Java model. yinxusen Author: Xiangrui Meng <[email protected]> Closes #5049 from mengxr/SPARK-6226-mengxr and squashes the following commits: 570ba81 [Xiangrui Meng] fix python style b10b911 [Xiangrui Meng] add save/load in PySpark's KMeansModel Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c94d0626 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c94d0626 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c94d0626 Branch: refs/heads/master Commit: c94d0626471e209ab7ebfc588f9a2992946b7ed5 Parents: d9f3e01 Author: Xiangrui Meng <[email protected]> Authored: Tue Mar 17 12:14:40 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Tue Mar 17 12:14:40 2015 -0700 ---------------------------------------------------------------------- .../spark/mllib/clustering/KMeansModel.scala | 5 ++++ python/pyspark/mllib/clustering.py | 28 +++++++++++++++++--- python/pyspark/mllib/common.py | 4 +-- 3 files changed, 32 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c94d0626/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 707da53..e4e411a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import scala.collection.JavaConverters._ + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -34,6 +36,9 @@ import org.apache.spark.sql.Row */ class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable { + /** A Java-friendly constructor that takes an Iterable of Vectors. */ + def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) + /** Total number of clusters. */ def k: Int = clusterCenters.length http://git-wip-us.apache.org/repos/asf/spark/blob/c94d0626/python/pyspark/mllib/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 949db57..464f49a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -19,14 +19,16 @@ from numpy import array from pyspark import RDD from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc -from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py +from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] -class KMeansModel(object): +@inherit_doc +class KMeansModel(Saveable, Loader): """A clustering model derived from the k-means method. @@ -55,6 +57,16 @@ class KMeansModel(object): True >>> type(model.clusterCenters) <type 'list'> + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = KMeansModel.load(sc, path) + >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self, centers): @@ -77,6 +89,16 @@ class KMeansModel(object): best_distance = distance return best + def save(self, sc, path): + java_centers = _py2java(sc, map(_convert_to_vector, self.centers)) + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path) + return KMeansModel(_java2py(sc, java_model.clusterCenters())) + class KMeans(object): http://git-wip-us.apache.org/repos/asf/spark/blob/c94d0626/python/pyspark/mllib/common.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 621591c..a539d2f 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -70,8 +70,8 @@ def _py2java(sc, obj): obj = _to_java_object_rdd(obj) elif isinstance(obj, SparkContext): obj = obj._jsc - elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)): - obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, list): + obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, basestring)): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
