Skip to content
Snippets Groups Projects
Commit 3539cb7d authored by Yuhao Yang's avatar Yuhao Yang Committed by Joseph K. Bradley
Browse files

[SPARK-5563] [MLLIB] LDA with online variational inference

JIRA: https://issues.apache.org/jira/browse/SPARK-5563
The PR contains the implementation for [Online LDA] (https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) based on the research of  Matt Hoffman and David M. Blei, which provides an efficient option for LDA users. Major advantages for the algorithm are the stream compatibility and economic time/memory consumption due to the corpus split. For more details, please refer to the jira.

Online LDA can act as a fast option for LDA, and will be especially helpful for the users who needs a quick result or with large corpus.

 Correctness test.
I have tested current PR with https://github.com/Blei-Lab/onlineldavb and the results are identical. I've uploaded the result and code to https://github.com/hhbyyh/LDACrossValidation.

Author: Yuhao Yang <hhbyyh@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4419 from hhbyyh/ldaonline and squashes the following commits:

1045eec [Yuhao Yang] Merge pull request #2 from jkbradley/hhbyyh-ldaonline2
cf376ff [Joseph K. Bradley] For private vars needed for testing, I made them private and added accessors.  Java doesn’t understand package-private tags, so this minimizes the issues Java users might encounter.
6149ca6 [Yuhao Yang] fix for setOptimizer
cf0007d [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
54cf8da [Yuhao Yang] some style change
68c2318 [Yuhao Yang] add a java ut
4041723 [Yuhao Yang] add ut
138bfed [Yuhao Yang] Merge pull request #1 from jkbradley/hhbyyh-ldaonline-update
9e910d9 [Joseph K. Bradley] small fix
61d60df [Joseph K. Bradley] Minor cleanups: * Update *Concentration parameter documentation * EM Optimizer: createVertices() does not need to be a function * OnlineLDAOptimizer: typos in doc * Clean up the core code for online LDA (Scala style)
a996a82 [Yuhao Yang] respond to comments
b1178cf [Yuhao Yang] fit into the optimizer framework
dbe3cff [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
15be071 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
b29193b [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
d19ef55 [Yuhao Yang] change OnlineLDA to class
97b9e1a [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
e7bf3b0 [Yuhao Yang] move to seperate file
f367cc9 [Yuhao Yang] change to optimization
8cb16a6 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
62405cc [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
02d0373 [Yuhao Yang] fix style in comment
f6d47ca [Yuhao Yang] Merge branch 'ldaonline' of https://github.com/hhbyyh/spark into ldaonline
d86cdec [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
a570c9a [Yuhao Yang] use sample to pick up batch
4a3f27e [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline
e271eb1 [Yuhao Yang] remove non ascii
581c623 [Yuhao Yang] seperate API and adjust batch split
37af91a [Yuhao Yang] iMerge remote-tracking branch 'upstream/master' into ldaonline
20328d1 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline i
aa365d1 [Yuhao Yang] merge upstream master
3a06526 [Yuhao Yang] merge with new example
0dd3947 [Yuhao Yang] kMerge remote-tracking branch 'upstream/master' into ldaonline
0d0f3ee [Yuhao Yang] replace random split with sliding
fa408a8 [Yuhao Yang] ssMerge remote-tracking branch 'upstream/master' into ldaonline
45884ab [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline s
f41c5ca [Yuhao Yang] style fix
26dca1b [Yuhao Yang] style fix and make class private
043e786 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaonline s Conflicts: 	mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
d640d9c [Yuhao Yang] online lda initial checkin
parent 9646018b
No related branches found
No related tags found
No related merge requests found
......@@ -78,35 +78,29 @@ class LDA private (
*
* This is the parameter to a symmetric Dirichlet distribution.
*/
def getDocConcentration: Double = {
if (this.docConcentration == -1) {
(50.0 / k) + 1.0
} else {
this.docConcentration
}
}
def getDocConcentration: Double = this.docConcentration
/**
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
* distributions over topics ("theta").
*
* This is the parameter to a symmetric Dirichlet distribution.
* This is the parameter to a symmetric Dirichlet distribution, where larger values
* mean more smoothing (more regularization).
*
* This value should be > 1.0, where larger values mean more smoothing (more regularization).
* If set to -1, then docConcentration is set automatically.
* (default = -1 = automatic)
*
* Automatic setting of parameter:
* - For EM: default = (50 / k) + 1.
* - The 50/k is common in LDA libraries.
* - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
*
* Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
* but values in (0,1) are not yet supported.
* Optimizer-specific parameter settings:
* - EM
* - Value should be > 1.0
* - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
* Asuncion et al. (2009), who recommend a +1 adjustment for EM.
* - Online
* - Value should be >= 0
* - default = (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*/
def setDocConcentration(docConcentration: Double): this.type = {
require(docConcentration > 1.0 || docConcentration == -1.0,
s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration")
this.docConcentration = docConcentration
this
}
......@@ -126,13 +120,7 @@ class LDA private (
* Note: The topics' distributions over terms are called "beta" in the original LDA paper
* by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
*/
def getTopicConcentration: Double = {
if (this.topicConcentration == -1) {
1.1
} else {
this.topicConcentration
}
}
def getTopicConcentration: Double = this.topicConcentration
/**
* Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
......@@ -143,21 +131,20 @@ class LDA private (
* Note: The topics' distributions over terms are called "beta" in the original LDA paper
* by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
*
* This value should be > 0.0.
* If set to -1, then topicConcentration is set automatically.
* (default = -1 = automatic)
*
* Automatic setting of parameter:
* - For EM: default = 0.1 + 1.
* - The 0.1 gives a small amount of smoothing.
* - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
*
* Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
* but values in (0,1) are not yet supported.
* Optimizer-specific parameter settings:
* - EM
* - Value should be > 1.0
* - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows
* Asuncion et al. (2009), who recommend a +1 adjustment for EM.
* - Online
* - Value should be >= 0
* - default = (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*/
def setTopicConcentration(topicConcentration: Double): this.type = {
require(topicConcentration > 1.0 || topicConcentration == -1.0,
s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration")
this.topicConcentration = topicConcentration
this
}
......@@ -223,14 +210,15 @@ class LDA private (
/**
* Set the LDAOptimizer used to perform the actual calculation by algorithm name.
* Currently "em" is supported.
* Currently "em", "online" is supported.
*/
def setOptimizer(optimizerName: String): this.type = {
this.ldaOptimizer =
optimizerName.toLowerCase match {
case "em" => new EMLDAOptimizer
case "online" => new OnlineLDAOptimizer
case other =>
throw new IllegalArgumentException(s"Only em is supported but got $other.")
throw new IllegalArgumentException(s"Only em, online are supported but got $other.")
}
this
}
......@@ -245,8 +233,7 @@ class LDA private (
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)]): LDAModel = {
val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
seed, checkpointInterval)
val state = ldaOptimizer.initialize(documents, this)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
......
......@@ -19,13 +19,15 @@ package org.apache.spark.mllib.clustering
import java.util.Random
import breeze.linalg.{DenseVector => BDV, normalize}
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
import breeze.numerics.{digamma, exp, abs}
import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.Experimental
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
import org.apache.spark.rdd.RDD
/**
......@@ -35,7 +37,7 @@ import org.apache.spark.rdd.RDD
* hold optimizer-specific parameters for users to set.
*/
@Experimental
trait LDAOptimizer{
trait LDAOptimizer {
/*
DEVELOPERS NOTE:
......@@ -49,13 +51,7 @@ trait LDAOptimizer{
* Initializer for the optimizer. LDA passes the common parameters to the optimizer and
* the internal structure can be initialized properly.
*/
private[clustering] def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): LDAOptimizer
private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer
private[clustering] def next(): LDAOptimizer
......@@ -80,12 +76,12 @@ trait LDAOptimizer{
*
*/
@Experimental
class EMLDAOptimizer extends LDAOptimizer{
class EMLDAOptimizer extends LDAOptimizer {
import LDA._
/**
* Following fields will only be initialized through initialState method
* The following fields will only be initialized through the initialize() method
*/
private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
private[clustering] var k: Int = 0
......@@ -98,13 +94,23 @@ class EMLDAOptimizer extends LDAOptimizer{
/**
* Compute bipartite term/doc graph.
*/
private[clustering] override def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): LDAOptimizer = {
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
val docConcentration = lda.getDocConcentration
val topicConcentration = lda.getTopicConcentration
val k = lda.getK
// Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
// but values in (0,1) are not yet supported.
require(docConcentration > 1.0 || docConcentration == -1.0, s"LDA docConcentration must be" +
s" > 1.0 (or -1 for auto) for EM Optimizer, but was set to $docConcentration")
require(topicConcentration > 1.0 || topicConcentration == -1.0, s"LDA topicConcentration " +
s"must be > 1.0 (or -1 for auto) for EM Optimizer, but was set to $topicConcentration")
this.docConcentration = if (docConcentration == -1) (50.0 / k) + 1.0 else docConcentration
this.topicConcentration = if (topicConcentration == -1) 1.1 else topicConcentration
val randomSeed = lda.getSeed
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
// Add edges for terms with non-zero counts.
......@@ -113,11 +119,9 @@ class EMLDAOptimizer extends LDAOptimizer{
}
}
val vocabSize = docs.take(1).head._2.size
// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
def createVertices(): RDD[(VertexId, TopicCounts)] = {
val docTermVertices: RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
......@@ -130,22 +134,18 @@ class EMLDAOptimizer extends LDAOptimizer{
verticesTMP.reduceByKey(_ + _)
}
val docTermVertices = createVertices()
// Partition such that edges are grouped by document
this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D)
this.k = k
this.vocabSize = vocabSize
this.docConcentration = docConcentration
this.topicConcentration = topicConcentration
this.checkpointInterval = checkpointInterval
this.vocabSize = docs.take(1).head._2.size
this.checkpointInterval = lda.getCheckpointInterval
this.graphCheckpointer = new
PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
this.globalTopicTotals = computeGlobalTopicTotals()
this
}
private[clustering] override def next(): EMLDAOptimizer = {
override private[clustering] def next(): EMLDAOptimizer = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
val eta = topicConcentration
......@@ -202,9 +202,269 @@ class EMLDAOptimizer extends LDAOptimizer{
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
}
private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
this.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(this, iterationTimes)
}
}
/**
* :: Experimental ::
*
* An online optimizer for LDA. The Optimizer implements the Online variational Bayes LDA
* algorithm, which processes a subset of the corpus on each iteration, and updates the term-topic
* distribution adaptively.
*
* Original Online LDA paper:
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
*/
@Experimental
class OnlineLDAOptimizer extends LDAOptimizer {
// LDA common parameters
private var k: Int = 0
private var corpusSize: Long = 0
private var vocabSize: Int = 0
/** alias for docConcentration */
private var alpha: Double = 0
/** (private[clustering] for debugging) Get docConcentration */
private[clustering] def getAlpha: Double = alpha
/** alias for topicConcentration */
private var eta: Double = 0
/** (private[clustering] for debugging) Get topicConcentration */
private[clustering] def getEta: Double = eta
private var randomGenerator: java.util.Random = null
// Online LDA specific parameters
// Learning rate is: (tau_0 + t)^{-kappa}
private var tau_0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.05
// internal data structure
private var docs: RDD[(Long, Vector)] = null
/** Dirichlet parameter for the posterior over topics */
private var lambda: BDM[Double] = null
/** (private[clustering] for debugging) Get parameter for topics */
private[clustering] def getLambda: BDM[Double] = lambda
/** Current iteration (count of invocations of [[next()]]) */
private var iteration: Int = 0
private var gammaShape: Double = 100
/**
* A (positive) learning parameter that downweights early iterations. Larger values make early
* iterations count less.
*/
def getTau_0: Double = this.tau_0
/**
* A (positive) learning parameter that downweights early iterations. Larger values make early
* iterations count less.
* Default: 1024, following the original Online LDA paper.
*/
def setTau_0(tau_0: Double): this.type = {
require(tau_0 > 0, s"LDA tau_0 must be positive, but was set to $tau_0")
this.tau_0 = tau_0
this
}
/**
* Learning rate: exponential decay rate
*/
def getKappa: Double = this.kappa
/**
* Learning rate: exponential decay rate---should be between
* (0.5, 1.0] to guarantee asymptotic convergence.
* Default: 0.51, based on the original Online LDA paper.
*/
def setKappa(kappa: Double): this.type = {
require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa")
this.kappa = kappa
this
}
/**
* Mini-batch fraction, which sets the fraction of document sampled and used in each iteration
*/
def getMiniBatchFraction: Double = this.miniBatchFraction
/**
* Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in
* each iteration.
*
* Note that this should be adjusted in synch with [[LDA.setMaxIterations()]]
* so the entire corpus is used. Specifically, set both so that
* maxIterations * miniBatchFraction >= 1.
*
* Default: 0.05, i.e., 5% of total documents.
*/
def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0,
s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction")
this.miniBatchFraction = miniBatchFraction
this
}
/**
* (private[clustering])
* Set the Dirichlet parameter for the posterior over topics.
* This is only used for testing now. In the future, it can help support training stop/resume.
*/
private[clustering] def setLambda(lambda: BDM[Double]): this.type = {
this.lambda = lambda
this
}
/**
* (private[clustering])
* Used for random initialization of the variational parameters.
* Larger value produces values closer to 1.0.
* This is only used for testing currently.
*/
private[clustering] def setGammaShape(shape: Double): this.type = {
this.gammaShape = shape
this
}
override private[clustering] def initialize(
docs: RDD[(Long, Vector)],
lda: LDA): OnlineLDAOptimizer = {
this.k = lda.getK
this.corpusSize = docs.count()
this.vocabSize = docs.first()._2.size
this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
this.randomGenerator = new Random(lda.getSeed)
this.docs = docs
// Initialize the variational distribution q(beta|lambda)
this.lambda = getGammaMatrix(k, vocabSize)
this.iteration = 0
this
}
override private[clustering] def next(): OnlineLDAOptimizer = {
val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
if (batch.isEmpty()) return this
submitMiniBatch(batch)
}
/**
* Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA
* model, and it will update the topic distribution adaptively for the terms appearing in the
* subset.
*/
private[clustering] def submitMiniBatch(batch: RDD[(Long, Vector)]): OnlineLDAOptimizer = {
iteration += 1
val k = this.k
val vocabSize = this.vocabSize
val Elogbeta = dirichletExpectation(lambda)
val expElogbeta = exp(Elogbeta)
val alpha = this.alpha
val gammaShape = this.gammaShape
val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
val stat = BDM.zeros[Double](k, vocabSize)
docs.foreach { doc =>
val termCounts = doc._2
val (ids: List[Int], cts: Array[Double]) = termCounts match {
case v: DenseVector => ((0 until v.size).toList, v.values)
case v: SparseVector => (v.indices.toList, v.values)
case v => throw new IllegalArgumentException("Online LDA does not support vector type "
+ v.getClass)
}
// Initialize the variational distribution q(theta|gamma) for the mini-batch
var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
var expElogthetad = exp(Elogthetad) // 1 * K
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts).t // 1 * ids
// Iterate between gamma and phi until convergence
while (meanchange > 1e-3) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
Elogthetad = digamma(gammad) - digamma(sum(gammad))
expElogthetad = exp(Elogthetad)
phinorm = expElogthetad * expElogbetad + 1e-100
meanchange = sum(abs(gammad - lastgamma)) / k
}
val m1 = expElogthetad.t
val m2 = (ctsVector / phinorm).t.toDenseVector
var i = 0
while (i < ids.size) {
stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
i += 1
}
}
Iterator(stat)
}
val statsSum: BDM[Double] = stats.reduce(_ += _)
val batchResult = statsSum :* expElogbeta
// Note that this is an optimization to avoid batch.count
update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
this
}
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
}
/**
* Update lambda based on the batch submitted. batchSize can be different for each iteration.
*/
private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
val tau_0 = this.getTau_0
val kappa = this.getKappa
// weight of the mini-batch.
val weight = math.pow(tau_0 + iter, -kappa)
// Update lambda based on documents.
lambda = lambda * (1 - weight) +
(stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight
}
/**
* Get a random matrix to initialize lambda
*/
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
randomGenerator.nextLong()))
val gammaRandomGenerator = new Gamma(gammaShape, 1.0 / gammaShape)(randBasis)
val temp = gammaRandomGenerator.sample(row * col).toArray
new BDM[Double](col, row, temp).t
}
/**
* For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation
* uses digamma which is accurate but expensive.
*/
private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
val rowSum = sum(alpha(breeze.linalg.*, ::))
val digAlpha = digamma(alpha)
val digRowSum = digamma(rowSum)
val result = digAlpha(::, breeze.linalg.*) - digRowSum
result
}
}
......@@ -20,7 +20,6 @@ package org.apache.spark.mllib.clustering;
import java.io.Serializable;
import java.util.ArrayList;
import org.apache.spark.api.java.JavaRDD;
import scala.Tuple2;
import org.junit.After;
......@@ -30,6 +29,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
......@@ -109,11 +109,45 @@ public class JavaLDASuite implements Serializable {
assert(model.logPrior() < 0.0);
}
@Test
public void OnlineOptimizerCompatibility() {
int k = 3;
double topicSmoothing = 1.2;
double termSmoothing = 1.2;
// Train a model
OnlineLDAOptimizer op = new OnlineLDAOptimizer()
.setTau_0(1024)
.setKappa(0.51)
.setGammaShape(1e40)
.setMiniBatchFraction(0.5);
LDA lda = new LDA();
lda.setK(k)
.setDocConcentration(topicSmoothing)
.setTopicConcentration(termSmoothing)
.setMaxIterations(5)
.setSeed(12345)
.setOptimizer(op);
LDAModel model = lda.run(corpus);
// Check: basic parameters
assertEquals(model.k(), k);
assertEquals(model.vocabSize(), tinyVocabSize);
// Check: topic summaries
Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
assertEquals(roundedTopicSummary.length, k);
Tuple2<int[], double[]>[] roundedLocalTopicSummary = model.describeTopics();
assertEquals(roundedLocalTopicSummary.length, k);
}
private static int tinyK = LDASuite$.MODULE$.tinyK();
private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite$.MODULE$.tinyTopicDescription();
JavaPairRDD<Long, Vector> corpus;
private JavaPairRDD<Long, Vector> corpus;
}
......@@ -17,6 +17,8 @@
package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
......@@ -37,7 +39,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
// Check: describeTopics() with all terms
val fullTopicSummary = model.describeTopics()
assert(fullTopicSummary.size === tinyK)
assert(fullTopicSummary.length === tinyK)
fullTopicSummary.zip(tinyTopicDescription).foreach {
case ((algTerms, algTermWeights), (terms, termWeights)) =>
assert(algTerms === terms)
......@@ -54,7 +56,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
}
}
test("running and DistributedLDAModel") {
test("running and DistributedLDAModel with default Optimizer (EM)") {
val k = 3
val topicSmoothing = 1.2
val termSmoothing = 1.2
......@@ -99,7 +101,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
// Check: per-doc topic distributions
val topicDistributions = model.topicDistributions.collect()
// Ensure all documents are covered.
assert(topicDistributions.size === tinyCorpus.size)
assert(topicDistributions.length === tinyCorpus.length)
assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
// Ensure we have proper distributions
topicDistributions.foreach { case (docId, topicDistribution) =>
......@@ -131,6 +133,87 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
assert(lda.getBeta === 3.0)
assert(lda.getTopicConcentration === 3.0)
}
test("OnlineLDAOptimizer initialization") {
val lda = new LDA().setK(2)
val corpus = sc.parallelize(tinyCorpus, 2)
val op = new OnlineLDAOptimizer().initialize(corpus, lda)
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567)
assert(op.getAlpha == 0.5) // default 1.0 / k
assert(op.getEta == 0.5) // default 1.0 / k
assert(op.getKappa == 0.9876)
assert(op.getMiniBatchFraction == 0.123)
assert(op.getTau_0 == 567)
}
test("OnlineLDAOptimizer one iteration") {
// run OnlineLDAOptimizer for 1 iteration to verify it's consistency with Blei-lab,
// [[https://github.com/Blei-Lab/onlineldavb]]
val k = 2
val vocabSize = 6
def docs: Array[(Long, Vector)] = Array(
Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana
Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
val corpus = sc.parallelize(docs, 2)
// Set GammaShape large to avoid the stochastic impact.
val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40)
.setMiniBatchFraction(1)
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345)
val state = op.initialize(corpus, lda)
// override lambda to simulate an intermediate state
// [[ 1.1 1.2 1.3 0.9 0.8 0.7]
// [ 0.9 0.8 0.7 1.1 1.2 1.3]]
op.setLambda(new BDM[Double](k, vocabSize,
Array(1.1, 0.9, 1.2, 0.8, 1.3, 0.7, 0.9, 1.1, 0.8, 1.2, 0.7, 1.3)))
// run for one iteration
state.submitMiniBatch(corpus)
// verify the result, Note this generate the identical result as
// [[https://github.com/Blei-Lab/onlineldavb]]
val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
}
test("OnlineLDAOptimizer with toy data") {
def toydata: Array[(Long, Vector)] = Array(
Vectors.sparse(6, Array(0, 1), Array(1, 1)),
Vectors.sparse(6, Array(1, 2), Array(1, 1)),
Vectors.sparse(6, Array(0, 2), Array(1, 1)),
Vectors.sparse(6, Array(3, 4), Array(1, 1)),
Vectors.sparse(6, Array(3, 5), Array(1, 1)),
Vectors.sparse(6, Array(4, 5), Array(1, 1))
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
val docs = sc.parallelize(toydata)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau_0(1024).setKappa(0.51)
.setGammaShape(1e10)
val lda = new LDA().setK(2)
.setDocConcentration(0.01)
.setTopicConcentration(0.01)
.setMaxIterations(100)
.setOptimizer(op)
.setSeed(12345)
val ldaModel = lda.run(docs)
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
val topics = topicIndices.map { case (terms, termWeights) =>
terms.zip(termWeights)
}
// check distribution for each topic, typical distribution is (0.3, 0.3, 0.3, 0.02, 0.02, 0.02)
topics.foreach { topic =>
val smalls = topic.filter(t => t._2 < 0.1).map(_._2)
assert(smalls.length == 3 && smalls.sum < 0.2)
}
}
}
private[clustering] object LDASuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment