From f51fd6fbb4d9822502f98b312251e317d757bc3a Mon Sep 17 00:00:00 2001
From: Feynman Liang <fliang@databricks.com>
Date: Fri, 31 Jul 2015 18:36:22 -0700
Subject: [PATCH] [SPARK-8936] [MLLIB] OnlineLDA document-topic Dirichlet
 hyperparameter optimization

Adds `alpha` (document-topic Dirichlet parameter) hyperparameter optimization to `OnlineLDAOptimizer` following Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters. Also introduces a private `setSampleWithReplacement` to `OnlineLDAOptimizer` for unit testing purposes.

Author: Feynman Liang <fliang@databricks.com>

Closes #7836 from feynmanliang/SPARK-8936-alpha-optimize and squashes the following commits:

4bef484 [Feynman Liang] Documentation improvements
c3c6c1d [Feynman Liang] Fix docs
151e859 [Feynman Liang] Fix style
fa77518 [Feynman Liang] Hyperparameter optimization
---
 .../spark/mllib/clustering/LDAOptimizer.scala | 75 ++++++++++++++++---
 .../spark/mllib/clustering/LDASuite.scala     | 34 +++++++++
 2 files changed, 99 insertions(+), 10 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index d6f8b29a43..b0e14cb829 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering
 
 import java.util.Random
 
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
-import breeze.numerics.{abs, exp}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum}
+import breeze.numerics.{trigamma, abs, exp}
 import breeze.stats.distributions.{Gamma, RandBasis}
 
 import org.apache.spark.annotation.DeveloperApi
@@ -239,22 +239,26 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
   /** alias for docConcentration */
   private var alpha: Vector = Vectors.dense(0)
 
-  /** (private[clustering] for debugging)  Get docConcentration */
+  /** (for debugging)  Get docConcentration */
   private[clustering] def getAlpha: Vector = alpha
 
   /** alias for topicConcentration */
   private var eta: Double = 0
 
-  /** (private[clustering] for debugging)  Get topicConcentration */
+  /** (for debugging)  Get topicConcentration */
   private[clustering] def getEta: Double = eta
 
   private var randomGenerator: java.util.Random = null
 
+  /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */
+  private var sampleWithReplacement: Boolean = true
+
   // Online LDA specific parameters
   // Learning rate is: (tau0 + t)^{-kappa}
   private var tau0: Double = 1024
   private var kappa: Double = 0.51
   private var miniBatchFraction: Double = 0.05
+  private var optimizeAlpha: Boolean = false
 
   // internal data structure
   private var docs: RDD[(Long, Vector)] = null
@@ -262,7 +266,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
   /** Dirichlet parameter for the posterior over topics */
   private var lambda: BDM[Double] = null
 
-  /** (private[clustering] for debugging) Get parameter for topics */
+  /** (for debugging) Get parameter for topics */
   private[clustering] def getLambda: BDM[Double] = lambda
 
   /** Current iteration (count of invocations of [[next()]]) */
@@ -325,7 +329,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
   }
 
   /**
-   * (private[clustering])
+   * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution)
+   * will be optimized during training.
+   */
+  def getOptimzeAlpha: Boolean = this.optimizeAlpha
+
+  /**
+   * Sets whether to optimize alpha parameter during training.
+   *
+   * Default: false
+   */
+  def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = {
+    this.optimizeAlpha = optimizeAlpha
+    this
+  }
+
+  /**
    * Set the Dirichlet parameter for the posterior over topics.
    * This is only used for testing now. In the future, it can help support training stop/resume.
    */
@@ -335,7 +354,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
   }
 
   /**
-   * (private[clustering])
    * Used for random initialization of the variational parameters.
    * Larger value produces values closer to 1.0.
    * This is only used for testing currently.
@@ -345,6 +363,15 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
     this
   }
 
+  /**
+   * Sets whether to sample mini-batches with or without replacement. (default = true)
+   * This is only used for testing currently.
+   */
+  private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = {
+    this.sampleWithReplacement = replace
+    this
+  }
+
   override private[clustering] def initialize(
       docs: RDD[(Long, Vector)],
       lda: LDA): OnlineLDAOptimizer = {
@@ -376,7 +403,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
   }
 
   override private[clustering] def next(): OnlineLDAOptimizer = {
-    val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
+    val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction,
+      randomGenerator.nextLong())
     if (batch.isEmpty()) return this
     submitMiniBatch(batch)
   }
@@ -418,6 +446,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
 
     // Note that this is an optimization to avoid batch.count
     updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
+    if (optimizeAlpha) updateAlpha(gammat)
     this
   }
 
@@ -433,13 +462,39 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
       weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
   }
 
-  /** Calculates learning rate rho, which decays as a function of [[iteration]] */
+  /**
+   * Update alpha based on `gammat`, the inferred topic distributions for documents in the
+   * current mini-batch. Uses Newton-Rhapson method.
+   * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters
+   *      (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf)
+   */
+  private def updateAlpha(gammat: BDM[Double]): Unit = {
+    val weight = rho()
+    val N = gammat.rows.toDouble
+    val alpha = this.alpha.toBreeze.toDenseVector
+    val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N
+    val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector)
+
+    val c = N * trigamma(sum(alpha))
+    val q = -N * trigamma(alpha)
+    val b = sum(gradf / q) / (1D / c + sum(1D / q))
+
+    val dalpha = -(gradf - b) / q
+
+    if (all((weight * dalpha + alpha) :> 0D)) {
+      alpha :+= weight * dalpha
+      this.alpha = Vectors.dense(alpha.toArray)
+    }
+  }
+
+
+  /** Calculate learning rate rho for the current [[iteration]]. */
   private def rho(): Double = {
     math.pow(getTau0 + this.iteration, -getKappa)
   }
 
   /**
-   * Get a random matrix to initialize lambda
+   * Get a random matrix to initialize lambda.
    */
   private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
     val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index f2b94707fd..fdc2554ab8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -400,6 +400,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("OnlineLDAOptimizer alpha hyperparameter optimization") {
+    val k = 2
+    val docs = sc.parallelize(toyData)
+    val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
+      .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false)
+    val lda = new LDA().setK(k)
+      .setDocConcentration(1D / k)
+      .setTopicConcentration(0.01)
+      .setMaxIterations(100)
+      .setOptimizer(op)
+      .setSeed(12345)
+    val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel]
+
+    /* Verify the results with gensim:
+      import numpy as np
+      from gensim import models
+      corpus = [
+       [(0, 1.0), (1, 1.0)],
+       [(1, 1.0), (2, 1.0)],
+       [(0, 1.0), (2, 1.0)],
+       [(3, 1.0), (4, 1.0)],
+       [(3, 1.0), (5, 1.0)],
+       [(4, 1.0), (5, 1.0)]]
+      np.random.seed(2345)
+      lda = models.ldamodel.LdaModel(
+         corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100,
+         decay=0.51, offset=1024)
+      print(lda.alpha)
+      > [ 0.42582646  0.43511073]
+     */
+
+    assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05)
+  }
+
   test("model save/load") {
     // Test for LocalLDAModel.
     val localModel = new LocalLDAModel(tinyTopics,
-- 
GitLab