Skip to content
Snippets Groups Projects
Commit 686dd742 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark

SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #5619 from mengxr/SPARK-7036 and squashes the following commits:

dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark
parent 7fe6142c
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,7 @@ from pyspark import SparkContext ...@@ -22,6 +22,7 @@ from pyspark import SparkContext
from pyspark.rdd import RDD from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
from pyspark.mllib.util import JavaLoader, JavaSaveable from pyspark.mllib.util import JavaLoader, JavaSaveable
from pyspark.sql import DataFrame
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
...@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ...@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
True True
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2) >>> model.predict(2, 2)
3.8...
>>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
>>> model = ALS.train(df, 1, nonnegative=True, seed=10)
>>> model.predict(2, 2)
3.8... 3.8...
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2) >>> model.predict(2, 2)
0.4... 0.4...
>>> import os, tempfile >>> import os, tempfile
>>> path = tempfile.mkdtemp() >>> path = tempfile.mkdtemp()
>>> model.save(sc, path) >>> model.save(sc, path)
>>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path)
>>> sameModel.predict(2,2) >>> sameModel.predict(2, 2)
0.4... 0.4...
>>> sameModel.predictAll(testset).collect() >>> sameModel.predictAll(testset).collect()
[Rating(... [Rating(...
...@@ -125,13 +131,20 @@ class ALS(object): ...@@ -125,13 +131,20 @@ class ALS(object):
@classmethod @classmethod
def _prepare(cls, ratings): def _prepare(cls, ratings):
assert isinstance(ratings, RDD), "ratings should be RDD" if isinstance(ratings, RDD):
pass
elif isinstance(ratings, DataFrame):
ratings = ratings.rdd
else:
raise TypeError("Ratings should be represented by either an RDD or a DataFrame, "
"but got %s." % type(ratings))
first = ratings.first() first = ratings.first()
if not isinstance(first, Rating): if isinstance(first, Rating):
if isinstance(first, (tuple, list)): pass
ratings = ratings.map(lambda x: Rating(*x)) elif isinstance(first, (tuple, list)):
else: ratings = ratings.map(lambda x: Rating(*x))
raise ValueError("rating should be RDD of Rating or tuple/list") else:
raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first))
return ratings return ratings
@classmethod @classmethod
...@@ -152,8 +165,11 @@ class ALS(object): ...@@ -152,8 +165,11 @@ class ALS(object):
def _test(): def _test():
import doctest import doctest
import pyspark.mllib.recommendation import pyspark.mllib.recommendation
from pyspark.sql import SQLContext
globs = pyspark.mllib.recommendation.__dict__.copy() globs = pyspark.mllib.recommendation.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest') sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop() globs['sc'].stop()
if failure_count: if failure_count:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment