Skip to content
Snippets Groups Projects
Commit b13162b3 authored by Yuhao Yang's avatar Yuhao Yang Committed by Joseph K. Bradley
Browse files

[SPARK-7475] [MLLIB] adjust ldaExample for online LDA

jira: https://issues.apache.org/jira/browse/SPARK-7475

Add a new argument to specify the algorithm applied to LDA, to exhibit the basic usage of LDAOptimizer.

cc jkbradley

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #6000 from hhbyyh/ldaExample and squashes the following commits:

0a7e2bc [Yuhao Yang] fix according to comments
5810b0f [Yuhao Yang] adjust ldaExample for online LDA
parent bd74301f
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ import scopt.OptionParser ...@@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger} import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
...@@ -48,6 +48,7 @@ object LDAExample { ...@@ -48,6 +48,7 @@ object LDAExample {
topicConcentration: Double = -1, topicConcentration: Double = -1,
vocabSize: Int = 10000, vocabSize: Int = 10000,
stopwordFile: String = "", stopwordFile: String = "",
algorithm: String = "em",
checkpointDir: Option[String] = None, checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params] checkpointInterval: Int = 10) extends AbstractParams[Params]
...@@ -78,6 +79,10 @@ object LDAExample { ...@@ -78,6 +79,10 @@ object LDAExample {
.text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
s" default: ${defaultParams.stopwordFile}") s" default: ${defaultParams.stopwordFile}")
.action((x, c) => c.copy(stopwordFile = x)) .action((x, c) => c.copy(stopwordFile = x))
opt[String]("algorithm")
.text(s"inference algorithm to use. em and online are supported." +
s" default: ${defaultParams.algorithm}")
.action((x, c) => c.copy(algorithm = x))
opt[String]("checkpointDir") opt[String]("checkpointDir")
.text(s"Directory for checkpointing intermediate results." + .text(s"Directory for checkpointing intermediate results." +
s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
...@@ -128,7 +133,17 @@ object LDAExample { ...@@ -128,7 +133,17 @@ object LDAExample {
// Run LDA. // Run LDA.
val lda = new LDA() val lda = new LDA()
lda.setK(params.k)
val optimizer = params.algorithm.toLowerCase match {
case "em" => new EMLDAOptimizer
// add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize)
case _ => throw new IllegalArgumentException(
s"Only em, online are supported but got ${params.algorithm}.")
}
lda.setOptimizer(optimizer)
.setK(params.k)
.setMaxIterations(params.maxIterations) .setMaxIterations(params.maxIterations)
.setDocConcentration(params.docConcentration) .setDocConcentration(params.docConcentration)
.setTopicConcentration(params.topicConcentration) .setTopicConcentration(params.topicConcentration)
...@@ -137,14 +152,18 @@ object LDAExample { ...@@ -137,14 +152,18 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get) sc.setCheckpointDir(params.checkpointDir.get)
} }
val startTime = System.nanoTime() val startTime = System.nanoTime()
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val ldaModel = lda.run(corpus)
val elapsed = (System.nanoTime() - startTime) / 1e9 val elapsed = (System.nanoTime() - startTime) / 1e9
println(s"Finished training LDA model. Summary:") println(s"Finished training LDA model. Summary:")
println(s"\t Training time: $elapsed sec") println(s"\t Training time: $elapsed sec")
val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
println(s"\t Training data average log likelihood: $avgLogLikelihood") if (ldaModel.isInstanceOf[DistributedLDAModel]) {
println() val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
println(s"\t Training data average log likelihood: $avgLogLikelihood")
println()
}
// Print the topics, showing the top-weighted terms for each topic. // Print the topics, showing the top-weighted terms for each topic.
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment