From 4d6d8192c807006ff89488a1d38bc6f7d41de5cf Mon Sep 17 00:00:00 2001
From: dardelet <guillaumegorp@gmail.com>
Date: Tue, 4 Jul 2017 17:58:44 +0100
Subject: [PATCH] [SPARK-21268][MLLIB] Move center calculations to a
 distributed map in KMeans

## What changes were proposed in this pull request?

The scal() and creation of newCenter vector is done in the driver, after a collectAsMap operation while it could be done in the distributed RDD.
This PR moves this code before the collectAsMap for more efficiency

## How was this patch tested?

This was tested manually by running the KMeansExample and verifying that the new code ran without error and gave same output as before.

Author: dardelet <guillaumegorp@gmail.com>
Author: Guillaume Dardelet <dardelet@users.noreply.github.com>

Closes #18491 from dardelet/move-center-calculation-to-distributed-map-kmean.
---
 .../org/apache/spark/mllib/clustering/KMeans.scala    | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index fa72b72e2d..98e50c5b45 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -272,8 +272,8 @@ class KMeans private (
       val costAccum = sc.doubleAccumulator
       val bcCenters = sc.broadcast(centers)
 
-      // Find the sum and count of points mapping to each center
-      val totalContribs = data.mapPartitions { points =>
+      // Find the new centers
+      val newCenters = data.mapPartitions { points =>
         val thisCenters = bcCenters.value
         val dims = thisCenters.head.vector.size
 
@@ -292,15 +292,16 @@ class KMeans private (
       }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
         axpy(1.0, sum2, sum1)
         (sum1, count1 + count2)
+      }.mapValues { case (sum, count) =>
+        scal(1.0 / count, sum)
+        new VectorWithNorm(sum)
       }.collectAsMap()
 
       bcCenters.destroy(blocking = false)
 
       // Update the cluster centers and costs
       converged = true
-      totalContribs.foreach { case (j, (sum, count)) =>
-        scal(1.0 / count, sum)
-        val newCenter = new VectorWithNorm(sum)
+      newCenters.foreach { case (j, newCenter) =>
         if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) {
           converged = false
         }
-- 
GitLab