Skip to content
Snippets Groups Projects
Commit 4d2864b2 authored by Kai Jiang's avatar Kai Jiang Committed by Xiangrui Meng
Browse files

[SPARK-7106][MLLIB][PYSPARK] Support model save/load in Python's FPGrowth

## What changes were proposed in this pull request?

Python API supports mode save/load in FPGrowth
JIRA: [https://issues.apache.org/jira/browse/SPARK-7106](https://issues.apache.org/jira/browse/SPARK-7106)
## How was the this patch tested?

The patch is tested with Python doctest.

Author: Kai Jiang <jiangkai@gmail.com>

Closes #11321 from vectorijk/spark-7106.
parent 13ce10e9
No related branches found
No related tags found
No related merge requests found
......@@ -21,14 +21,15 @@ from collections import namedtuple
from pyspark import SparkContext, since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc
__all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel']
@inherit_doc
@ignore_unicode_prefix
class FPGrowthModel(JavaModelWrapper):
class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""
.. note:: Experimental
......@@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper):
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect())
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
>>> model_path = temp_path + "/fpm"
>>> model.save(sc, model_path)
>>> sameModel = FPGrowthModel.load(sc, model_path)
>>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect())
True
.. versionadded:: 1.4.0
"""
......@@ -51,6 +57,16 @@ class FPGrowthModel(JavaModelWrapper):
"""
return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
@classmethod
@since("2.0.0")
def load(cls, sc, path):
"""
Load a model from the given path.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.FPGrowthModelWrapper(model)
return FPGrowthModel(wrapper)
class FPGrowth(object):
"""
......@@ -170,8 +186,19 @@ def _test():
import pyspark.mllib.fpm
globs = pyspark.mllib.fpm.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count:
exit(-1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment