Skip to content
Snippets Groups Projects
Commit 48866f78 authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-6095] [MLLIB] Support model save/load in Python's linear models

For Python's linear models, weights and intercept are stored in Python.
This PR implements Python's linear models sava/load functions which do the same thing as scala.
It can also make model import/export cross languages.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #5016 from yanboliang/spark-6095 and squashes the following commits:

d9bb824 [Yanbo Liang] fix python style
b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation
parent a7456459
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ import numpy ...@@ -21,7 +21,7 @@ import numpy
from numpy import array from numpy import array
from pyspark import RDD from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
...@@ -99,6 +99,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ...@@ -99,6 +99,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
1 1
>>> lrm.predict(SparseVector(2, {0: 1.0})) >>> lrm.predict(SparseVector(2, {0: 1.0}))
0 0
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
>>> sameModel = LogisticRegressionModel.load(sc, path)
>>> sameModel.predict(array([0.0, 1.0]))
1
>>> sameModel.predict(SparseVector(2, {0: 1.0}))
0
>>> try:
... os.removedirs(path)
... except:
... pass
""" """
def __init__(self, weights, intercept): def __init__(self, weights, intercept):
super(LogisticRegressionModel, self).__init__(weights, intercept) super(LogisticRegressionModel, self).__init__(weights, intercept)
...@@ -124,6 +136,22 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ...@@ -124,6 +136,22 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
else: else:
return 1 if prob > self._threshold else 0 return 1 if prob > self._threshold else 0
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
_py2java(sc, self._coeff), self.intercept)
java_model.save(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load(
sc._jsc.sc(), path)
weights = _java2py(sc, java_model.weights())
intercept = java_model.intercept()
threshold = java_model.getThreshold().get()
model = LogisticRegressionModel(weights, intercept)
model.setThreshold(threshold)
return model
class LogisticRegressionWithSGD(object): class LogisticRegressionWithSGD(object):
...@@ -243,6 +271,18 @@ class SVMModel(LinearBinaryClassificationModel): ...@@ -243,6 +271,18 @@ class SVMModel(LinearBinaryClassificationModel):
1 1
>>> svm.predict(SparseVector(2, {0: -1.0})) >>> svm.predict(SparseVector(2, {0: -1.0}))
0 0
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> svm.save(sc, path)
>>> sameModel = SVMModel.load(sc, path)
>>> sameModel.predict(SparseVector(2, {1: 1.0}))
1
>>> sameModel.predict(SparseVector(2, {0: -1.0}))
0
>>> try:
... os.removedirs(path)
... except:
... pass
""" """
def __init__(self, weights, intercept): def __init__(self, weights, intercept):
super(SVMModel, self).__init__(weights, intercept) super(SVMModel, self).__init__(weights, intercept)
...@@ -263,6 +303,22 @@ class SVMModel(LinearBinaryClassificationModel): ...@@ -263,6 +303,22 @@ class SVMModel(LinearBinaryClassificationModel):
else: else:
return 1 if margin > self._threshold else 0 return 1 if margin > self._threshold else 0
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel(
_py2java(sc, self._coeff), self.intercept)
java_model.save(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load(
sc._jsc.sc(), path)
weights = _java2py(sc, java_model.weights())
intercept = java_model.intercept()
threshold = java_model.getThreshold().get()
model = SVMModel(weights, intercept)
model.setThreshold(threshold)
return model
class SVMWithSGD(object): class SVMWithSGD(object):
......
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
import numpy as np import numpy as np
from numpy import array from numpy import array
from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.util import Saveable, Loader
__all__ = ['LabeledPoint', 'LinearModel', __all__ = ['LabeledPoint', 'LinearModel',
'LinearRegressionModel', 'LinearRegressionWithSGD', 'LinearRegressionModel', 'LinearRegressionWithSGD',
...@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase): ...@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
True True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
>>> sameModel = LinearRegressionModel.load(sc, path)
>>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
True
>>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
True
>>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
>>> try:
... os.removedirs(path)
... except:
... pass
>>> data = [ >>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
...@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase): ...@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True True
""" """
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
_py2java(sc, self._coeff), self.intercept)
java_model.save(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load(
sc._jsc.sc(), path)
weights = _java2py(sc, java_model.weights())
intercept = java_model.intercept()
model = LinearRegressionModel(weights, intercept)
return model
# train_func should take two parameters, namely data and initial_weights, and # train_func should take two parameters, namely data and initial_weights, and
...@@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase): ...@@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase):
True True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
>>> sameModel = LassoModel.load(sc, path)
>>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
True
>>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
True
>>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
>>> try:
... os.removedirs(path)
... except:
... pass
>>> data = [ >>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
...@@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase): ...@@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True True
""" """
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
_py2java(sc, self._coeff), self.intercept)
java_model.save(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(
sc._jsc.sc(), path)
weights = _java2py(sc, java_model.weights())
intercept = java_model.intercept()
model = LassoModel(weights, intercept)
return model
class LassoWithSGD(object): class LassoWithSGD(object):
...@@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase): ...@@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
>>> sameModel = RidgeRegressionModel.load(sc, path)
>>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
True
>>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
True
>>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
>>> try:
... os.removedirs(path)
... except:
... pass
>>> data = [ >>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
...@@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase): ...@@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True True
""" """
def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
_py2java(sc, self._coeff), self.intercept)
java_model.save(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load(
sc._jsc.sc(), path)
weights = _java2py(sc, java_model.weights())
intercept = java_model.intercept()
model = RidgeRegressionModel(weights, intercept)
return model
class RidgeRegressionWithSGD(object): class RidgeRegressionWithSGD(object):
......
...@@ -20,7 +20,6 @@ import warnings ...@@ -20,7 +20,6 @@ import warnings
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
class MLUtils(object): class MLUtils(object):
...@@ -50,6 +49,7 @@ class MLUtils(object): ...@@ -50,6 +49,7 @@ class MLUtils(object):
@staticmethod @staticmethod
def _convert_labeled_point_to_libsvm(p): def _convert_labeled_point_to_libsvm(p):
"""Converts a LabeledPoint to a string in LIBSVM format.""" """Converts a LabeledPoint to a string in LIBSVM format."""
from pyspark.mllib.regression import LabeledPoint
assert isinstance(p, LabeledPoint) assert isinstance(p, LabeledPoint)
items = [str(p.label)] items = [str(p.label)]
v = _convert_to_vector(p.features) v = _convert_to_vector(p.features)
...@@ -92,6 +92,7 @@ class MLUtils(object): ...@@ -92,6 +92,7 @@ class MLUtils(object):
>>> from tempfile import NamedTemporaryFile >>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.util import MLUtils
>>> from pyspark.mllib.regression import LabeledPoint
>>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush() >>> tempFile.flush()
...@@ -110,6 +111,7 @@ class MLUtils(object): ...@@ -110,6 +111,7 @@ class MLUtils(object):
>>> print examples[2] >>> print examples[2]
(-1.0,(6,[1,3,5],[4.0,5.0,6.0])) (-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
""" """
from pyspark.mllib.regression import LabeledPoint
if multiclass is not None: if multiclass is not None:
warnings.warn("deprecated", DeprecationWarning) warnings.warn("deprecated", DeprecationWarning)
...@@ -130,6 +132,7 @@ class MLUtils(object): ...@@ -130,6 +132,7 @@ class MLUtils(object):
>>> from tempfile import NamedTemporaryFile >>> from tempfile import NamedTemporaryFile
>>> from fileinput import input >>> from fileinput import input
>>> from pyspark.mllib.regression import LabeledPoint
>>> from glob import glob >>> from glob import glob
>>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.util import MLUtils
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
...@@ -156,6 +159,7 @@ class MLUtils(object): ...@@ -156,6 +159,7 @@ class MLUtils(object):
>>> from tempfile import NamedTemporaryFile >>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.util import MLUtils
>>> from pyspark.mllib.regression import LabeledPoint
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \
LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
>>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile = NamedTemporaryFile(delete=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment