Repository: spark
Updated Branches:
  refs/heads/branch-1.4 d1f565100 -> dfdae5800


[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on 
bad input

In the Python API for Gaussian Mixture Model, predict() and predictSoft() 
methods should raise an error when the input argument is not an RDD.

Author: FlytxtRnD <[email protected]>

Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following 
commits:

4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not 
an RDD

(cherry picked from commit 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2)
Signed-off-by: Joseph K. Bradley <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dfdae580
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dfdae580
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dfdae580

Branch: refs/heads/branch-1.4
Commit: dfdae5800c6a4f9b8e941138f61b784b24b0b00b
Parents: d1f5651
Author: FlytxtRnD <[email protected]>
Authored: Fri May 15 10:43:18 2015 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Fri May 15 10:43:26 2015 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/clustering.py | 6 ++++++
 1 file changed, 6 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dfdae580/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py 
b/python/pyspark/mllib/clustering.py
index a53333d..b55583f 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
         if isinstance(x, RDD):
             cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
             return cluster_labels
+        else:
+            raise TypeError("x should be represented by an RDD, "
+                            "but got %s." % type(x))
 
     def predictSoft(self, x):
         """
@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
             membership_matrix = callMLlibFunc("predictSoftGMM", 
x.map(_convert_to_vector),
                                               
_convert_to_vector(self._weights), means, sigmas)
             return membership_matrix.map(lambda x: pyarray.array('d', x))
+        else:
+            raise TypeError("x should be represented by an RDD, "
+                            "but got %s." % type(x))
 
 
 class GaussianMixture(object):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to