Skip to content
Snippets Groups Projects
Commit c8f667d7 authored by Kai Jiang's avatar Kai Jiang Committed by Xiangrui Meng
Browse files

[SPARK-13037][ML][PYSPARK] PySpark ml.recommendation support export/import

PySpark ml.recommendation support export/import.

Author: Kai Jiang <jiangkai@gmail.com>

Closes #11044 from vectorijk/spark-13037.
parent 574571c8
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@
#
from pyspark import since
from pyspark.ml.util import keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
......@@ -26,7 +26,8 @@ __all__ = ['ALS', 'ALSModel']
@inherit_doc
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed):
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed,
MLWritable, MLReadable):
"""
Alternating Least Squares (ALS) matrix factorization.
......@@ -81,6 +82,27 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> als_path = path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
>>> als.getMaxIter()
5
>>> model_path = path + "/als_model"
>>> model.save(model_path)
>>> model2 = ALSModel.load(model_path)
>>> model.rank == model2.rank
True
>>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())
True
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.4.0
"""
......@@ -274,7 +296,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
return self.getOrDefault(self.nonnegative)
class ALSModel(JavaModel):
class ALSModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by ALS.
......@@ -308,9 +330,10 @@ class ALSModel(JavaModel):
if __name__ == "__main__":
import doctest
import pyspark.ml.recommendation
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
globs = pyspark.ml.recommendation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.recommendation tests")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment