diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 8a92f6911c24b08902a4d12c45fa4404d169853a..58ad99d46e23b63f46f18f2ee127f199216166d2 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -20,6 +20,7 @@ import array as pyarray
 
 if sys.version > '3':
     xrange = range
+    basestring = str
 
 from math import exp, log
 
@@ -579,7 +580,7 @@ class LDAModel(JavaModelWrapper):
     Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
 
     >>> from pyspark.mllib.linalg import Vectors
-    >>> from numpy.testing import assert_almost_equal
+    >>> from numpy.testing import assert_almost_equal, assert_equal
     >>> data = [
     ...     [1, Vectors.dense([0.0, 1.0])],
     ...     [2, SparseVector(2, {0: 1.0})],
@@ -591,6 +592,19 @@ class LDAModel(JavaModelWrapper):
     >>> topics = model.topicsMatrix()
     >>> topics_expect = array([[0.5,  0.5], [0.5, 0.5]])
     >>> assert_almost_equal(topics, topics_expect, 1)
+
+    >>> import os, tempfile
+    >>> from shutil import rmtree
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = LDAModel.load(sc, path)
+    >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix())
+    >>> sameModel.vocabSize() == model.vocabSize()
+    True
+    >>> try:
+    ...     rmtree(path)
+    ... except OSError:
+    ...     pass
     """
 
     def topicsMatrix(self):
@@ -601,6 +615,33 @@ class LDAModel(JavaModelWrapper):
         """Vocabulary size (number of terms or terms in the vocabulary)"""
         return self.call("vocabSize")
 
+    def save(self, sc, path):
+        """Save the LDAModel on to disk.
+
+        :param sc: SparkContext
+        :param path: str, path to where the model needs to be stored.
+        """
+        if not isinstance(sc, SparkContext):
+            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
+        if not isinstance(path, basestring):
+            raise TypeError("path should be a basestring, got type %s" % type(path))
+        self._java_model.save(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        """Load the LDAModel from disk.
+
+        :param sc: SparkContext
+        :param path: str, path to where the model is stored.
+        """
+        if not isinstance(sc, SparkContext):
+            raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
+        if not isinstance(path, basestring):
+            raise TypeError("path should be a basestring, got type %s" % type(path))
+        java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
+            sc._jsc.sc(), path)
+        return cls(java_model)
+
 
 class LDA(object):