diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 23141aaf42b49566c8955725cd5a6d683cf92b69..68a7b3b6763af21cc3b501bebad25fd78e08176d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -43,18 +43,17 @@ import org.apache.spark.util.random.XORShiftRandom
 class KMeans private (
     private var k: Int,
     private var maxIterations: Int,
-    private var runs: Int,
     private var initializationMode: String,
     private var initializationSteps: Int,
     private var epsilon: Double,
     private var seed: Long) extends Serializable with Logging {
 
   /**
-   * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
+   * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20,
    * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}.
    */
   @Since("0.8.0")
-  def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())
+  def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())
 
   /**
    * Number of clusters to create (k).
@@ -112,15 +111,17 @@ class KMeans private (
    * This function has no effect since Spark 2.0.0.
    */
   @Since("1.4.0")
+  @deprecated("This has no effect and always returns 1", "2.1.0")
   def getRuns: Int = {
     logWarning("Getting number of runs has no effect since Spark 2.0.0.")
-    runs
+    1
   }
 
   /**
    * This function has no effect since Spark 2.0.0.
    */
   @Since("0.8.0")
+  @deprecated("This has no effect", "2.1.0")
   def setRuns(runs: Int): this.type = {
     logWarning("Setting number of runs has no effect since Spark 2.0.0.")
     this
@@ -239,17 +240,9 @@ class KMeans private (
 
     val initStartTime = System.nanoTime()
 
-    // Only one run is allowed when initialModel is given
-    val numRuns = if (initialModel.nonEmpty) {
-      if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
-      1
-    } else {
-      runs
-    }
-
     val centers = initialModel match {
       case Some(kMeansCenters) =>
-        Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
+        kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
       case None =>
         if (initializationMode == KMeans.RANDOM) {
           initRandom(data)
@@ -258,89 +251,62 @@ class KMeans private (
         }
     }
     val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
-    logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
-      " seconds.")
+    logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")
 
-    val active = Array.fill(numRuns)(true)
-    val costs = Array.fill(numRuns)(0.0)
-
-    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
+    var converged = false
+    var cost = 0.0
     var iteration = 0
 
     val iterationStartTime = System.nanoTime()
 
-    instr.foreach(_.logNumFeatures(centers(0)(0).vector.size))
+    instr.foreach(_.logNumFeatures(centers.head.vector.size))
 
-    // Execute iterations of Lloyd's algorithm until all runs have converged
-    while (iteration < maxIterations && !activeRuns.isEmpty) {
-      type WeightedPoint = (Vector, Long)
-      def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
-        axpy(1.0, x._1, y._1)
-        (y._1, x._2 + y._2)
-      }
-
-      val activeCenters = activeRuns.map(r => centers(r)).toArray
-      val costAccums = activeRuns.map(_ => sc.doubleAccumulator)
-
-      val bcActiveCenters = sc.broadcast(activeCenters)
+    // Execute iterations of Lloyd's algorithm until converged
+    while (iteration < maxIterations && !converged) {
+      val costAccum = sc.doubleAccumulator
+      val bcCenters = sc.broadcast(centers)
 
       // Find the sum and count of points mapping to each center
       val totalContribs = data.mapPartitions { points =>
-        val thisActiveCenters = bcActiveCenters.value
-        val runs = thisActiveCenters.length
-        val k = thisActiveCenters(0).length
-        val dims = thisActiveCenters(0)(0).vector.size
+        val thisCenters = bcCenters.value
+        val dims = thisCenters.head.vector.size
 
-        val sums = Array.fill(runs, k)(Vectors.zeros(dims))
-        val counts = Array.fill(runs, k)(0L)
+        val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
+        val counts = Array.fill(thisCenters.length)(0L)
 
         points.foreach { point =>
-          (0 until runs).foreach { i =>
-            val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
-            costAccums(i).add(cost)
-            val sum = sums(i)(bestCenter)
-            axpy(1.0, point.vector, sum)
-            counts(i)(bestCenter) += 1
-          }
+          val (bestCenter, cost) = KMeans.findClosest(thisCenters, point)
+          costAccum.add(cost)
+          val sum = sums(bestCenter)
+          axpy(1.0, point.vector, sum)
+          counts(bestCenter) += 1
         }
 
-        val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
-          ((i, j), (sums(i)(j), counts(i)(j)))
-        }
-        contribs.iterator
-      }.reduceByKey(mergeContribs).collectAsMap()
-
-      bcActiveCenters.destroy(blocking = false)
-
-      // Update the cluster centers and costs for each active run
-      for ((run, i) <- activeRuns.zipWithIndex) {
-        var changed = false
-        var j = 0
-        while (j < k) {
-          val (sum, count) = totalContribs((i, j))
-          if (count != 0) {
-            scal(1.0 / count, sum)
-            val newCenter = new VectorWithNorm(sum)
-            if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
-              changed = true
-            }
-            centers(run)(j) = newCenter
-          }
-          j += 1
-        }
-        if (!changed) {
-          active(run) = false
-          logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations")
+        counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
+      }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
+        axpy(1.0, sum2, sum1)
+        (sum1, count1 + count2)
+      }.collectAsMap()
+
+      bcCenters.destroy(blocking = false)
+
+      // Update the cluster centers and costs
+      converged = true
+      totalContribs.foreach { case (j, (sum, count)) =>
+        scal(1.0 / count, sum)
+        val newCenter = new VectorWithNorm(sum)
+        if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) {
+          converged = false
         }
-        costs(run) = costAccums(i).value
+        centers(j) = newCenter
       }
 
-      activeRuns = activeRuns.filter(active(_))
+      cost = costAccum.value
       iteration += 1
     }
 
     val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
-    logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.")
+    logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")
 
     if (iteration == maxIterations) {
       logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
@@ -348,59 +314,43 @@ class KMeans private (
       logInfo(s"KMeans converged in $iteration iterations.")
     }
 
-    val (minCost, bestRun) = costs.zipWithIndex.min
+    logInfo(s"The cost is $cost.")
 
-    logInfo(s"The cost for the best run is $minCost.")
-
-    new KMeansModel(centers(bestRun).map(_.vector))
+    new KMeansModel(centers.map(_.vector))
   }
 
   /**
-   * Initialize `runs` sets of cluster centers at random.
+   * Initialize a set of cluster centers at random.
    */
-  private def initRandom(data: RDD[VectorWithNorm])
-  : Array[Array[VectorWithNorm]] = {
-    // Sample all the cluster centers in one pass to avoid repeated scans
-    val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
-    Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
-      new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
-    }.toArray)
+  private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
+    data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense)
   }
 
   /**
-   * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al.
+   * Initialize a set of cluster centers using the k-means|| algorithm by Bahmani et al.
    * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries
-   * to find with dissimilar cluster centers by starting with a random center and then doing
+   * to find dissimilar cluster centers by starting with a random center and then doing
    * passes where more centers are chosen with probability proportional to their squared distance
    * to the current cluster set. It results in a provable approximation to an optimal clustering.
    *
    * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
    */
-  private def initKMeansParallel(data: RDD[VectorWithNorm])
-  : Array[Array[VectorWithNorm]] = {
+  private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
     // Initialize empty centers and point costs.
-    val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
-    var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
+    var costs = data.map(_ => Double.PositiveInfinity)
 
-    // Initialize each run's first center to a random point.
+    // Initialize the first center to a random point.
     val seed = new XORShiftRandom(this.seed).nextInt()
-    val sample = data.takeSample(true, runs, seed).toSeq
+    val sample = data.takeSample(false, 1, seed)
     // Could be empty if data is empty; fail with a better message early:
-    require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data")
-    val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
-
-    /** Merges new centers to centers. */
-    def mergeNewCenters(): Unit = {
-      var r = 0
-      while (r < runs) {
-        centers(r) ++= newCenters(r)
-        newCenters(r).clear()
-        r += 1
-      }
-    }
+    require(sample.nonEmpty, s"No samples available from $data")
+
+    val centers = ArrayBuffer[VectorWithNorm]()
+    var newCenters = Seq(sample.head.toDense)
+    centers ++= newCenters
 
-    // On each step, sample 2 * k points on average for each run with probability proportional
-    // to their squared distance from that run's centers. Note that only distances between points
+    // On each step, sample 2 * k points on average with probability proportional
+    // to their squared distance from the centers. Note that only distances between points
     // and new centers are computed in each iteration.
     var step = 0
     var bcNewCentersList = ArrayBuffer[Broadcast[_]]()
@@ -409,74 +359,39 @@ class KMeans private (
       bcNewCentersList += bcNewCenters
       val preCosts = costs
       costs = data.zip(preCosts).map { case (point, cost) =>
-          Array.tabulate(runs) { r =>
-            math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
-          }
-        }.persist(StorageLevel.MEMORY_AND_DISK)
-      val sumCosts = costs
-        .aggregate(new Array[Double](runs))(
-          seqOp = (s, v) => {
-            // s += v
-            var r = 0
-            while (r < runs) {
-              s(r) += v(r)
-              r += 1
-            }
-            s
-          },
-          combOp = (s0, s1) => {
-            // s0 += s1
-            var r = 0
-            while (r < runs) {
-              s0(r) += s1(r)
-              r += 1
-            }
-            s0
-          }
-        )
+        math.min(KMeans.pointCost(bcNewCenters.value, point), cost)
+      }.persist(StorageLevel.MEMORY_AND_DISK)
+      val sumCosts = costs.sum()
 
       bcNewCenters.unpersist(blocking = false)
       preCosts.unpersist(blocking = false)
 
-      val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
+      val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointCosts) =>
         val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
-        pointsWithCosts.flatMap { case (p, c) =>
-          val rs = (0 until runs).filter { r =>
-            rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
-          }
-          if (rs.nonEmpty) Some((p, rs)) else None
-        }
+        pointCosts.filter { case (_, c) => rand.nextDouble() < 2.0 * c * k / sumCosts }.map(_._1)
       }.collect()
-      mergeNewCenters()
-      chosen.foreach { case (p, rs) =>
-        rs.foreach(newCenters(_) += p.toDense)
-      }
+      newCenters = chosen.map(_.toDense)
+      centers ++= newCenters
       step += 1
     }
 
-    mergeNewCenters()
     costs.unpersist(blocking = false)
     bcNewCentersList.foreach(_.destroy(false))
 
-    // Finally, we might have a set of more than k candidate centers for each run; weigh each
-    // candidate by the number of points in the dataset mapping to it and run a local k-means++
-    // on the weighted centers to pick just k of them
-    val bcCenters = data.context.broadcast(centers)
-    val weightMap = data.flatMap { p =>
-      Iterator.tabulate(runs) { r =>
-        ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
-      }
-    }.reduceByKey(_ + _).collectAsMap()
+    if (centers.size == k) {
+      centers.toArray
+    } else {
+      // Finally, we might have a set of more or less than k candidate centers; weight each
+      // candidate by the number of points in the dataset mapping to it and run a local k-means++
+      // on the weighted centers to pick k of them
+      val bcCenters = data.context.broadcast(centers)
+      val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue()
 
-    bcCenters.destroy(blocking = false)
+      bcCenters.destroy(blocking = false)
 
-    val finalCenters = (0 until runs).par.map { r =>
-      val myCenters = centers(r).toArray
-      val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
-      LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
+      val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
+      LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30)
     }
-
-    finalCenters.toArray
   }
 }
 
@@ -493,6 +408,52 @@ object KMeans {
   @Since("0.8.0")
   val K_MEANS_PARALLEL = "k-means||"
 
+  /**
+   * Trains a k-means model using the given set of parameters.
+   *
+   * @param data Training points as an `RDD` of `Vector` types.
+   * @param k Number of clusters to create.
+   * @param maxIterations Maximum number of iterations allowed.
+   * @param initializationMode The initialization algorithm. This can either be "random" or
+   *                           "k-means||". (default: "k-means||")
+   * @param seed Random seed for cluster initialization. Default is to generate seed based
+   *             on system time.
+   */
+  @Since("2.1.0")
+  def train(
+      data: RDD[Vector],
+      k: Int,
+      maxIterations: Int,
+      initializationMode: String,
+      seed: Long): KMeansModel = {
+    new KMeans().setK(k)
+      .setMaxIterations(maxIterations)
+      .setInitializationMode(initializationMode)
+      .setSeed(seed)
+      .run(data)
+  }
+
+  /**
+   * Trains a k-means model using the given set of parameters.
+   *
+   * @param data Training points as an `RDD` of `Vector` types.
+   * @param k Number of clusters to create.
+   * @param maxIterations Maximum number of iterations allowed.
+   * @param initializationMode The initialization algorithm. This can either be "random" or
+   *                           "k-means||". (default: "k-means||")
+   */
+  @Since("2.1.0")
+  def train(
+      data: RDD[Vector],
+      k: Int,
+      maxIterations: Int,
+      initializationMode: String): KMeansModel = {
+    new KMeans().setK(k)
+      .setMaxIterations(maxIterations)
+      .setInitializationMode(initializationMode)
+      .run(data)
+  }
+
   /**
    * Trains a k-means model using the given set of parameters.
    *
@@ -506,6 +467,7 @@ object KMeans {
    *             on system time.
    */
   @Since("1.3.0")
+  @deprecated("Use train method without 'runs'", "2.1.0")
   def train(
       data: RDD[Vector],
       k: Int,
@@ -531,6 +493,7 @@ object KMeans {
    *                           "k-means||". (default: "k-means||")
    */
   @Since("0.8.0")
+  @deprecated("Use train method without 'runs'", "2.1.0")
   def train(
       data: RDD[Vector],
       k: Int,
@@ -551,19 +514,24 @@ object KMeans {
       data: RDD[Vector],
       k: Int,
       maxIterations: Int): KMeansModel = {
-    train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
+    new KMeans().setK(k)
+      .setMaxIterations(maxIterations)
+      .run(data)
   }
 
   /**
    * Trains a k-means model using specified parameters and the default values for unspecified.
    */
   @Since("0.8.0")
+  @deprecated("Use train method without 'runs'", "2.1.0")
   def train(
       data: RDD[Vector],
       k: Int,
       maxIterations: Int,
       runs: Int): KMeansModel = {
-    train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
+    new KMeans().setK(k)
+      .setMaxIterations(maxIterations)
+      .run(data)
   }
 
   /**