Skip to content
Snippets Groups Projects
Commit e300a5a1 authored by Nick Pentreath's avatar Nick Pentreath
Browse files

[SPARK-20300][ML][PYSPARK] Python API for ALSModel.recommendForAllUsers,Items

Add Python API for `ALSModel` methods `recommendForAllUsers`, `recommendForAllItems`

## How was this patch tested?

New doc tests.

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #17622 from MLnick/SPARK-20300-pyspark-recall.
parent 86174ea8
No related branches found
No related tags found
No related merge requests found
......@@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
>>> user_recs = model.recommendForAllUsers(3)
>>> user_recs.where(user_recs.user == 0)\
.select("recommendations.item", "recommendations.rating").collect()
[Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])]
>>> item_recs = model.recommendForAllItems(3)
>>> item_recs.where(item_recs.item == 2)\
.select("recommendations.user", "recommendations.rating").collect()
[Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])]
>>> als_path = temp_path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
......@@ -384,6 +392,28 @@ class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("itemFactors")
@since("2.2.0")
def recommendForAllUsers(self, numItems):
"""
Returns top `numItems` items recommended for each user, for all users.
:param numItems: max number of recommendations for each user
:return: a DataFrame of (userCol, recommendations), where recommendations are
stored as an array of (itemCol, rating) Rows.
"""
return self._call_java("recommendForAllUsers", numItems)
@since("2.2.0")
def recommendForAllItems(self, numUsers):
"""
Returns top `numUsers` users recommended for each item, for all items.
:param numUsers: max number of recommendations for each item
:return: a DataFrame of (itemCol, recommendations), where recommendations are
stored as an array of (userCol, rating) Rows.
"""
return self._call_java("recommendForAllItems", numUsers)
if __name__ == "__main__":
import doctest
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment