From 9ea201cf6482c9c62c9428759d238063db62d66e Mon Sep 17 00:00:00 2001
From: Anthony Truchet <a.truchet@criteo.com>
Date: Wed, 8 Mar 2017 11:44:25 +0000
Subject: [PATCH] [SPARK-16440][MLLIB] Ensure broadcasted variables are
 destroyed even in case of exception

## What changes were proposed in this pull request?

Ensure broadcasted variable are destroyed even in case of exception
## How was this patch tested?

Word2VecSuite was run locally

Author: Anthony Truchet <a.truchet@criteo.com>

Closes #14299 from AnthonyTruchet/SPARK-16440.
---
 .../apache/spark/mllib/feature/Word2Vec.scala  | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 2364d43aaa..531c8b0791 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging {
     val expTable = sc.broadcast(createExpTable())
     val bcVocab = sc.broadcast(vocab)
     val bcVocabHash = sc.broadcast(vocabHash)
+    try {
+      doFit(dataset, sc, expTable, bcVocab, bcVocabHash)
+    } finally {
+      expTable.destroy(blocking = false)
+      bcVocab.destroy(blocking = false)
+      bcVocabHash.destroy(blocking = false)
+    }
+  }
+
+  private def doFit[S <: Iterable[String]](
+    dataset: RDD[S], sc: SparkContext,
+    expTable: Broadcast[Array[Float]],
+    bcVocab: Broadcast[Array[VocabWord]],
+    bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
     // each partition is a collection of sentences,
     // will be translated into arrays of Index integer
     val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging {
       bcSyn1Global.destroy(false)
     }
     newSentences.unpersist()
-    expTable.destroy(false)
-    bcVocab.destroy(false)
-    bcVocabHash.destroy(false)
 
     val wordArray = vocab.map(_.word)
     new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
-- 
GitLab