From 48866f789712b0cdbaf76054d1014c6df032fff1 Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Fri, 20 Mar 2015 14:44:21 -0400
Subject: [PATCH] [SPARK-6095] [MLLIB] Support model save/load in Python's
 linear models

For Python's linear models, weights and intercept are stored in Python.
This PR implements Python's linear models sava/load functions which do the same thing as scala.
It can also make model import/export cross languages.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #5016 from yanboliang/spark-6095 and squashes the following commits:

d9bb824 [Yanbo Liang] fix python style
b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation
---
 python/pyspark/mllib/classification.py | 58 +++++++++++++++++-
 python/pyspark/mllib/regression.py     | 84 +++++++++++++++++++++++++-
 python/pyspark/mllib/util.py           |  6 +-
 3 files changed, 145 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index e476517370..b66159c5bf 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -21,7 +21,7 @@ import numpy
 from numpy import array
 
 from pyspark import RDD
-from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
 from pyspark.mllib.linalg import SparseVector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
 
@@ -99,6 +99,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
     1
     >>> lrm.predict(SparseVector(2, {0: 1.0}))
     0
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> lrm.save(sc, path)
+    >>> sameModel = LogisticRegressionModel.load(sc, path)
+    >>> sameModel.predict(array([0.0, 1.0]))
+    1
+    >>> sameModel.predict(SparseVector(2, {0: 1.0}))
+    0
+    >>> try:
+    ...    os.removedirs(path)
+    ... except:
+    ...    pass
     """
     def __init__(self, weights, intercept):
         super(LogisticRegressionModel, self).__init__(weights, intercept)
@@ -124,6 +136,22 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
         else:
             return 1 if prob > self._threshold else 0
 
+    def save(self, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
+            _py2java(sc, self._coeff), self.intercept)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load(
+            sc._jsc.sc(), path)
+        weights = _java2py(sc, java_model.weights())
+        intercept = java_model.intercept()
+        threshold = java_model.getThreshold().get()
+        model = LogisticRegressionModel(weights, intercept)
+        model.setThreshold(threshold)
+        return model
+
 
 class LogisticRegressionWithSGD(object):
 
@@ -243,6 +271,18 @@ class SVMModel(LinearBinaryClassificationModel):
     1
     >>> svm.predict(SparseVector(2, {0: -1.0}))
     0
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> svm.save(sc, path)
+    >>> sameModel = SVMModel.load(sc, path)
+    >>> sameModel.predict(SparseVector(2, {1: 1.0}))
+    1
+    >>> sameModel.predict(SparseVector(2, {0: -1.0}))
+    0
+    >>> try:
+    ...    os.removedirs(path)
+    ... except:
+    ...    pass
     """
     def __init__(self, weights, intercept):
         super(SVMModel, self).__init__(weights, intercept)
@@ -263,6 +303,22 @@ class SVMModel(LinearBinaryClassificationModel):
         else:
             return 1 if margin > self._threshold else 0
 
+    def save(self, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel(
+            _py2java(sc, self._coeff), self.intercept)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load(
+            sc._jsc.sc(), path)
+        weights = _java2py(sc, java_model.weights())
+        intercept = java_model.intercept()
+        threshold = java_model.getThreshold().get()
+        model = SVMModel(weights, intercept)
+        model.setThreshold(threshold)
+        return model
+
 
 class SVMWithSGD(object):
 
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 0c21ad5787..015a786011 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,8 +18,9 @@
 import numpy as np
 from numpy import array
 
-from pyspark.mllib.common import callMLlibFunc, inherit_doc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
 from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.util import Saveable, Loader
 
 __all__ = ['LabeledPoint', 'LinearModel',
            'LinearRegressionModel', 'LinearRegressionWithSGD',
@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
     True
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> lrm.save(sc, path)
+    >>> sameModel = LinearRegressionModel.load(sc, path)
+    >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+    True
+    >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+    True
+    >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+    True
+    >>> try:
+    ...    os.removedirs(path)
+    ... except:
+    ...    pass
     >>> data = [
     ...     LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
     ...     LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
     """
+    def save(self, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
+            _py2java(sc, self._coeff), self.intercept)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load(
+            sc._jsc.sc(), path)
+        weights = _java2py(sc, java_model.weights())
+        intercept = java_model.intercept()
+        model = LinearRegressionModel(weights, intercept)
+        return model
 
 
 # train_func should take two parameters, namely data and initial_weights, and
@@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase):
     True
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> lrm.save(sc, path)
+    >>> sameModel = LassoModel.load(sc, path)
+    >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+    True
+    >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+    True
+    >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+    True
+    >>> try:
+    ...    os.removedirs(path)
+    ... except:
+    ...    pass
     >>> data = [
     ...     LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
     ...     LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase):
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
     """
+    def save(self, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
+            _py2java(sc, self._coeff), self.intercept)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(
+            sc._jsc.sc(), path)
+        weights = _java2py(sc, java_model.weights())
+        intercept = java_model.intercept()
+        model = LassoModel(weights, intercept)
+        return model
 
 
 class LassoWithSGD(object):
@@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase):
     True
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> lrm.save(sc, path)
+    >>> sameModel = RidgeRegressionModel.load(sc, path)
+    >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+    True
+    >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+    True
+    >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+    True
+    >>> try:
+    ...    os.removedirs(path)
+    ... except:
+    ...    pass
     >>> data = [
     ...     LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
     ...     LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase):
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
     """
+    def save(self, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
+            _py2java(sc, self._coeff), self.intercept)
+        java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load(
+            sc._jsc.sc(), path)
+        weights = _java2py(sc, java_model.weights())
+        intercept = java_model.intercept()
+        model = RidgeRegressionModel(weights, intercept)
+        return model
 
 
 class RidgeRegressionWithSGD(object):
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index e877c720ac..c5c3468eb9 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -20,7 +20,6 @@ import warnings
 
 from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
 from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
-from pyspark.mllib.regression import LabeledPoint
 
 
 class MLUtils(object):
@@ -50,6 +49,7 @@ class MLUtils(object):
     @staticmethod
     def _convert_labeled_point_to_libsvm(p):
         """Converts a LabeledPoint to a string in LIBSVM format."""
+        from pyspark.mllib.regression import LabeledPoint
         assert isinstance(p, LabeledPoint)
         items = [str(p.label)]
         v = _convert_to_vector(p.features)
@@ -92,6 +92,7 @@ class MLUtils(object):
 
         >>> from tempfile import NamedTemporaryFile
         >>> from pyspark.mllib.util import MLUtils
+        >>> from pyspark.mllib.regression import LabeledPoint
         >>> tempFile = NamedTemporaryFile(delete=True)
         >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
         >>> tempFile.flush()
@@ -110,6 +111,7 @@ class MLUtils(object):
         >>> print examples[2]
         (-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
         """
+        from pyspark.mllib.regression import LabeledPoint
         if multiclass is not None:
             warnings.warn("deprecated", DeprecationWarning)
 
@@ -130,6 +132,7 @@ class MLUtils(object):
 
         >>> from tempfile import NamedTemporaryFile
         >>> from fileinput import input
+        >>> from pyspark.mllib.regression import LabeledPoint
         >>> from glob import glob
         >>> from pyspark.mllib.util import MLUtils
         >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
@@ -156,6 +159,7 @@ class MLUtils(object):
 
         >>> from tempfile import NamedTemporaryFile
         >>> from pyspark.mllib.util import MLUtils
+        >>> from pyspark.mllib.regression import LabeledPoint
         >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \
                         LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
         >>> tempFile = NamedTemporaryFile(delete=True)
-- 
GitLab