Hi all, We're trying to train a Gaussian Mixture Model (GMM) with a specified initial model. Doc 1.5.1 says we should use a GaussianMixtureModel object as input for the "initialModel" parameter to the GaussianMixture.train method. Before creating our own initial model (the plan is to use a Kmean result for instance), we simply wanted to test case this scenario. So we try to initialize a 2nd training using the GaussianMixtureModel from the output a 1st training. But this trivial scenario throws an error. Could you please help us determine what's going on here ? Thanks a lot guillaume
PS: we run (py)spark 1.5.1 with hadoop 2.6 Below is the trivial scenario code and the error: ============================================ SOURCE CODE from pyspark.mllib.clustering import GaussianMixture from numpy import array import sys import os import pyspark ### Local default options K=2 # "k" (int) Set the number of Gaussians in the mixture model. Default: 2 convergenceTol=1e-3 # "convergenceTol" (double) Set the largest change in log-likelihood at which convergence is considered to have occurred. maxIterations=100 # "maxIterations" (int) Set the maximum number of iterations to run. Default: 100 seed=None # "seed" (long) Set the random seed initialModel=None ### Load and parse the sample data data = sc.textFile("gmm_data.txt") # Data from the dummy set here: data/mllib/gmm_data.txt parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) print type(parsedData) print type(parsedData.first()) ### 1st training: Build the GMM gmm = GaussianMixture.train(parsedData, K, convergenceTol, maxIterations, seed, initialModel) # output parameters of model for i in range(2): print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, "sigma = ", gmm.gaussians[i].sigma.toArray()) ### 2nd training: Re-build a GMM using an initial model initialModel = gmm print type(initialModel) gmm = GaussianMixture.train(parsedData, K, convergenceTol, maxIterations, seed, initialModel) ============================================ OUTPUT WITH ERROR: <class 'pyspark.rdd.PipelinedRDD'> <type 'numpy.ndarray'> ('weight = ', 0.51945003367044018, 'mu = ', DenseVector([-0.1045, 0.0429]), 'sigma = ', array([[ 4.90706817, -2.00676881], [-2.00676881, 1.01143891]])) ('weight = ', 0.48054996632955982, 'mu = ', DenseVector([0.0722, 0.0167]), 'sigma = ', array([[ 4.77975653, 1.87624558], [ 1.87624558, 0.91467242]])) <class 'pyspark.mllib.clustering.GaussianMixtureModel'> --------------------------------------------------------------------------- Py4JJavaError Traceback (most recent call last) <ipython-input-30-0008fe75eb61> in <module>() 33 initialModel = gmm 34 print type(initialModel) ---> 35 gmm = GaussianMixture.train(parsedData, K, convergenceTol, maxIterations, seed, initialModel) # /opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/clustering.pyc in train(cls, rdd, k, convergenceTol, maxIterations, seed, initialModel) 306 java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), 307 k, convergenceTol, maxIterations, seed, --> 308 initialModelWeights, initialModelMu, initialModelSigma) 309 return GaussianMixtureModel(java_model) 310 /opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/common.pyc in callMLlibFunc(name, *args) 128 sc = SparkContext._active_spark_context 129 api = getattr(sc._jvm.PythonMLLibAPI(), name) --> 130 return callJavaFunc(sc, api, *args) 131 132 /opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/common.pyc in callJavaFunc(sc, func, *args) 120 def callJavaFunc(sc, func, *args): 121 """ Call Java Function """ --> 122 args = [_py2java(sc, a) for a in args] 123 return _java2py(sc, func(*args)) 124 /opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/mllib/common.pyc in _py2java(sc, obj) 86 else: 87 data = bytearray(PickleSerializer().dumps(obj)) ---> 88 obj = sc._jvm.SerDe.loads(data) 89 return obj 90 /opt/spark/spark-1.5.1-bin-hadoop2.6/python/lib/py4j-0.8.2.1-src.zip/py4j/java_gateway.py in __call__(self, *args) 536 answer = self.gateway_client.send_command(command) 537 return_value = get_return_value(answer, self.gateway_client, --> 538 self.target_id, self.name) 539 540 for temp_arg in temp_args: /opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/sql/utils.pyc in deco(*a, **kw) 34 def deco(*a, **kw): 35 try: ---> 36 return f(*a, **kw) 37 except py4j.protocol.Py4JJavaError as e: 38 s = e.java_exception.toString() /opt/spark/spark-1.5.1-bin-hadoop2.6/python/lib/py4j-0.8.2.1-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 298 raise Py4JJavaError( 299 'An error occurred while calling {0}{1}{2}.\n'. --> 300 format(target_id, '.', name), value) 301 else: 302 raise Py4JError( Py4JJavaError: An error occurred while calling z:org.apache.spark.mllib.api.python.SerDe.loads. : net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct) at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23) at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:701) at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:171) at net.razorvine.pickle.Unpickler.load(Unpickler.java:85) at net.razorvine.pickle.Unpickler.loads(Unpickler.java:98) at org.apache.spark.mllib.api.python.SerDe$.loads(PythonMLLibAPI.scala:1462) at org.apache.spark.mllib.api.python.SerDe.loads(PythonMLLibAPI.scala) at sun.reflect.GeneratedMethodAccessor31.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:379) at py4j.Gateway.invoke(Gateway.java:259) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:207) at java.lang.Thread.run(Thread.java:745) --------------------------------------------------------------------------- Guillaume Maze, Research Scientist at Ifremer LPO, UMR 6523: wwz.ifremer.fr/lpo SO Argo-France: wwz.ifremer.fr/lpo/SO-Argo +33 (0)2 98 22 43 39 http://www.guillaumemaze.org ---------------------------------------------------------------------------