Skip to content
Snippets Groups Projects
Commit 9ea201cf authored by Anthony Truchet's avatar Anthony Truchet Committed by Sean Owen
Browse files

[SPARK-16440][MLLIB] Ensure broadcasted variables are destroyed even in case of exception

## What changes were proposed in this pull request?

Ensure broadcasted variable are destroyed even in case of exception
## How was this patch tested?

Word2VecSuite was run locally

Author: Anthony Truchet <a.truchet@criteo.com>

Closes #14299 from AnthonyTruchet/SPARK-16440.
parent 3f9f9180
No related branches found
No related tags found
No related merge requests found
...@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._ ...@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
...@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging { ...@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging {
val expTable = sc.broadcast(createExpTable()) val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab) val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash) val bcVocabHash = sc.broadcast(vocabHash)
try {
doFit(dataset, sc, expTable, bcVocab, bcVocabHash)
} finally {
expTable.destroy(blocking = false)
bcVocab.destroy(blocking = false)
bcVocabHash.destroy(blocking = false)
}
}
private def doFit[S <: Iterable[String]](
dataset: RDD[S], sc: SparkContext,
expTable: Broadcast[Array[Float]],
bcVocab: Broadcast[Array[VocabWord]],
bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
// each partition is a collection of sentences, // each partition is a collection of sentences,
// will be translated into arrays of Index integer // will be translated into arrays of Index integer
val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
...@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging { ...@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging {
bcSyn1Global.destroy(false) bcSyn1Global.destroy(false)
} }
newSentences.unpersist() newSentences.unpersist()
expTable.destroy(false)
bcVocab.destroy(false)
bcVocabHash.destroy(false)
val wordArray = vocab.map(_.word) val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment