Skip to content
Snippets Groups Projects
Commit 754f5300 authored by Hossein Falaki's avatar Hossein Falaki
Browse files

Added predictAll python function to MatrixFactorizationModel

parent 04132ea9
No related branches found
No related tags found
No related merge requests found
......@@ -21,8 +21,7 @@ from pyspark.mllib._common import \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
_serialize_tuple, _deserialize_rating
from pyspark.serializers import BatchedSerializer
_serialize_tuple, RatingDeserializer
from pyspark.rdd import RDD
class MatrixFactorizationModel(object):
......@@ -36,6 +35,9 @@ class MatrixFactorizationModel(object):
>>> model = ALS.trainImplicit(sc, ratings, 1)
>>> model.predict(2,2) is not None
True
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model.predictAll(testset).count == 2
True
"""
def __init__(self, sc, java_model):
......@@ -50,8 +52,8 @@ class MatrixFactorizationModel(object):
def predictAll(self, usersProducts):
usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
self._context, RatingDeserializer())
class ALS(object):
@classmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment