Skip to content
Snippets Groups Projects
Commit bce354c1 authored by WeichenXu's avatar WeichenXu Committed by Sean Owen
Browse files

[SPARK-16696][ML][MLLIB] destroy KMeans bcNewCenters when loop finished and...

[SPARK-16696][ML][MLLIB] destroy KMeans bcNewCenters when loop finished and update code where should release unused broadcast/RDD in proper time

## What changes were proposed in this pull request?

update unused broadcast in KMeans/Word2Vec,
use destroy(false) to release memory in time.

and several place destroy() update to destroy(false) so that it will be async-called,
it will better than blocking called.

and update bcNewCenters in KMeans to make it destroy in correct time.
I use a list to store all historical `bcNewCenters` generated in each loop iteration and delay them to release at the end of loop.

fix TODO in `BisectingKMeans.run` "unpersist old indices",
Implements the pattern "persist current step RDD, and unpersist previous one" in the loop iteration.

## How was this patch tested?

Existing tests.

Author: WeichenXu <WeichenXu123@outlook.com>

Closes #14333 from WeichenXu123/broadvar_unpersist_to_destroy.
parent 0dc4310b
No related branches found
No related tags found
No related merge requests found
......@@ -165,6 +165,8 @@ class BisectingKMeans private (
val random = new Random(seed)
var numLeafClustersNeeded = k - 1
var level = 1
var preIndices: RDD[Long] = null
var indices: RDD[Long] = null
while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) {
// Divisible clusters are sufficiently large and have non-trivial cost.
var divisibleClusters = activeClusters.filter { case (_, summary) =>
......@@ -194,8 +196,9 @@ class BisectingKMeans private (
newClusters = summarize(d, newAssignments)
newClusterCenters = newClusters.mapValues(_.center).map(identity)
}
// TODO: Unpersist old indices.
val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
if (preIndices != null) preIndices.unpersist()
preIndices = indices
indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
.persist(StorageLevel.MEMORY_AND_DISK)
assignments = indices.zip(vectors)
inactiveClusters ++= activeClusters
......@@ -208,6 +211,7 @@ class BisectingKMeans private (
}
level += 1
}
if(indices != null) indices.unpersist()
val clusters = activeClusters ++ inactiveClusters
val root = buildTree(clusters)
new BisectingKMeansModel(root)
......
......@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
import org.apache.spark.ml.util.Instrumentation
......@@ -309,7 +310,7 @@ class KMeans private (
contribs.iterator
}.reduceByKey(mergeContribs).collectAsMap()
bcActiveCenters.unpersist(blocking = false)
bcActiveCenters.destroy(blocking = false)
// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
......@@ -402,8 +403,10 @@ class KMeans private (
// to their squared distance from that run's centers. Note that only distances between points
// and new centers are computed in each iteration.
var step = 0
var bcNewCentersList = ArrayBuffer[Broadcast[_]]()
while (step < initializationSteps) {
val bcNewCenters = data.context.broadcast(newCenters)
bcNewCentersList += bcNewCenters
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
Array.tabulate(runs) { r =>
......@@ -453,6 +456,7 @@ class KMeans private (
mergeNewCenters()
costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
......@@ -464,7 +468,7 @@ class KMeans private (
}
}.reduceByKey(_ + _).collectAsMap()
bcCenters.unpersist(blocking = false)
bcCenters.destroy(blocking = false)
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
......
......@@ -430,13 +430,13 @@ class Word2Vec extends Serializable with Logging {
}
i += 1
}
bcSyn0Global.unpersist(false)
bcSyn1Global.unpersist(false)
bcSyn0Global.destroy(false)
bcSyn1Global.destroy(false)
}
newSentences.unpersist()
expTable.destroy()
bcVocab.destroy()
bcVocabHash.destroy()
expTable.destroy(false)
bcVocab.destroy(false)
bcVocabHash.destroy(false)
val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment