Skip to content
Snippets Groups Projects
Commit b54c6ab3 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-4396] allow lookup by index in Python's Rating

In PySpark, ALS can take an RDD of (user, product, rating) tuples as input. However, model.predict outputs an RDD of Rating. So on the input side, users can use r[0], r[1], r[2], while on the output side, users have to use r.user, r.product, r.rating. We should allow lookup by index in Rating by making Rating a namedtuple.

davies

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3261)
<!-- Reviewable:end -->

Author: Xiangrui Meng <meng@databricks.com>

Closes #3261 from mengxr/SPARK-4396 and squashes the following commits:

543aef0 [Xiangrui Meng] use named tuple to implement ALS
0b61bae [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4396
d3bd7d4 [Xiangrui Meng] allow lookup by index in Python's Rating
parent 8fbf72b7
No related branches found
No related tags found
No related merge requests found
......@@ -15,24 +15,28 @@
# limitations under the License.
#
from collections import namedtuple
from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
__all__ = ['MatrixFactorizationModel', 'ALS']
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
class Rating(object):
def __init__(self, user, product, rating):
self.user = int(user)
self.product = int(product)
self.rating = float(rating)
class Rating(namedtuple("Rating", ["user", "product", "rating"])):
"""
Represents a (user, product, rating) tuple.
def __reduce__(self):
return Rating, (self.user, self.product, self.rating)
>>> r = Rating(1, 2, 5.0)
>>> (r.user, r.product, r.rating)
(1, 2, 5.0)
>>> (r[0], r[1], r[2])
(1, 2, 5.0)
"""
def __repr__(self):
return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating)
def __reduce__(self):
return Rating, (int(self.user), int(self.product), float(self.rating))
class MatrixFactorizationModel(JavaModelWrapper):
......@@ -51,7 +55,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 1, seed=10)
>>> model.predictAll(testset).collect()
[Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)]
[Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)]
>>> model = ALS.train(ratings, 4, seed=10)
>>> model.userFeatures().collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment