diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 059b52ef20a986fd5fcf53038d731e0b82e5d395..ece28848aa02c2fc55cecf547233f7f9bf3c333a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] (
   override protected def formatVersion = "1.0"
 
   override def save(sc: SparkContext, path: String): Unit = {
-    LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
+    LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
+      gammaShape)
   }
   // TODO
   // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
@@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
     // as a Row in data.
     case class Data(topic: Vector, index: Int)
 
-    // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in
-    // model.predict()
-    def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
+    def save(
+        sc: SparkContext,
+        path: String,
+        topicsMatrix: Matrix,
+        docConcentration: Vector,
+        topicConcentration: Double,
+        gammaShape: Double): Unit = {
       val sqlContext = SQLContext.getOrCreate(sc)
       import sqlContext.implicits._
 
       val k = topicsMatrix.numCols
       val metadata = compact(render
         (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
-          ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
+          ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~
+          ("docConcentration" -> docConcentration.toArray.toSeq) ~
+          ("topicConcentration" -> topicConcentration) ~
+          ("gammaShape" -> gammaShape)))
       sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
 
       val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
@@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
       sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
     }
 
-    def load(sc: SparkContext, path: String): LocalLDAModel = {
+    def load(
+        sc: SparkContext,
+        path: String,
+        docConcentration: Vector,
+        topicConcentration: Double,
+        gammaShape: Double): LocalLDAModel = {
       val dataPath = Loader.dataPath(path)
       val sqlContext = SQLContext.getOrCreate(sc)
       val dataFrame = sqlContext.read.parquet(dataPath)
@@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
       val topicsMat = Matrices.fromBreeze(brzTopics)
 
       // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
-      new LocalLDAModel(topicsMat,
-        Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D)
+      new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape)
     }
   }
 
@@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
     implicit val formats = DefaultFormats
     val expectedK = (metadata \ "k").extract[Int]
     val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
+    val docConcentration =
+      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
+    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+    val gammaShape = (metadata \ "gammaShape").extract[Double]
     val classNameV1_0 = SaveLoadV1_0.thisClassName
 
     val model = (loadedClassName, loadedVersion) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        SaveLoadV1_0.load(sc, path)
+        SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape)
       case _ => throw new Exception(
         s"LocalLDAModel.load did not recognize model with (className, format version):" +
           s"($loadedClassName, $loadedVersion).  Supported:\n" +
@@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
 
     val thisFormatVersion = "1.0"
 
-    val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
+    val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel"
 
     // Store globalTopicTotals as a Vector.
     case class Data(globalTopicTotals: Vector)
@@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
       import sqlContext.implicits._
 
       val metadata = compact(render
-        (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
+        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
           ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
           ("docConcentration" -> docConcentration.toArray.toSeq) ~
           ("topicConcentration" -> topicConcentration) ~
@@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
     val topicConcentration = (metadata \ "topicConcentration").extract[Double]
     val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
     val gammaShape = (metadata \ "gammaShape").extract[Double]
-    val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
 
     val model = (loadedClassName, loadedVersion) match {
       case (className, "1.0") if className == classNameV1_0 => {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index aa36336ebbee68142e05cf67f53680bb74c8e242..b91c7cefed22e607139e0561b4178e12aa422dc1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
   test("model save/load") {
     // Test for LocalLDAModel.
     val localModel = new LocalLDAModel(tinyTopics,
-      Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
+      Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
     val tempDir1 = Utils.createTempDir()
     val path1 = tempDir1.toURI.toString
 
@@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
       assert(samelocalModel.k === localModel.k)
       assert(samelocalModel.vocabSize === localModel.vocabSize)
+      assert(samelocalModel.docConcentration === localModel.docConcentration)
+      assert(samelocalModel.topicConcentration === localModel.topicConcentration)
+      assert(samelocalModel.gammaShape === localModel.gammaShape)
 
       val sameDistributedModel = DistributedLDAModel.load(sc, path2)
       assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
@@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
       assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
       assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
+      assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
       assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
 
       val graph = distributedModel.graph