Repository: spark Updated Branches: refs/heads/master f96b85ab4 -> 8f4aaba0e
[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD. Author: FlytxtRnD <[email protected]> Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits: 4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f4aaba0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f4aaba0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f4aaba0 Branch: refs/heads/master Commit: 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2 Parents: f96b85a Author: FlytxtRnD <[email protected]> Authored: Fri May 15 10:43:18 2015 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Fri May 15 10:43:18 2015 -0700 ---------------------------------------------------------------------- python/pyspark/mllib/clustering.py | 6 ++++++ 1 file changed, 6 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8f4aaba0/python/pyspark/mllib/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a53333d..b55583f 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -212,6 +212,9 @@ class GaussianMixtureModel(object): if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) def predictSoft(self, x): """ @@ -225,6 +228,9 @@ class GaussianMixtureModel(object): membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), _convert_to_vector(self._weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) class GaussianMixture(object): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
