Skip to content
Snippets Groups Projects
Commit aedbbaa3 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-6053][MLLIB] support save/load in PySpark's ALS

A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #4811 from mengxr/SPARK-5991 and squashes the following commits:

f135dac [Xiangrui Meng] update save doc
57e5200 [Xiangrui Meng] address comments
06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5991
282ec8d [Xiangrui Meng] support save/load in PySpark's ALS
parent fd8d283e
No related branches found
No related tags found
No related merge requests found
......@@ -200,10 +200,8 @@ In the following example we load rating data. Each row consists of a user, a pro
We use the default ALS.train() method which assumes ratings are explicit. We evaluate the
recommendation by measuring the Mean Squared Error of rating prediction.
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.recommendation import ALS, Rating
from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
# Load and parse the data
data = sc.textFile("data/mllib/als/test.data")
......@@ -220,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
print("Mean Squared Error = " + str(MSE))
# Save and load model
model.save(sc, "myModelPath")
sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
{% endhighlight %}
If the rating matrix is derived from other source of information (i.e., it is inferred from other
......
......@@ -48,7 +48,7 @@ trait Saveable {
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
* This directory and any intermediate directory will be created if needed.
* If the directory already exists, this method throws an exception.
*/
def save(sc: SparkContext, path: String): Unit
......
......@@ -19,7 +19,8 @@ from collections import namedtuple
from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
from pyspark.mllib.util import Saveable, JavaLoader
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
......@@ -39,7 +40,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
return Rating, (int(self.user), int(self.product), float(self.rating))
class MatrixFactorizationModel(JavaModelWrapper):
@inherit_doc
class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
"""A matrix factorisation model trained by regularized alternating
least-squares.
......@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
0.43...
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = MatrixFactorizationModel.load(sc, path)
>>> sameModel.predict(2,2)
0.43...
>>> try:
... os.removedirs(path)
... except:
... pass
"""
def predict(self, user, product):
return self._java_model.predict(int(user), int(product))
......@@ -98,6 +111,9 @@ class MatrixFactorizationModel(JavaModelWrapper):
def productFeatures(self):
return self.call("getProductFeatures")
def save(self, sc, path):
self.call("save", sc._jsc.sc(), path)
class ALS(object):
......
......@@ -168,6 +168,64 @@ class MLUtils(object):
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
class Saveable(object):
"""
Mixin for models and transformers which may be saved as files.
"""
def save(self, sc, path):
"""
Save this model to the given path.
This saves:
* human-readable (JSON) model metadata to path/metadata/
* Parquet formatted data to path/data/
The model may be loaded using py:meth:`Loader.load`.
:param sc: Spark context used to save model data.
:param path: Path specifying the directory in which to save
this model. If the directory already exists,
this method throws an exception.
"""
raise NotImplementedError
class Loader(object):
"""
Mixin for classes which can load saved models from files.
"""
@classmethod
def load(cls, sc, path):
"""
Load a model from the given path. The model should have been
saved using py:meth:`Saveable.save`.
:param sc: Spark context used for loading model files.
:param path: Path specifying the directory to which the model
was saved.
:return: model instance
"""
raise NotImplemented
class JavaLoader(Loader):
"""
Mixin for classes which can load saved models using its Scala
implementation.
"""
@classmethod
def load(cls, sc, path):
java_package = cls.__module__.replace("pyspark", "org.apache.spark")
java_class = ".".join([java_package, cls.__name__])
java_obj = sc._jvm
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return cls(java_obj.load(sc._jsc.sc(), path))
def _test():
import doctest
from pyspark.context import SparkContext
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment