diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 2364d43aaa0e2e1f91e6a63ed16d69cb275f6eeb..531c8b07910fc8ebce7a7359b32eea4e30b485df 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging {
     val expTable = sc.broadcast(createExpTable())
     val bcVocab = sc.broadcast(vocab)
     val bcVocabHash = sc.broadcast(vocabHash)
+    try {
+      doFit(dataset, sc, expTable, bcVocab, bcVocabHash)
+    } finally {
+      expTable.destroy(blocking = false)
+      bcVocab.destroy(blocking = false)
+      bcVocabHash.destroy(blocking = false)
+    }
+  }
+
+  private def doFit[S <: Iterable[String]](
+    dataset: RDD[S], sc: SparkContext,
+    expTable: Broadcast[Array[Float]],
+    bcVocab: Broadcast[Array[VocabWord]],
+    bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
     // each partition is a collection of sentences,
     // will be translated into arrays of Index integer
     val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging {
       bcSyn1Global.destroy(false)
     }
     newSentences.unpersist()
-    expTable.destroy(false)
-    bcVocab.destroy(false)
-    bcVocabHash.destroy(false)
 
     val wordArray = vocab.map(_.word)
     new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)