From 2cc212d56a1d50fe68d5816f71b27803de1f6389 Mon Sep 17 00:00:00 2001
From: Feynman Liang <fliang@databricks.com>
Date: Wed, 29 Jul 2015 16:20:20 -0700
Subject: [PATCH] [SPARK-6793] [MLLIB] OnlineLDAOptimizer LDA perplexity

Implements `logPerplexity` in `OnlineLDAOptimizer`. Also refactors inference code into companion object to enable future reuse (e.g. `predict` method).

Author: Feynman Liang <fliang@databricks.com>

Closes #7705 from feynmanliang/SPARK-6793-perplexity and squashes the following commits:

6da2c99 [Feynman Liang] Remove get* from LDAModel public API
8381da6 [Feynman Liang] Code review comments
17f7000 [Feynman Liang] Documentation typo fixes
2f452a4 [Feynman Liang] Remove auxillary DistributedLDAModel constructor
a275914 [Feynman Liang] Prevent empty counts calls to variationalInference
06d02d9 [Feynman Liang] Remove deprecated LocalLDAModel constructor
afecb46 [Feynman Liang] Fix regression bug in sstats accumulator
5a327a0 [Feynman Liang] Code review quick fixes
998c03e [Feynman Liang] Fix style
1cbb67d [Feynman Liang] Fix access modifier bug
4362daa [Feynman Liang] Organize imports
4f171f7 [Feynman Liang] Fix indendation
2f049ce [Feynman Liang] Fix failing save/load tests
7415e96 [Feynman Liang] Pick changes from big PR
11e7c33 [Feynman Liang] Merge remote-tracking branch 'apache/master' into SPARK-6793-perplexity
f8adc48 [Feynman Liang] Add logPerplexity, refactor variationalBound into a method
cd521d6 [Feynman Liang] Refactor methods into companion class
7f62a55 [Feynman Liang] --amend
c62cb1e [Feynman Liang] Outer product for stats, revert Range slicing
aead650 [Feynman Liang] Range slice, in-place update, reduce transposes
---
 .../spark/mllib/clustering/LDAModel.scala     | 200 ++++++++++++++----
 .../spark/mllib/clustering/LDAOptimizer.scala | 138 +++++++-----
 .../spark/mllib/clustering/LDAUtils.scala     |  55 +++++
 .../spark/mllib/clustering/JavaLDASuite.java  |   6 +-
 .../spark/mllib/clustering/LDASuite.scala     |  53 ++++-
 5 files changed, 348 insertions(+), 104 deletions(-)
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 31c1d520fd..059b52ef20 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,10 +17,9 @@
 
 package org.apache.spark.mllib.clustering
 
-import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}
-
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
+import breeze.numerics.{exp, lgamma}
 import org.apache.hadoop.fs.Path
-
 import org.json4s.DefaultFormats
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
@@ -28,14 +27,13 @@ import org.json4s.jackson.JsonMethods._
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaPairRDD
-import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
-import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
-import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
+import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.sql.{Row, SQLContext}
 import org.apache.spark.util.BoundedPriorityQueue
 
-
 /**
  * :: Experimental ::
  *
@@ -53,6 +51,31 @@ abstract class LDAModel private[clustering] extends Saveable {
   /** Vocabulary size (number of terms or terms in the vocabulary) */
   def vocabSize: Int
 
+  /**
+   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+   * distributions over topics ("theta").
+   *
+   * This is the parameter to a Dirichlet distribution.
+   */
+  def docConcentration: Vector
+
+  /**
+   * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+   * distributions over terms.
+   *
+   * This is the parameter to a symmetric Dirichlet distribution.
+   *
+   * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+   * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+   */
+  def topicConcentration: Double
+
+  /**
+  * Shape parameter for random initialization of variational parameter gamma.
+  * Used for variational inference for perplexity and other test-time computations.
+  */
+  protected def gammaShape: Double
+
   /**
    * Inferred topics, where each topic is represented by a distribution over terms.
    * This is a matrix of size vocabSize x k, where each column is a topic.
@@ -168,7 +191,10 @@ abstract class LDAModel private[clustering] extends Saveable {
  */
 @Experimental
 class LocalLDAModel private[clustering] (
-    private val topics: Matrix) extends LDAModel with Serializable {
+    val topics: Matrix,
+    override val docConcentration: Vector,
+    override val topicConcentration: Double,
+    override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
 
   override def k: Int = topics.numCols
 
@@ -197,8 +223,82 @@ class LocalLDAModel private[clustering] (
   // TODO:
   // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
 
+  /**
+   * Calculate the log variational bound on perplexity. See Equation (16) in original Online
+   * LDA paper.
+   * @param documents test corpus to use for calculating perplexity
+   * @return the log perplexity per word
+   */
+  def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
+    val corpusWords = documents
+      .map { case (_, termCounts) => termCounts.toArray.sum }
+      .sum()
+    val batchVariationalBound = bound(documents, docConcentration,
+      topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize)
+    val perWordBound = batchVariationalBound / corpusWords
+
+    perWordBound
+  }
+
+  /**
+   * Estimate the variational likelihood bound of from `documents`:
+   *    log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
+   * This bound is derived by decomposing the LDA model to:
+   *    log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
+   * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper.
+   * @param documents a subset of the test corpus
+   * @param alpha document-topic Dirichlet prior parameters
+   * @param eta topic-word Dirichlet prior parameters
+   * @param lambda parameters for variational q(beta | lambda) topic-word distributions
+   * @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
+   *                   topic mixture distributions
+   * @param k number of topics
+   * @param vocabSize number of unique terms in the entire test corpus
+   */
+  private def bound(
+      documents: RDD[(Long, Vector)],
+      alpha: Vector,
+      eta: Double,
+      lambda: BDM[Double],
+      gammaShape: Double,
+      k: Int,
+      vocabSize: Long): Double = {
+    val brzAlpha = alpha.toBreeze.toDenseVector
+    // transpose because dirichletExpectation normalizes by row and we need to normalize
+    // by topic (columns of lambda)
+    val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
+
+    var score = documents.filter(_._2.numActives > 0).map { case (id: Long, termCounts: Vector) =>
+      var docScore = 0.0D
+      val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
+        termCounts, exp(Elogbeta), brzAlpha, gammaShape, k)
+      val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
+
+      // E[log p(doc | theta, beta)]
+      termCounts.foreachActive { case (idx, count) =>
+        docScore += LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t)
+      }
+      // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector
+      docScore += sum((brzAlpha - gammad) :* Elogthetad)
+      docScore += sum(lgamma(gammad) - lgamma(brzAlpha))
+      docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
+
+      docScore
+    }.sum()
+
+    // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar
+    score += sum((eta - lambda) :* Elogbeta)
+    score += sum(lgamma(lambda) - lgamma(eta))
+
+    val sumEta = eta * vocabSize
+    score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
+
+    score
+  }
+
 }
 
+
 @Experimental
 object LocalLDAModel extends Loader[LocalLDAModel] {
 
@@ -212,6 +312,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
     // as a Row in data.
     case class Data(topic: Vector, index: Int)
 
+    // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in
+    // model.predict()
     def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
       val sqlContext = SQLContext.getOrCreate(sc)
       import sqlContext.implicits._
@@ -219,7 +321,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
       val k = topicsMatrix.numCols
       val metadata = compact(render
         (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
-         ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
+          ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
       sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
 
       val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
@@ -243,7 +345,11 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
       topics.foreach { case Row(vec: Vector, ind: Int) =>
         brzTopics(::, ind) := vec.toBreeze
       }
-      new LocalLDAModel(Matrices.fromBreeze(brzTopics))
+      val topicsMat = Matrices.fromBreeze(brzTopics)
+
+      // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940
+      new LocalLDAModel(topicsMat,
+        Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D)
     }
   }
 
@@ -259,8 +365,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
         SaveLoadV1_0.load(sc, path)
       case _ => throw new Exception(
         s"LocalLDAModel.load did not recognize model with (className, format version):" +
-        s"($loadedClassName, $loadedVersion).  Supported:\n" +
-        s"  ($classNameV1_0, 1.0)")
+          s"($loadedClassName, $loadedVersion).  Supported:\n" +
+          s"  ($classNameV1_0, 1.0)")
     }
 
     val topicsMatrix = model.topicsMatrix
