Repository: spark
Updated Branches:
  refs/heads/master 5e6ad24ff -> 25636d986


[Spark 6096][MLlib] Add Naive Bayes load save methods in Python

See [SPARK-6096](https://issues.apache.org/jira/browse/SPARK-6096).

Author: Xusen Yin <[email protected]>

Closes #5090 from yinxusen/SPARK-6096 and squashes the following commits:

bd0fea5 [Xusen Yin] fix style problem, etc.
3fd41f2 [Xusen Yin] use hanging indent in Python style
e83803d [Xusen Yin] fix Python style
d6dbde5 [Xusen Yin] fix python call java error
a054bb3 [Xusen Yin] add save load for NaiveBayes python


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25636d98
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25636d98
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25636d98

Branch: refs/heads/master
Commit: 25636d9867c6bc901463b6b227cb444d701cfdd1
Parents: 5e6ad24
Author: Xusen Yin <[email protected]>
Authored: Fri Mar 20 14:53:59 2015 -0400
Committer: Xiangrui Meng <[email protected]>
Committed: Fri Mar 20 14:53:59 2015 -0400

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala | 11 +++++++
 python/pyspark/mllib/classification.py          | 31 +++++++++++++++++++-
 2 files changed, 41 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/25636d98/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 068449a..d60e82c 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.mllib.classification
 
+import java.lang.{Iterable => JIterable}
+
+import scala.collection.JavaConverters._
+
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => 
brzArgmax, sum => brzSum}
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
@@ -41,6 +45,13 @@ class NaiveBayesModel private[mllib] (
     val pi: Array[Double],
     val theta: Array[Array[Double]]) extends ClassificationModel with 
Serializable with Saveable {
 
+  /** A Java-friendly constructor that takes three Iterable parameters. */
+  private[mllib] def this(
+      labels: JIterable[Double],
+      pi: JIterable[Double],
+      theta: JIterable[JIterable[Double]]) =
+    this(labels.asScala.toArray, pi.asScala.toArray, 
theta.asScala.toArray.map(_.asScala.toArray))
+
   private val brzPi = new BDV[Double](pi)
   private val brzTheta = new BDM[Double](theta.length, theta(0).length)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/25636d98/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py 
b/python/pyspark/mllib/classification.py
index b66159c..6766f3e 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -24,6 +24,7 @@ from pyspark import RDD
 from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
 from pyspark.mllib.linalg import SparseVector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint, LinearModel, 
_regression_train_wrapper
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
 
 
 __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 
'LogisticRegressionWithLBFGS',
@@ -359,7 +360,8 @@ class SVMWithSGD(object):
         return _regression_train_wrapper(train, SVMModel, data, initialWeights)
 
 
-class NaiveBayesModel(object):
+@inherit_doc
+class NaiveBayesModel(Saveable, Loader):
 
     """
     Model for Naive Bayes classifiers.
@@ -390,6 +392,16 @@ class NaiveBayesModel(object):
     0.0
     >>> model.predict(SparseVector(2, {0: 1.0}))
     1.0
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = NaiveBayesModel.load(sc, path)
+    >>> sameModel.predict(SparseVector(2, {0: 1.0})) == 
model.predict(SparseVector(2, {0: 1.0}))
+    True
+    >>> try:
+    ...     os.removedirs(path)
+    ... except OSError:
+    ...     pass
     """
 
     def __init__(self, labels, pi, theta):
@@ -404,6 +416,23 @@ class NaiveBayesModel(object):
         x = _convert_to_vector(x)
         return self.labels[numpy.argmax(self.pi + 
x.dot(self.theta.transpose()))]
 
+    def save(self, sc, path):
+        java_labels = _py2java(sc, self.labels.tolist())
+        java_pi = _py2java(sc, self.pi.tolist())
+        java_theta = _py2java(sc, self.theta.tolist())
+        java_model = 
sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel(
+            java_labels, java_pi, java_theta)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = 
sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
+            sc._jsc.sc(), path)
+        py_labels = _java2py(sc, java_model.labels())
+        py_pi = _java2py(sc, java_model.pi())
+        py_theta = _java2py(sc, java_model.theta())
+        return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))
+
 
 class NaiveBayes(object):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to