diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 92e05815d6a3d3a64d56c165cde1802c3a389229..830510b1698d4a3ac71ff18a8c7782320ea341cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.ml.clustering
 
+import org.apache.hadoop.fs.Path
 import org.apache.spark.Logging
 import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter}
 import org.apache.spark.ml.param._
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
     EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
     LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
@@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
     @Since("1.6.0") override val uid: String,
     @Since("1.6.0") val vocabSize: Int,
     @Since("1.6.0") @transient protected val sqlContext: SQLContext)
-  extends Model[LDAModel] with LDAParams with Logging {
+  extends Model[LDAModel] with LDAParams with Logging with MLWritable {
 
   // NOTE to developers:
   //  This abstraction should contain all important functionality for basic LDA usage.
@@ -486,6 +487,64 @@ class LocalLDAModel private[ml] (
 
   @Since("1.6.0")
   override def isDistributed: Boolean = false
+
+  @Since("1.6.0")
+  override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
+}
+
+
+@Since("1.6.0")
+object LocalLDAModel extends MLReadable[LocalLDAModel] {
+
+  private[LocalLDAModel]
+  class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {
+
+    private case class Data(
+        vocabSize: Int,
+        topicsMatrix: Matrix,
+        docConcentration: Vector,
+        topicConcentration: Double,
+        gammaShape: Double)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val oldModel = instance.oldLocalModel
+      val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
+        oldModel.topicConcentration, oldModel.gammaShape)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class LocalLDAModelReader extends MLReader[LocalLDAModel] {
+
+    private val className = classOf[LocalLDAModel].getName
+
+    override def load(path: String): LocalLDAModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
+          "gammaShape")
+        .head()
+      val vocabSize = data.getAs[Int](0)
+      val topicsMatrix = data.getAs[Matrix](1)
+      val docConcentration = data.getAs[Vector](2)
+      val topicConcentration = data.getAs[Double](3)
+      val gammaShape = data.getAs[Double](4)
+      val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
+        gammaShape)
+      val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): LocalLDAModel = super.load(path)
 }
 
 
@@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] (
    */
   @Since("1.6.0")
   lazy val logPrior: Double = oldDistributedModel.logPrior
+
+  @Since("1.6.0")
+  override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
+}
+
+
+@Since("1.6.0")
+object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
+
+  private[DistributedLDAModel]
+  class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val modelPath = new Path(path, "oldModel").toString
+      instance.oldDistributedModel.save(sc, modelPath)
+    }
+  }
+
+  private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] {
+
+    private val className = classOf[DistributedLDAModel].getName
+
+    override def load(path: String): DistributedLDAModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val modelPath = new Path(path, "oldModel").toString
+      val oldModel = OldDistributedLDAModel.load(sc, modelPath)
+      val model = new DistributedLDAModel(
+        metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): DistributedLDAModel = super.load(path)
 }
 
 
@@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] (
 @Since("1.6.0")
 @Experimental
 class LDA @Since("1.6.0") (
-    @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams {
+    @Since("1.6.0") override val uid: String)
+  extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable {
 
   @Since("1.6.0")
   def this() = this(Identifiable.randomUID("lda"))
@@ -695,7 +794,7 @@ class LDA @Since("1.6.0") (
 }
 
 
-private[clustering] object LDA {
+private[clustering] object LDA extends DefaultParamsReadable[LDA] {
 
   /** Get dataset for spark.mllib LDA */
   def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
@@ -706,4 +805,7 @@ private[clustering] object LDA {
         (docId, features)
       }
   }
+
+  @Since("1.6.0")
+  override def load(path: String): LDA = super.load(path)
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index cd520f09bd4664d102c61ecc61bf074cd076154a..7384d065a2ea8ffc113b95933617431e68a7ea9d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable {
  * @param topics Inferred topics (vocabSize x k matrix).
  */
 @Since("1.3.0")
-class LocalLDAModel private[clustering] (
+class LocalLDAModel private[spark] (
     @Since("1.3.0") val topics: Matrix,
     @Since("1.5.0") override val docConcentration: Vector,
     @Since("1.5.0") override val topicConcentration: Double,
-    override protected[clustering] val gammaShape: Double = 100)
+    override protected[spark] val gammaShape: Double = 100)
   extends LDAModel with Serializable {
 
   @Since("1.3.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index b634d31cc34f025b0fa46c79cf9fcc27b4c7ef8b..97dbfd9a4314a5db98895fbf2ac17be4f732aef5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -18,9 +18,10 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row, SQLContext}
 
 
@@ -39,10 +40,24 @@ object LDASuite {
     }.map(v => new TestRow(v))
     sql.createDataFrame(rdd)
   }
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "k" -> 3,
+    "maxIter" -> 2,
+    "checkpointInterval" -> 30,
+    "learningOffset" -> 1023.0,
+    "learningDecay" -> 0.52,
+    "subsamplingRate" -> 0.051
+  )
 }
 
 
-class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   val k: Int = 5
   val vocabSize: Int = 30
@@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     val lp = model.logPrior
     assert(lp <= 0.0 && lp != Double.NegativeInfinity)
   }
+
+  test("read/write LocalLDAModel") {
+    def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
+      assert(model.vocabSize === model2.vocabSize)
+      assert(Vectors.dense(model.topicsMatrix.toArray) ~==
+        Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
+      assert(Vectors.dense(model.getDocConcentration) ~==
+        Vectors.dense(model2.getDocConcentration) absTol 1e-6)
+    }
+    val lda = new LDA()
+    testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
+  }
+
+  test("read/write DistributedLDAModel") {
+    def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
+      assert(model.vocabSize === model2.vocabSize)
+      assert(Vectors.dense(model.topicsMatrix.toArray) ~==
+        Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
+      assert(Vectors.dense(model.getDocConcentration) ~==
+        Vectors.dense(model2.getDocConcentration) absTol 1e-6)
+    }
+    val lda = new LDA()
+    testEstimatorAndModelReadWrite(lda, dataset,
+      LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
+  }
 }