Repository: spark Updated Branches: refs/heads/branch-0.9 8e5604b22 -> 0116dee7e
[SPARK-2433][MLLIB] fix NaiveBayesModel.predict This is the same as https://github.com/apache/spark/pull/463 , which I forgot to merge into branch-0.9. Author: Xiangrui Meng <[email protected]> Closes #1453 from mengxr/nb-transpose-0.9 and squashes the following commits: bc53ce8 [Xiangrui Meng] fix NaiveBayes Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0116dee7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0116dee7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0116dee7 Branch: refs/heads/branch-0.9 Commit: 0116dee7e041da408865dd667377afe222367348 Parents: 8e5604b Author: Xiangrui Meng <[email protected]> Authored: Wed Jul 16 20:12:09 2014 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Wed Jul 16 20:12:09 2014 -0700 ---------------------------------------------------------------------- python/pyspark/mllib/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0116dee7/python/pyspark/mllib/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 19b90df..f6c96e3 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -84,7 +84,7 @@ class NaiveBayesModel(object): - pi: vector of logs of class priors (dimension C) - theta: matrix of logs of class conditional probabilities (CxD) - >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3) + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 1.0]).reshape(3,3) >>> model = NaiveBayes.train(sc.parallelize(data)) >>> model.predict(array([0.0, 1.0])) 0 @@ -98,7 +98,7 @@ class NaiveBayesModel(object): def predict(self, x): """Return the most likely class for a data vector x""" - return numpy.argmax(self.pi + dot(x, self.theta)) + return numpy.argmax(self.pi + dot(x, self.theta.transpose())) class NaiveBayes(object): @classmethod
