diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index 7a2d77a4dad13e9f55a38db0b655542ebf55ef5e..5c9706cb8cb29c0cd65c16ad91bdfc867296026f 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -21,14 +21,15 @@ from collections import namedtuple
 
 from pyspark import SparkContext, since
 from pyspark.rdd import ignore_unicode_prefix
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
+from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc
 
 __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel']
 
 
 @inherit_doc
 @ignore_unicode_prefix
-class FPGrowthModel(JavaModelWrapper):
+class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     """
     .. note:: Experimental
 
@@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper):
     >>> model = FPGrowth.train(rdd, 0.6, 2)
     >>> sorted(model.freqItemsets().collect())
     [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
+    >>> model_path = temp_path + "/fpm"
+    >>> model.save(sc, model_path)
+    >>> sameModel = FPGrowthModel.load(sc, model_path)
+    >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect())
+    True
 
     .. versionadded:: 1.4.0
     """
@@ -51,6 +57,16 @@ class FPGrowthModel(JavaModelWrapper):
         """
         return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
 
+    @classmethod
+    @since("2.0.0")
+    def load(cls, sc, path):
+        """
+        Load a model from the given path.
+        """
+        model = cls._load_java(sc, path)
+        wrapper = sc._jvm.FPGrowthModelWrapper(model)
+        return FPGrowthModel(wrapper)
+
 
 class FPGrowth(object):
     """
@@ -170,8 +186,19 @@ def _test():
     import pyspark.mllib.fpm
     globs = pyspark.mllib.fpm.__dict__.copy()
     globs['sc'] = SparkContext('local[4]', 'PythonTest')
-    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    import tempfile
+
+    temp_path = tempfile.mkdtemp()
+    globs['temp_path'] = temp_path
+    try:
+        (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+        globs['sc'].stop()
+    finally:
+        from shutil import rmtree
+        try:
+            rmtree(temp_path)
+        except OSError:
+            pass
     if failure_count:
         exit(-1)