diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 6af90d7287ff8c63f4ad4fcce4ca169ee05b6969..33babda69bbb973174e022493a3d5a43f6cad3aa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -27,6 +27,7 @@ import org.json4s.jackson.JsonMethods._
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
 import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
 import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -217,26 +218,28 @@ class LocalLDAModel private[clustering] (
   // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
   /**
    * Calculates a lower bound on the log likelihood of the entire corpus.
+   *
+   * See Equation (16) in original Online LDA paper.
+   *
    * @param documents test corpus to use for calculating log likelihood
    * @return variational lower bound on the log likelihood of the entire corpus
    */
-  def logLikelihood(documents: RDD[(Long, Vector)]): Double = bound(documents,
+  def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents,
     docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k,
     vocabSize)
 
   /**
-   * Calculate an upper bound bound on perplexity. See Equation (16) in original Online
-   * LDA paper.
+   * Calculate an upper bound bound on perplexity.  (Lower is better.)
+   * See Equation (16) in original Online LDA paper.
+   *
    * @param documents test corpus to use for calculating perplexity
-   * @return variational upper bound on log perplexity per word
+   * @return Variational upper bound on log perplexity per token.
    */
   def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
-    val corpusWords = documents
+    val corpusTokenCount = documents
       .map { case (_, termCounts) => termCounts.toArray.sum }
       .sum()
-    val perWordBound = -logLikelihood(documents) / corpusWords
-
-    perWordBound
+    -logLikelihood(documents) / corpusTokenCount
   }
 
   /**
@@ -244,17 +247,20 @@ class LocalLDAModel private[clustering] (
    *    log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
    * This bound is derived by decomposing the LDA model to:
    *    log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
-   * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper.
+   * and noting that the KL-divergence D(q|p) >= 0.
+   *
+   * See Equation (16) in original Online LDA paper, as well as Appendix A.3 in the JMLR version of
+   * the original LDA paper.
    * @param documents a subset of the test corpus
    * @param alpha document-topic Dirichlet prior parameters
-   * @param eta topic-word Dirichlet prior parameters
+   * @param eta topic-word Dirichlet prior parameter
    * @param lambda parameters for variational q(beta | lambda) topic-word distributions
    * @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
    *                   topic mixture distributions
    * @param k number of topics
    * @param vocabSize number of unique terms in the entire test corpus
    */
-  private def bound(
+  private def logLikelihoodBound(
       documents: RDD[(Long, Vector)],
       alpha: Vector,
       eta: Double,
@@ -266,33 +272,38 @@ class LocalLDAModel private[clustering] (
     // transpose because dirichletExpectation normalizes by row and we need to normalize
     // by topic (columns of lambda)
     val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
+    val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)
+
+    // Sum bound components for each document:
+    //  component for prob(tokens) + component for prob(document-topic distribution)
+    val corpusPart =
+      documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
+        val localElogbeta = ElogbetaBc.value
+        var docBound = 0.0D
+        val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
+          termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
+        val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
+
+        // E[log p(doc | theta, beta)]
+        termCounts.foreachActive { case (idx, count) =>
+          docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t)
+        }
+        // E[log p(theta | alpha) - log q(theta | gamma)]
+        docBound += sum((brzAlpha - gammad) :* Elogthetad)
+        docBound += sum(lgamma(gammad) - lgamma(brzAlpha))
+        docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
 
-    var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
-      var docScore = 0.0D
-      val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
-        termCounts, exp(Elogbeta), brzAlpha, gammaShape, k)
-      val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
-
-      // E[log p(doc | theta, beta)]
-      termCounts.foreachActive { case (idx, count) =>
-        docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t)
-      }
-      // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector
-      docScore += sum((brzAlpha - gammad) :* Elogthetad)
-      docScore += sum(lgamma(gammad) - lgamma(brzAlpha))
-      docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
-
-      docScore
-    }.sum()
-
-    // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar
-    score += sum((eta - lambda) :* Elogbeta)
-    score += sum(lgamma(lambda) - lgamma(eta))
+        docBound
+      }.sum()
 
+    // Bound component for prob(topic-term distributions):
+    //   E[log p(beta | eta) - log q(beta | lambda)]
     val sumEta = eta * vocabSize
-    score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
+    val topicsPart = sum((eta - lambda) :* Elogbeta) +
+      sum(lgamma(lambda) - lgamma(eta)) +
+      sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
 
-    score
+    corpusPart + topicsPart
   }
 
   /**
@@ -310,6 +321,7 @@ class LocalLDAModel private[clustering] (
     // Double transpose because dirichletExpectation normalizes by row and we need to normalize
     // by topic (columns of lambda)
     val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+    val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta)
     val docConcentrationBrz = this.docConcentration.toBreeze
     val gammaShape = this.gammaShape
     val k = this.k
@@ -320,7 +332,7 @@ class LocalLDAModel private[clustering] (
       } else {
         val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
           termCounts,
-          expElogbeta,
+          expElogbetaBc.value,
           docConcentrationBrz,
           gammaShape,
           k)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index b0e14cb8296a6aae903166bba96962748e267ea8..afba2866c7040a4a7b588249d2a6f51207963fc5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -419,6 +419,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
     val k = this.k
     val vocabSize = this.vocabSize
     val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t
+    val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta)
     val alpha = this.alpha.toBreeze
     val gammaShape = this.gammaShape
 
@@ -433,13 +434,14 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
           case v: SparseVector => v.indices.toList
         }
         val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
-          termCounts, expElogbeta, alpha, gammaShape, k)
+          termCounts, expElogbetaBc.value, alpha, gammaShape, k)
         stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
         gammaPart = gammad :: gammaPart
       }
       Iterator((stat, gammaPart))
     }
     val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+    expElogbetaBc.unpersist()
     val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
       stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
     val batchResult = statsSum :* expElogbeta.t
@@ -540,21 +542,22 @@ private[clustering] object OnlineLDAOptimizer {
     val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad))  // K
     val expElogbetad = expElogbeta(ids, ::).toDenseMatrix                        // ids * K
 
-    val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100            // ids
-    var meanchange = 1D
+    val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100            // ids
+    var meanGammaChange = 1D
     val ctsVector = new BDV[Double](cts)                                         // ids
 
     // Iterate between gamma and phi until convergence
-    while (meanchange > 1e-3) {
+    while (meanGammaChange > 1e-3) {
       val lastgamma = gammad.copy
       //        K                  K * ids               ids
-      gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
+      gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha
       expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
-      phinorm := expElogbetad * expElogthetad :+ 1e-100
-      meanchange = sum(abs(gammad - lastgamma)) / k
+      // TODO: Keep more values in log space, and only exponentiate when needed.
+      phiNorm := expElogbetad * expElogthetad :+ 1e-100
+      meanGammaChange = sum(abs(gammad - lastgamma)) / k
     }
 
-    val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
+    val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix
     (gammad, sstatsd)
   }
 }