diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 3b95a9e6936e8ba6d11719746cb87077060f5317..707da537d238fafb97dacd1a822ba9e50cae0c3a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,15 +17,22 @@
 
 package org.apache.spark.mllib.clustering
 
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.Row
 
 /**
  * A clustering model for K-means. Each point belongs to the cluster with the closest center.
  */
-class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
+class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
 
   /** Total number of clusters. */
   def k: Int = clusterCenters.length
@@ -58,4 +65,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
 
   private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
     clusterCenters.map(new VectorWithNorm(_))
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    KMeansModel.SaveLoadV1_0.save(sc, this, path)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object KMeansModel extends Loader[KMeansModel] {
+  override def load(sc: SparkContext, path: String): KMeansModel = {
+    KMeansModel.SaveLoadV1_0.load(sc, path)
+  }
+
+  private case class Cluster(id: Int, point: Vector)
+
+  private object Cluster {
+    def apply(r: Row): Cluster = {
+      Cluster(r.getInt(0), r.getAs[Vector](1))
+    }
+  }
+
+  private[clustering]
+  object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    private[clustering]
+    val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
+
+    def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+      val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
+        Cluster(id, point)
+      }.toDF()
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): KMeansModel = {
+      implicit val formats = DefaultFormats
+      val sqlContext = new SQLContext(sc)
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+      val k = (metadata \ "k").extract[Int]
+      val centriods = sqlContext.parquetFile(Loader.dataPath(path))
+      Loader.checkSchema[Cluster](centriods.schema)
+      val localCentriods = centriods.map(Cluster.apply).collect()
+      assert(k == localCentriods.size)
+      new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
+    }
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index caee5917000aa7b29f96f0434eb0867aa512ac1e..7bf250eb5a38372a1f8092d7350c9b5088abf22e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,9 +21,10 @@ import scala.util.Random
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
 
 class KMeansSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
       assert(predicts(0) != predicts(3))
     }
   }
+
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    Array(true, false).foreach { case selector =>
+      val model = KMeansSuite.createModel(10, 3, selector)
+      // Save model, load it back, and compare.
+      try {
+        model.save(sc, path)
+        val sameModel = KMeansModel.load(sc, path)
+        KMeansSuite.checkEqual(model, sameModel)
+      } finally {
+        Utils.deleteRecursively(tempDir)
+      }
+    }
+  }
+}
+
+object KMeansSuite extends FunSuite {
+  def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
+    val singlePoint = isSparse match {
+      case true =>
+        Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
+      case _ =>
+        Vectors.dense(Array.fill[Double](dim)(0.0))
+    }
+    new KMeansModel(Array.fill[Vector](k)(singlePoint))
+  }
+
+  def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
+    assert(a.k === b.k)
+    a.clusterCenters.zip(b.clusterCenters).foreach {
+      case (ca: SparseVector, cb: SparseVector) =>
+        assert(ca === cb)
+      case (ca: DenseVector, cb: DenseVector) =>
+        assert(ca === cb)
+      case _ =>
+        throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
+    }
+  }
 }
 
 class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {