Skip to content
Snippets Groups Projects
Commit 46927696 authored by Yu ISHIKAWA's avatar Yu ISHIKAWA Committed by Joseph K. Bradley
Browse files

[SPARK-6259] [MLLIB] Python API for LDA

I implemented the Python API for LDA. But I didn't implemented a method for `LDAModel.describeTopics()`, beause it's a little hard to implement it now. And adding document about that and an example code would fit for another issue.

TODO: LDAModel.describeTopics() in Python must be also implemented. But it would be nice to fit for another issue. Implementing it is a little hard, since the return value of `describeTopics` in Scala consists of Tuple classes.

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #6791 from yu-iskw/SPARK-6259 and squashes the following commits:

6855f59 [Yu ISHIKAWA] LDA inherits object
28bd165 [Yu ISHIKAWA] Change the place of testing code
d7a332a [Yu ISHIKAWA] Remove the doc comment about the optimizer's default value
083e226 [Yu ISHIKAWA] Add the comment about the supported values and the default value of `optimizer`
9f8bed8 [Yu ISHIKAWA] Simplify casting
faa9764 [Yu ISHIKAWA] Add some comments for the LDA paramters
98f645a [Yu ISHIKAWA] Remove the interface for `describeTopics`. Because it is not implemented.
57ac03d [Yu ISHIKAWA] Remove the unnecessary import in Python unit testing
73412c3 [Yu ISHIKAWA] Fix the typo
2278829 [Yu ISHIKAWA] Fix the indentation
39514ec [Yu ISHIKAWA] Modify how to cast the input data
8117e18 [Yu ISHIKAWA] Fix the validation problems by `lint-scala`
77fd1b7 [Yu ISHIKAWA] Not use LabeledPoint
68f0653 [Yu ISHIKAWA] Support some parameters for `ALS.train()` in Python
25ef2ac [Yu ISHIKAWA] Resolve conflicts with rebasing
parent c6b1a9e7
No related branches found
No related tags found
No related merge requests found
......@@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable {
new MatrixFactorizationModelWrapper(model)
}
/**
* Java stub for Python mllib LDA.run()
*/
def trainLDAModel(
data: JavaRDD[java.util.List[Any]],
k: Int,
maxIterations: Int,
docConcentration: Double,
topicConcentration: Double,
seed: java.lang.Long,
checkpointInterval: Int,
optimizer: String): LDAModel = {
val algo = new LDA()
.setK(k)
.setMaxIterations(maxIterations)
.setDocConcentration(docConcentration)
.setTopicConcentration(topicConcentration)
.setCheckpointInterval(checkpointInterval)
.setOptimizer(optimizer)
if (seed != null) algo.setSeed(seed)
val documents = data.rdd.map(_.asScala.toArray).map { r =>
r(0) match {
case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
}
}
algo.run(documents)
}
/**
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
......
......@@ -31,13 +31,15 @@ from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
from pyspark.streaming import DStream
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
'PowerIterationClusteringModel', 'PowerIterationClustering',
'StreamingKMeans', 'StreamingKMeansModel']
'StreamingKMeans', 'StreamingKMeansModel',
'LDA', 'LDAModel']
@inherit_doc
......@@ -563,6 +565,68 @@ class StreamingKMeans(object):
return dstream.mapValues(lambda x: self._model.predict(x))
class LDAModel(JavaModelWrapper):
""" A clustering model derived from the LDA method.
Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
Terminology
- "word" = "term": an element of the vocabulary
- "token": instance of a term appearing in a document
- "topic": multinomial distribution over words representing some concept
References:
- Original LDA paper (journal version):
Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
>>> from pyspark.mllib.linalg import Vectors
>>> from numpy.testing import assert_almost_equal
>>> data = [
... [1, Vectors.dense([0.0, 1.0])],
... [2, SparseVector(2, {0: 1.0})],
... ]
>>> rdd = sc.parallelize(data)
>>> model = LDA.train(rdd, k=2)
>>> model.vocabSize()
2
>>> topics = model.topicsMatrix()
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
>>> assert_almost_equal(topics, topics_expect, 1)
"""
def topicsMatrix(self):
"""Inferred topics, where each topic is represented by a distribution over terms."""
return self.call("topicsMatrix").toArray()
def vocabSize(self):
"""Vocabulary size (number of terms or terms in the vocabulary)"""
return self.call("vocabSize")
class LDA(object):
@classmethod
def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
"""Train a LDA model.
:param rdd: RDD of data points
:param k: Number of clusters you want
:param maxIterations: Number of iterations. Default to 20
:param docConcentration: Concentration parameter (commonly named "alpha")
for the prior placed on documents' distributions over topics ("theta").
:param topicConcentration: Concentration parameter (commonly named "beta" or "eta")
for the prior placed on topics' distributions over terms.
:param seed: Random Seed
:param checkpointInterval: Period (in iterations) between checkpoints.
:param optimizer: LDAOptimizer used to perform the actual calculation.
Currently "em", "online" are supported. Default to "em".
"""
model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
docConcentration, topicConcentration, seed,
checkpointInterval, optimizer)
return LDAModel(model)
def _test():
import doctest
import pyspark.mllib.clustering
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment