Repository: spark
Updated Branches:
  refs/heads/master f96b85ab4 -> 8f4aaba0e


[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on 
bad input

In the Python API for Gaussian Mixture Model, predict() and predictSoft() 
methods should raise an error when the input argument is not an RDD.

Author: FlytxtRnD <[email protected]>

Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following 
commits:

4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not 
an RDD


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f4aaba0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f4aaba0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f4aaba0

Branch: refs/heads/master
Commit: 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2
Parents: f96b85a
Author: FlytxtRnD <[email protected]>
Authored: Fri May 15 10:43:18 2015 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Fri May 15 10:43:18 2015 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/clustering.py | 6 ++++++
 1 file changed, 6 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f4aaba0/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py 
b/python/pyspark/mllib/clustering.py
index a53333d..b55583f 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
         if isinstance(x, RDD):
             cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
             return cluster_labels
+        else:
+            raise TypeError("x should be represented by an RDD, "
+                            "but got %s." % type(x))
 
     def predictSoft(self, x):
         """
@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
             membership_matrix = callMLlibFunc("predictSoftGMM", 
x.map(_convert_to_vector),
                                               
_convert_to_vector(self._weights), means, sigmas)
             return membership_matrix.map(lambda x: pyarray.array('d', x))
+        else:
+            raise TypeError("x should be represented by an RDD, "
+                            "but got %s." % type(x))
 
 
 class GaussianMixture(object):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to