diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
index a2ed42d7a59e9740a364739fbb9334aa00f7544e..b402c71ed212028b32e120c257e04bfcd29fd6e9 100644
--- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
@@ -322,7 +322,7 @@ object KMeans {
     val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt)
     val runs = if (args.length >= 5) args(4).toInt else 1
     val sc = new SparkContext(master, "KMeans")
-    val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble))
+    val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache()
     val model = KMeans.train(data, k, iters, runs)
     val cost = model.computeCost(data)
     println("Cluster centers:")