Skip to content
Snippets Groups Projects
Commit c919798f authored by Xusen Yin's avatar Xusen Yin Committed by Patrick Wendell
Browse files

fix bugs of dot in python

If there are no `transpose()` in `self.theta`, a

*ValueError: matrices are not aligned*

is occurring. The former test case just ignore this situation.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #463 from yinxusen/python-naive-bayes and squashes the following commits:

fcbe3bc [Xusen Yin] fix bugs of dot in python
parent 0f87e6ad
No related branches found
No related tags found
No related merge requests found
...@@ -154,7 +154,7 @@ class NaiveBayesModel(object): ...@@ -154,7 +154,7 @@ class NaiveBayesModel(object):
def predict(self, x): def predict(self, x):
"""Return the most likely class for a data vector x""" """Return the most likely class for a data vector x"""
return self.labels[numpy.argmax(self.pi + _dot(x, self.theta))] return self.labels[numpy.argmax(self.pi + _dot(x, self.theta.transpose()))]
class NaiveBayes(object): class NaiveBayes(object):
@classmethod @classmethod
......
...@@ -104,10 +104,10 @@ class ListTests(PySparkTestCase): ...@@ -104,10 +104,10 @@ class ListTests(PySparkTestCase):
def test_classification(self): def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
data = [ data = [
LabeledPoint(0.0, [1, 0]), LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1]), LabeledPoint(1.0, [0, 1, 1]),
LabeledPoint(0.0, [2, 0]), LabeledPoint(0.0, [2, 0, 0]),
LabeledPoint(1.0, [0, 2]) LabeledPoint(1.0, [0, 2, 1])
] ]
rdd = self.sc.parallelize(data) rdd = self.sc.parallelize(data)
features = [p.features.tolist() for p in data] features = [p.features.tolist() for p in data]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment