Skip to content
Snippets Groups Projects
Commit bbea8885 authored by Yuhao Yang's avatar Yuhao Yang Committed by Joseph K. Bradley
Browse files

[SPARK-10809][MLLIB] Single-document topicDistributions method for LocalLDAModel

jira: https://issues.apache.org/jira/browse/SPARK-10809

We could provide a single-document topicDistributions method for LocalLDAModel to allow for quick queries which avoid RDD operations. Currently, the user must use an RDD of documents.

add some missing assert too.

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #9484 from hhbyyh/ldaTopicPre.
parent 4f8eefa3
No related branches found
No related tags found
No related merge requests found
......@@ -387,6 +387,32 @@ class LocalLDAModel private[spark] (
}
}
/**
* Predicts the topic mixture distribution for a document (often called "theta" in the
* literature). Returns a vector of zeros for an empty document.
*
* Note this means to allow quick query for single document. For batch documents, please refer
* to [[topicDistributions()]] to avoid overhead.
*
* @param document document to predict topic mixture distributions for
* @return topic mixture distribution for the document
*/
@Since("2.0.0")
def topicDistribution(document: Vector): Vector = {
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
if (document.numNonzeros == 0) {
Vectors.zeros(this.k)
} else {
val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
document,
expElogbeta,
this.docConcentration.toBreeze,
gammaShape,
this.k)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
/**
* Java-friendly version of [[topicDistributions]]
*/
......
......@@ -366,7 +366,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
(0, 0.99504), (1, 0.99504),
(1, 0.99504), (1, 0.99504))
val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) =>
val actualPredictions = ldaModel.topicDistributions(docs).cache()
val topTopics = actualPredictions.map { case (id, topics) =>
// convert results to expectedPredictions format, which only has highest probability topic
val topicsBz = topics.toBreeze.toDenseVector
(id, (argmax(topicsBz), max(topicsBz)))
......@@ -374,9 +375,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
.values
.collect()
expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
expectedPredictions.zip(topTopics).foreach { case (expected, actual) =>
assert(expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D))
}
docs.collect()
.map(doc => ldaModel.topicDistribution(doc._2))
.zip(actualPredictions.map(_._2).collect())
.foreach { case (single, batch) =>
assert(single ~== batch relTol 1E-3D)
}
actualPredictions.unpersist()
}
test("OnlineLDAOptimizer with asymmetric prior") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment