Skip to content
Snippets Groups Projects
Commit 4d2864b2 authored by Kai Jiang's avatar Kai Jiang Committed by Xiangrui Meng
Browse files

[SPARK-7106][MLLIB][PYSPARK] Support model save/load in Python's FPGrowth

## What changes were proposed in this pull request?

Python API supports mode save/load in FPGrowth
JIRA: [https://issues.apache.org/jira/browse/SPARK-7106](https://issues.apache.org/jira/browse/SPARK-7106)
## How was the this patch tested?

The patch is tested with Python doctest.

Author: Kai Jiang <jiangkai@gmail.com>

Closes #11321 from vectorijk/spark-7106.
parent 13ce10e9
No related branches found
No related tags found
No related merge requests found
...@@ -21,14 +21,15 @@ from collections import namedtuple ...@@ -21,14 +21,15 @@ from collections import namedtuple
from pyspark import SparkContext, since from pyspark import SparkContext, since
from pyspark.rdd import ignore_unicode_prefix from pyspark.rdd import ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc
__all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel']
@inherit_doc @inherit_doc
@ignore_unicode_prefix @ignore_unicode_prefix
class FPGrowthModel(JavaModelWrapper): class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
""" """
.. note:: Experimental .. note:: Experimental
...@@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper): ...@@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper):
>>> model = FPGrowth.train(rdd, 0.6, 2) >>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect()) >>> sorted(model.freqItemsets().collect())
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
>>> model_path = temp_path + "/fpm"
>>> model.save(sc, model_path)
>>> sameModel = FPGrowthModel.load(sc, model_path)
>>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect())
True
.. versionadded:: 1.4.0 .. versionadded:: 1.4.0
""" """
...@@ -51,6 +57,16 @@ class FPGrowthModel(JavaModelWrapper): ...@@ -51,6 +57,16 @@ class FPGrowthModel(JavaModelWrapper):
""" """
return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
@classmethod
@since("2.0.0")
def load(cls, sc, path):
"""
Load a model from the given path.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.FPGrowthModelWrapper(model)
return FPGrowthModel(wrapper)
class FPGrowth(object): class FPGrowth(object):
""" """
...@@ -170,8 +186,19 @@ def _test(): ...@@ -170,8 +186,19 @@ def _test():
import pyspark.mllib.fpm import pyspark.mllib.fpm
globs = pyspark.mllib.fpm.__dict__.copy() globs = pyspark.mllib.fpm.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) import tempfile
globs['sc'].stop()
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count: if failure_count:
exit(-1) exit(-1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment