Skip to content
Snippets Groups Projects
Commit 46de6c05 authored by lewuathe's avatar lewuathe Committed by Xiangrui Meng
Browse files

[SPARK-6598][MLLIB] Python API for IDFModel

This is the sub-task of SPARK-6254.
Wrapping IDFModel `idf` member function for pyspark.

Author: lewuathe <lewuathe@me.com>

Closes #5264 from Lewuathe/SPARK-6598 and squashes the following commits:

1dc522c [lewuathe] [SPARK-6598] Python API for IDFModel
parent cd48ca50
No related branches found
No related tags found
No related merge requests found
......@@ -244,6 +244,12 @@ class IDFModel(JavaVectorTransformer):
x = _convert_to_vector(x)
return JavaVectorTransformer.transform(self, x)
def idf(self):
"""
Returns the current IDF vector.
"""
return self.call('idf')
class IDF(object):
"""
......
......@@ -41,6 +41,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import IDF
from pyspark.serializers import PickleSerializer
from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
......@@ -620,6 +621,19 @@ class ChiSqTestTests(PySparkTestCase):
self.assertEqual(len(chi), num_cols)
self.assertIsNotNone(chi[1000])
class FeatureTest(PySparkTestCase):
def test_idf_model(self):
data = [
Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]),
Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]),
Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]),
Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9])
]
model = IDF().fit(self.sc.parallelize(data, 2))
idf = model.idf()
self.assertEqual(len(idf), 11)
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment