From f0f563a3c43fc9683e6920890cce44611c0c5f4b Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Sun, 30 Aug 2015 23:20:03 -0700
Subject: [PATCH] [SPARK-100354] [MLLIB] fix some apparent memory issues in
 k-means|| initializaiton

* do not cache first cost RDD
* change following cost RDD cache level to MEMORY_AND_DISK
* remove Vector wrapper to save a object per instance

Further improvements will be addressed in SPARK-10329

cc: yu-iskw HuJiayin

Author: Xiangrui Meng <meng@databricks.com>

Closes #8526 from mengxr/SPARK-10354.
---
 .../spark/mllib/clustering/KMeans.scala       | 21 ++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 46920fffe6..7168aac32c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -369,7 +369,7 @@ class KMeans private (
   : Array[Array[VectorWithNorm]] = {
     // Initialize empty centers and point costs.
     val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
-    var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
+    var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
 
     // Initialize each run's first center to a random point.
     val seed = new XORShiftRandom(this.seed).nextInt()
@@ -394,21 +394,28 @@ class KMeans private (
       val bcNewCenters = data.context.broadcast(newCenters)
       val preCosts = costs
       costs = data.zip(preCosts).map { case (point, cost) =>
-        Vectors.dense(
           Array.tabulate(runs) { r =>
             math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
-          })
-      }.cache()
+          }
+        }.persist(StorageLevel.MEMORY_AND_DISK)
       val sumCosts = costs
-        .aggregate(Vectors.zeros(runs))(
+        .aggregate(new Array[Double](runs))(
           seqOp = (s, v) => {
             // s += v
-            axpy(1.0, v, s)
+            var r = 0
+            while (r < runs) {
+              s(r) += v(r)
+              r += 1
+            }
             s
           },
           combOp = (s0, s1) => {
             // s0 += s1
-            axpy(1.0, s1, s0)
+            var r = 0
+            while (r < runs) {
+              s0(r) += s1(r)
+              r += 1
+            }
             s0
           }
         )
-- 
GitLab