You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/07/17 05:12:22 UTC

git commit: [SPARK-2433][MLLIB] fix NaiveBayesModel.predict

Repository: spark
Updated Branches:
  refs/heads/branch-0.9 8e5604b22 -> 0116dee7e


[SPARK-2433][MLLIB] fix NaiveBayesModel.predict

This is the same as https://github.com/apache/spark/pull/463 , which I forgot to merge into branch-0.9.

Author: Xiangrui Meng <me...@databricks.com>

Closes #1453 from mengxr/nb-transpose-0.9 and squashes the following commits:

bc53ce8 [Xiangrui Meng] fix NaiveBayes


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0116dee7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0116dee7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0116dee7

Branch: refs/heads/branch-0.9
Commit: 0116dee7e041da408865dd667377afe222367348
Parents: 8e5604b
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Jul 16 20:12:09 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jul 16 20:12:09 2014 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/classification.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0116dee7/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 19b90df..f6c96e3 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -84,7 +84,7 @@ class NaiveBayesModel(object):
     - pi: vector of logs of class priors (dimension C)
     - theta: matrix of logs of class conditional probabilities (CxD)
 
-    >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
+    >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 1.0]).reshape(3,3)
     >>> model = NaiveBayes.train(sc.parallelize(data))
     >>> model.predict(array([0.0, 1.0]))
     0
@@ -98,7 +98,7 @@ class NaiveBayesModel(object):
 
     def predict(self, x):
         """Return the most likely class for a data vector x"""
-        return numpy.argmax(self.pi + dot(x, self.theta))
+        return numpy.argmax(self.pi + dot(x, self.theta.transpose()))
 
 class NaiveBayes(object):
     @classmethod