@@ -268,7 +374,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
       s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
     require(expectedVocabSize == topicsMatrix.numRows,
       s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
-      s"but got ${topicsMatrix.numRows}")
+        s"but got ${topicsMatrix.numRows}")
     model
   }
 }
@@ -282,28 +388,25 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
  * than the [[LocalLDAModel]].
  */
 @Experimental
-class DistributedLDAModel private (
+class DistributedLDAModel private[clustering] (
     private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
     private[clustering] val globalTopicTotals: LDA.TopicCounts,
     val k: Int,
     val vocabSize: Int,
-    private[clustering] val docConcentration: Double,
-    private[clustering] val topicConcentration: Double,
+    override val docConcentration: Vector,
+    override val topicConcentration: Double,
+    override protected[clustering] val gammaShape: Double,
     private[spark] val iterationTimes: Array[Double]) extends LDAModel {
 
   import LDA._
 
-  private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
-    this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
-      state.topicConcentration, iterationTimes)
-  }
-
   /**
    * Convert model to a local model.
    * The local model stores the inferred topics but not the topic distributions for training
    * documents.
    */
-  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix)
+  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
+    gammaShape)
 
   /**
    * Inferred topics, where each topic is represented by a distribution over terms.
@@ -375,8 +478,9 @@ class DistributedLDAModel private (
    *    hyperparameters.
    */
   lazy val logLikelihood: Double = {
-    val eta = topicConcentration
-    val alpha = docConcentration
+    // TODO: generalize this for asymmetric (non-scalar) alpha
+    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
+    val eta = this.topicConcentration
     assert(eta > 1.0)
     assert(alpha > 1.0)
     val N_k = globalTopicTotals
@@ -400,8 +504,9 @@ class DistributedLDAModel private (
    *  log P(topics, topic distributions for docs | alpha, eta)
    */
   lazy val logPrior: Double = {
-    val eta = topicConcentration
-    val alpha = docConcentration
+    // TODO: generalize this for asymmetric (non-scalar) alpha
+    val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object
+    val eta = this.topicConcentration
     // Term vertices: Compute phi_{wk}.  Use to compute prior log probability.
     // Doc vertex: Compute theta_{kj}.  Use to compute prior log probability.
     val N_k = globalTopicTotals
@@ -412,12 +517,12 @@ class DistributedLDAModel private (
           val N_wk = vertex._2
           val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
           val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
-          (eta - 1.0) * brzSum(phi_wk.map(math.log))
+          (eta - 1.0) * sum(phi_wk.map(math.log))
         } else {
           val N_kj = vertex._2
           val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
           val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
-          (alpha - 1.0) * brzSum(theta_kj.map(math.log))
+          (alpha - 1.0) * sum(theta_kj.map(math.log))
         }
     }
     graph.vertices.aggregate(0.0)(seqOp, _ + _)
@@ -448,7 +553,7 @@ class DistributedLDAModel private (
   override def save(sc: SparkContext, path: String): Unit = {
     DistributedLDAModel.SaveLoadV1_0.save(
       sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
-      iterationTimes)
+      iterationTimes, gammaShape)
   }
 }
 
@@ -478,17 +583,20 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
         globalTopicTotals: LDA.TopicCounts,
         k: Int,
         vocabSize: Int,
-        docConcentration: Double,
+        docConcentration: Vector,
         topicConcentration: Double,
-        iterationTimes: Array[Double]): Unit = {
+        iterationTimes: Array[Double],
+        gammaShape: Double): Unit = {
       val sqlContext = SQLContext.getOrCreate(sc)
       import sqlContext.implicits._
 
       val metadata = compact(render
         (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
-         ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
-         ("topicConcentration" -> topicConcentration) ~
-         ("iterationTimes" -> iterationTimes.toSeq)))
+          ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
+          ("docConcentration" -> docConcentration.toArray.toSeq) ~
+          ("topicConcentration" -> topicConcentration) ~
+          ("iterationTimes" -> iterationTimes.toSeq) ~
+          ("gammaShape" -> gammaShape)))
       sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
 
       val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
@@ -510,9 +618,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
         sc: SparkContext,
         path: String,
         vocabSize: Int,
-        docConcentration: Double,
+        docConcentration: Vector,
         topicConcentration: Double,
-        iterationTimes: Array[Double]): DistributedLDAModel = {
+        iterationTimes: Array[Double],
+        gammaShape: Double): DistributedLDAModel = {
       val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
       val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
       val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
@@ -536,7 +645,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
       val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
 
       new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
-        docConcentration, topicConcentration, iterationTimes)
+        docConcentration, topicConcentration, gammaShape, iterationTimes)
     }
 
   }
@@ -546,32 +655,35 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
     implicit val formats = DefaultFormats
     val expectedK = (metadata \ "k").extract[Int]
     val vocabSize = (metadata \ "vocabSize").extract[Int]
-    val docConcentration = (metadata \ "docConcentration").extract[Double]
+    val docConcentration =
+      Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray)
     val topicConcentration = (metadata \ "topicConcentration").extract[Double]
     val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
+    val gammaShape = (metadata \ "gammaShape").extract[Double]
     val classNameV1_0 = SaveLoadV1_0.classNameV1_0
 
     val model = (loadedClassName, loadedVersion) match {
       case (className, "1.0") if className == classNameV1_0 => {
-        DistributedLDAModel.SaveLoadV1_0.load(
-          sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
+        DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration,
+          topicConcentration, iterationTimes.toArray, gammaShape)
       }
       case _ => throw new Exception(
         s"DistributedLDAModel.load did not recognize model with (className, format version):" +
-        s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 1.0)")
+          s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 1.0)")
     }
 
     require(model.vocabSize == vocabSize,
       s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
     require(model.docConcentration == docConcentration,
       s"DistributedLDAModel requires $docConcentration docConcentration, " +
-      s"got ${model.docConcentration} docConcentration")
+        s"got ${model.docConcentration} docConcentration")
     require(model.topicConcentration == topicConcentration,
       s"DistributedLDAModel requires $topicConcentration docConcentration, " +
-      s"got ${model.topicConcentration} docConcentration")
+        s"got ${model.topicConcentration} docConcentration")
     require(expectedK == model.k,
       s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
     model
   }
 
 }
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index f4170a3d98..7e75e7083a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
 import java.util.Random
 
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
-import breeze.numerics.{abs, digamma, exp}
+import breeze.numerics.{abs, exp}
 import breeze.stats.distributions.{Gamma, RandBasis}
 
 import org.apache.spark.annotation.DeveloperApi
@@ -208,7 +208,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
   override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
     require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
     this.graphCheckpointer.deleteAllCheckpoints()
-    new DistributedLDAModel(this, iterationTimes)
+    // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
+    // conversion
+    new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
+      Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
+      100, iterationTimes)
   }
 }
 
@@ -385,71 +389,52 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
     iteration += 1
     val k = this.k
     val vocabSize = this.vocabSize
-    val Elogbeta = dirichletExpectation(lambda).t
-    val expElogbeta = exp(Elogbeta)
+    val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t
     val alpha = this.alpha.toBreeze
     val gammaShape = this.gammaShape
 
-    val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
-      val stat = BDM.zeros[Double](k, vocabSize)
-      docs.foreach { doc =>
-        val termCounts = doc._2
-        val (ids: List[Int], cts: Array[Double]) = termCounts match {
-          case v: DenseVector => ((0 until v.size).toList, v.values)
-          case v: SparseVector => (v.indices.toList, v.values)
-          case v => throw new IllegalArgumentException("Online LDA does not support vector type "
-            + v.getClass)
-        }
-        if (!ids.isEmpty) {
-
-          // Initialize the variational distribution q(theta|gamma) for the mini-batch
-          val gammad: BDV[Double] =
-            new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
-          val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
-          val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K
-
-          val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
-          var meanchange = 1D
-          val ctsVector = new BDV[Double](cts) // ids
-
-          // Iterate between gamma and phi until convergence
-          while (meanchange > 1e-3) {
-            val lastgamma = gammad.copy
-            //        K                  K * ids               ids
-            gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
-            expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
-            phinorm := expElogbetad * expElogthetad :+ 1e-100
-            meanchange = sum(abs(gammad - lastgamma)) / k
-          }
+    val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs =>
+      val nonEmptyDocs = docs.filter(_._2.numActives > 0)
 
-          stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
+      val stat = BDM.zeros[Double](k, vocabSize)
+      var gammaPart = List[BDV[Double]]()
+      nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) =>
+        val ids: List[Int] = termCounts match {
+          case v: DenseVector => (0 until v.size).toList
+          case v: SparseVector => v.indices.toList
         }
+        val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
+          termCounts, expElogbeta, alpha, gammaShape, k)
+        stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
+        gammaPart = gammad :: gammaPart
       }
-      Iterator(stat)
+      Iterator((stat, gammaPart))
     }
-
-    val statsSum: BDM[Double] = stats.reduce(_ += _)
+    val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+    val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
+      stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
     val batchResult = statsSum :* expElogbeta.t
 
     // Note that this is an optimization to avoid batch.count
-    update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
+    updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
     this
   }
 
-  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
-    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
-  }
-
   /**
    * Update lambda based on the batch submitted. batchSize can be different for each iteration.
    */
-  private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
+  private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = {
     // weight of the mini-batch.
-    val weight = math.pow(getTau0 + iter, -getKappa)
+    val weight = rho()
 
     // Update lambda based on documents.
-    lambda = lambda * (1 - weight) +
-      (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
+    lambda := (1 - weight) * lambda +
+      weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
+  }
+
+  /** Calculates learning rate rho, which decays as a function of [[iteration]] */
+  private def rho(): Double = {
+    math.pow(getTau0 + this.iteration, -getKappa)
   }
 
   /**
@@ -463,15 +448,56 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
     new BDM[Double](col, row, temp).t
   }
 
+  override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
+    new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
+  }
+
+}
+
+/**
+ * Serializable companion object containing helper methods and shared code for
+ * [[OnlineLDAOptimizer]] and [[LocalLDAModel]].
+ */
+private[clustering] object OnlineLDAOptimizer {
   /**
-   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
-   * uses digamma which is accurate but expensive.
+   * Uses variational inference to infer the topic distribution `gammad` given the term counts
+   * for a document. `termCounts` must be non-empty, otherwise Breeze will throw a BLAS error.
+   *
+   * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001)
+   * avoids explicit computation of variational parameter `phi`.
+   * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]]
    */
-  private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
-    val rowSum = sum(alpha(breeze.linalg.*, ::))
-    val digAlpha = digamma(alpha)
-    val digRowSum = digamma(rowSum)
-    val result = digAlpha(::, breeze.linalg.*) - digRowSum
-    result
+  private[clustering] def variationalTopicInference(
+      termCounts: Vector,
+      expElogbeta: BDM[Double],
+      alpha: breeze.linalg.Vector[Double],
+      gammaShape: Double,
+      k: Int): (BDV[Double], BDM[Double]) = {
+    val (ids: List[Int], cts: Array[Double]) = termCounts match {
+      case v: DenseVector => ((0 until v.size).toList, v.values)
+      case v: SparseVector => (v.indices.toList, v.values)
+    }
+    // Initialize the variational distribution q(theta|gamma) for the mini-batch
+    val gammad: BDV[Double] =
+      new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k)                   // K
+    val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad))  // K
+    val expElogbetad = expElogbeta(ids, ::).toDenseMatrix                        // ids * K
+
+    val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100            // ids
+    var meanchange = 1D
+    val ctsVector = new BDV[Double](cts)                                         // ids
+
+    // Iterate between gamma and phi until convergence
+    while (meanchange > 1e-3) {
+      val lastgamma = gammad.copy
+      //        K                  K * ids               ids
+      gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
+      expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
+      phinorm := expElogbetad * expElogthetad :+ 1e-100
+      meanchange = sum(abs(gammad - lastgamma)) / k
+    }
+
+    val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
+    (gammad, sstatsd)
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
new file mode 100644
index 0000000000..f7e5ce1665
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.clustering
+
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum}
+import breeze.numerics._
+
+/**
+ * Utility methods for LDA.
+ */
+object LDAUtils {
+  /**
+   * Log Sum Exp with overflow protection using the identity:
+   * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\}
+   */
+  private[clustering] def logSumExp(x: BDV[Double]): Double = {
+    val a = max(x)
+    a + log(sum(exp(x :- a)))
+  }
+
+  /**
+   * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
+   * uses [[breeze.numerics.digamma]] which is accurate but expensive.
+   */
+  private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = {
+    digamma(alpha) - digamma(sum(alpha))
+  }
+
+  /**
+   * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are
+   * Dirichlet parameters.
+   */
+  private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
+    val rowSum = sum(alpha(breeze.linalg.*, ::))
+    val digAlpha = digamma(alpha)
+    val digRowSum = digamma(rowSum)
+    val result = digAlpha(::, breeze.linalg.*) - digRowSum
+    result
+  }
+
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index b48f190f59..d272a42c85 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.clustering;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
 
 import scala.Tuple2;
 
@@ -59,7 +60,10 @@ public class JavaLDASuite implements Serializable {
 
   @Test
   public void localLDAModel() {
-    LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
+    Matrix topics = LDASuite$.MODULE$.tinyTopics();
+    double[] topicConcentration = new double[topics.numRows()];
+    Arrays.fill(topicConcentration, 1.0D / topics.numRows());
+    LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D);
 
     // Check: basic parameters
     assertEquals(model.k(), tinyK);
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 376a87f051..aa36336ebb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.mllib.clustering
 
-import breeze.linalg.{DenseMatrix => BDM}
+import breeze.linalg.{DenseMatrix => BDM, max, argmax}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.graphx.Edge
@@ -31,7 +31,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
   import LDASuite._
 
   test("LocalLDAModel") {
-    val model = new LocalLDAModel(tinyTopics)
+    val model = new LocalLDAModel(tinyTopics,
+      Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
 
     // Check: basic parameters
     assert(model.k === tinyK)
@@ -235,6 +236,51 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("LocalLDAModel logPerplexity") {
+    val k = 2
+    val vocabSize = 6
+    val alpha = 0.01
+    val eta = 0.01
+    val gammaShape = 100
+    val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
+      1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
+      0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
+
+    def toydata: Array[(Long, Vector)] = Array(
+      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
+      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
+      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
+      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
+      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
+      Vectors.sparse(6, Array(4, 5), Array(1, 1))
+    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+    val docs = sc.parallelize(toydata)
+
+
+    val ldaModel: LocalLDAModel = new LocalLDAModel(
+      topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
+
+    /* Verify results using gensim:
+       import numpy as np
+       from gensim import models
+       corpus = [
+          [(0, 1.0), (1, 1.0)],
+          [(1, 1.0), (2, 1.0)],
+          [(0, 1.0), (2, 1.0)],
+          [(3, 1.0), (4, 1.0)],
+          [(3, 1.0), (5, 1.0)],
+          [(4, 1.0), (5, 1.0)]]
+       np.random.seed(2345)
+       lda = models.ldamodel.LdaModel(
+          corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
+          decay=0.51, offset=1024)
+       print(lda.log_perplexity(corpus))
+       > -3.69051285096
+     */
+
+    assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D)
+  }
+
   test("OnlineLDAOptimizer with asymmetric prior") {
     def toydata: Array[(Long, Vector)] = Array(
       Vectors.sparse(6, Array(0, 1), Array(1, 1)),
@@ -287,7 +333,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
 
   test("model save/load") {
     // Test for LocalLDAModel.
-    val localModel = new LocalLDAModel(tinyTopics)
+    val localModel = new LocalLDAModel(tinyTopics,
+      Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
     val tempDir1 = Utils.createTempDir()
     val path1 = tempDir1.toURI.toString
 
-- 
GitLab