Skip to content
Snippets Groups Projects
Commit 5307c9d3 authored by MechCoder's avatar MechCoder Committed by Xiangrui Meng
Browse files

[SPARK-9223] [PYSPARK] [MLLIB] Support model save/load in LDA

Since save / load has been merged in LDA, it takes no time to write the wrappers in Python as well.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7587 from MechCoder/python_lda_save_load and squashes the following commits:

c8e4ea7 [MechCoder] [SPARK-9223] [PySpark] Support model save/load in LDA
parent 430cd781
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ import array as pyarray
if sys.version > '3':
xrange = range
basestring = str
from math import exp, log
......@@ -579,7 +580,7 @@ class LDAModel(JavaModelWrapper):
Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
>>> from pyspark.mllib.linalg import Vectors
>>> from numpy.testing import assert_almost_equal
>>> from numpy.testing import assert_almost_equal, assert_equal
>>> data = [
... [1, Vectors.dense([0.0, 1.0])],
... [2, SparseVector(2, {0: 1.0})],
......@@ -591,6 +592,19 @@ class LDAModel(JavaModelWrapper):
>>> topics = model.topicsMatrix()
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
>>> assert_almost_equal(topics, topics_expect, 1)
>>> import os, tempfile
>>> from shutil import rmtree
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = LDAModel.load(sc, path)
>>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix())
>>> sameModel.vocabSize() == model.vocabSize()
True
>>> try:
... rmtree(path)
... except OSError:
... pass
"""
def topicsMatrix(self):
......@@ -601,6 +615,33 @@ class LDAModel(JavaModelWrapper):
"""Vocabulary size (number of terms or terms in the vocabulary)"""
return self.call("vocabSize")
def save(self, sc, path):
"""Save the LDAModel on to disk.
:param sc: SparkContext
:param path: str, path to where the model needs to be stored.
"""
if not isinstance(sc, SparkContext):
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
self._java_model.save(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
"""Load the LDAModel from disk.
:param sc: SparkContext
:param path: str, path to where the model is stored.
"""
if not isinstance(sc, SparkContext):
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
sc._jsc.sc(), path)
return cls(java_model)
class LDA(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment