Skip to content
Snippets Groups Projects
Commit 8fbf72b7 authored by Davies Liu's avatar Davies Liu Committed by Xiangrui Meng
Browse files

[SPARK-4435] [MLlib] [PySpark] improve classification

This PR add setThrehold() and clearThreshold() for LogisticRegressionModel and SVMModel, also support RDD of vector in LogisticRegressionModel.predict(), SVNModel.predict() and NaiveBayes.predict()

Author: Davies Liu <davies@databricks.com>

Closes #3305 from davies/setThreshold and squashes the following commits:

d0b835f [Davies Liu] Merge branch 'master' of github.com:apache/spark into setThreshold
e4acd76 [Davies Liu] address comments
2231a5f [Davies Liu] bugfix
7bd9009 [Davies Liu] address comments
0b0a8a7 [Davies Liu] address comments
c1e5573 [Davies Liu] improve classification
parent cedc3b5a
No related branches found
No related tags found
No related merge requests found
......@@ -64,7 +64,7 @@ class LogisticRegressionModel (
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
val score = 1.0 / (1.0 + math.exp(-margin))
threshold match {
case Some(t) => if (score < t) 0.0 else 1.0
case Some(t) => if (score > t) 1.0 else 0.0
case None => score
}
}
......
......@@ -65,7 +65,7 @@ class SVMModel (
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match {
case Some(t) => if (margin < t) 0.0 else 1.0
case Some(t) => if (margin > t) 1.0 else 0.0
case None => margin
}
}
......
......@@ -20,6 +20,7 @@ from math import exp
import numpy
from numpy import array
from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
......@@ -29,39 +30,88 @@ __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
class LogisticRegressionModel(LinearModel):
class LinearBinaryClassificationModel(LinearModel):
"""
Represents a linear binary classification model that predicts to whether an
example is positive (1.0) or negative (0.0).
"""
def __init__(self, weights, intercept):
super(LinearBinaryClassificationModel, self).__init__(weights, intercept)
self._threshold = None
def setThreshold(self, value):
"""
:: Experimental ::
Sets the threshold that separates positive predictions from negative
predictions. An example with prediction score greater than or equal
to this threshold is identified as an positive, and negative otherwise.
"""
self._threshold = value
def clearThreshold(self):
"""
:: Experimental ::
Clears the threshold so that `predict` will output raw prediction scores.
"""
self._threshold = None
def predict(self, test):
"""
Predict values for a single data point or an RDD of points using
the model trained.
"""
raise NotImplementedError
class LogisticRegressionModel(LinearBinaryClassificationModel):
"""A linear binary classification model derived from logistic regression.
>>> data = [
... LabeledPoint(0.0, [0.0]),
... LabeledPoint(1.0, [1.0]),
... LabeledPoint(1.0, [2.0]),
... LabeledPoint(1.0, [3.0])
... LabeledPoint(0.0, [0.0, 1.0]),
... LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
>>> lrm.predict(array([1.0])) > 0
True
>>> lrm.predict(array([0.0])) <= 0
True
>>> lrm.predict([1.0, 0.0])
1
>>> lrm.predict([0.0, 1.0])
0
>>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
[1, 0]
>>> lrm.clearThreshold()
>>> lrm.predict([0.0, 1.0])
0.123...
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data))
>>> lrm.predict(array([0.0, 1.0])) > 0
True
>>> lrm.predict(array([0.0, 0.0])) <= 0
True
>>> lrm.predict(SparseVector(2, {1: 1.0})) > 0
True
>>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0
True
>>> lrm.predict(array([0.0, 1.0]))
1
>>> lrm.predict(array([1.0, 0.0]))
0
>>> lrm.predict(SparseVector(2, {1: 1.0}))
1
>>> lrm.predict(SparseVector(2, {0: 1.0}))
0
"""
def __init__(self, weights, intercept):
super(LogisticRegressionModel, self).__init__(weights, intercept)
self._threshold = 0.5
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
the model trained.
"""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
margin = self.weights.dot(x) + self._intercept
if margin > 0:
......@@ -69,7 +119,10 @@ class LogisticRegressionModel(LinearModel):
else:
exp_margin = exp(margin)
prob = exp_margin / (1 + exp_margin)
return 1 if prob > 0.5 else 0
if self._threshold is None:
return prob
else:
return 1 if prob > self._threshold else 0
class LogisticRegressionWithSGD(object):
......@@ -111,7 +164,7 @@ class LogisticRegressionWithSGD(object):
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
class SVMModel(LinearModel):
class SVMModel(LinearBinaryClassificationModel):
"""A support vector machine.
......@@ -122,8 +175,14 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, [3.0])
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(data))
>>> svm.predict(array([1.0])) > 0
True
>>> svm.predict([1.0])
1
>>> svm.predict(sc.parallelize([[1.0]])).collect()
[1]
>>> svm.clearThreshold()
>>> svm.predict(array([1.0]))
1.25...
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: -1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
......@@ -131,16 +190,29 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(sparse_data))
>>> svm.predict(SparseVector(2, {1: 1.0})) > 0
True
>>> svm.predict(SparseVector(2, {0: -1.0})) <= 0
True
>>> svm.predict(SparseVector(2, {1: 1.0}))
1
>>> svm.predict(SparseVector(2, {0: -1.0}))
0
"""
def __init__(self, weights, intercept):
super(SVMModel, self).__init__(weights, intercept)
self._threshold = 0.0
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
the model trained.
"""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
margin = self.weights.dot(x) + self.intercept
return 1 if margin >= 0 else 0
if self._threshold is None:
return margin
else:
return 1 if margin > self._threshold else 0
class SVMWithSGD(object):
......@@ -201,6 +273,8 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(array([1.0, 0.0]))
1.0
>>> model.predict(sc.parallelize([[1.0, 0.0]])).collect()
[1.0]
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
... LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
......@@ -219,7 +293,9 @@ class NaiveBayesModel(object):
self.theta = theta
def predict(self, x):
"""Return the most likely class for a data vector x"""
"""Return the most likely class for a data vector or an RDD of vectors"""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
......@@ -250,7 +326,8 @@ class NaiveBayes(object):
def _test():
import doctest
from pyspark import SparkContext
globs = globals().copy()
import pyspark.mllib.classification
globs = pyspark.mllib.classification.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment