Repository: spark Updated Branches: refs/heads/branch-1.6 c3135d021 -> 175681914
[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None If initial model passed to GMM is not empty it causes `net.razorvine.pickle.PickleException`. It can be fixed by converting `initialModel.weights` to `list`. Author: zero323 <[email protected]> Closes #9986 from zero323/SPARK-12006. (cherry picked from commit fcd013cf70e7890aa25a8fe3cb6c8b36bf0e1f04) Signed-off-by: Joseph K. Bradley <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/17568191 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/17568191 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/17568191 Branch: refs/heads/branch-1.6 Commit: 175681914af953b7ce1b2971fef83a2445de1f94 Parents: c3135d0 Author: zero323 <[email protected]> Authored: Wed Jan 6 11:58:33 2016 -0800 Committer: Joseph K. Bradley <[email protected]> Committed: Wed Jan 6 11:58:42 2016 -0800 ---------------------------------------------------------------------- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/17568191/python/pyspark/mllib/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c9e6f1d..48daa87 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -346,7 +346,7 @@ class GaussianMixture(object): if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = initialModel.weights + initialModelWeights = list(initialModel.weights) initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), http://git-wip-us.apache.org/repos/asf/spark/blob/17568191/python/pyspark/mllib/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f8e8e0e..e9e7a90 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -474,6 +474,18 @@ class ListTests(MLlibTestCase): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
