Skip to content
Snippets Groups Projects
Commit 1a623b2e authored by Michelangelo D'Agostino's avatar Michelangelo D'Agostino Committed by Xiangrui Meng
Browse files

SPARK-3770: Make userFeatures accessible from python

https://issues.apache.org/jira/browse/SPARK-3770

We need access to the underlying latent user features from python. However, the userFeatures RDD from the MatrixFactorizationModel isn't accessible from the python bindings. I've added a method to the underlying scala class to turn the RDD[(Int, Array[Double])] to an RDD[String]. This is then accessed from the python recommendation.py

Author: Michelangelo D'Agostino <mdagostino@civisanalytics.com>

Closes #2636 from mdagost/mf_user_features and squashes the following commits:

c98f9e2 [Michelangelo D'Agostino] Added unit tests for userFeatures and productFeatures and merged master.
d5eadf8 [Michelangelo D'Agostino] Merge branch 'master' into mf_user_features
2481a2a [Michelangelo D'Agostino] Merged master and resolved conflict.
a6ffb96 [Michelangelo D'Agostino] Eliminated a function from our first approach to this problem that is no longer needed now that we added the fromTuple2RDD function.
2aa1bf8 [Michelangelo D'Agostino] Implemented a function called fromTuple2RDD in PythonMLLibAPI and used it to expose the MF userFeatures and productFeatures in python.
34cb2a2 [Michelangelo D'Agostino] A couple of lint cleanups and a comment.
cdd98e3 [Michelangelo D'Agostino] It's working now.
e1fbe5e [Michelangelo D'Agostino] Added scala function to stringify userFeatures for access in python.
parent 61ca7742
No related branches found
No related tags found
No related merge requests found
......@@ -673,6 +673,11 @@ private[spark] object SerDe extends Serializable {
rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
/* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = {
rdd.map(x => Array(x._1, x._2))
}
/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
......
......@@ -53,6 +53,23 @@ class MatrixFactorizationModel(object):
>>> model = ALS.train(ratings, 1)
>>> model.predictAll(testset).count() == 2
True
>>> model = ALS.train(ratings, 4)
>>> model.userFeatures().count() == 2
True
>>> first_user = model.userFeatures().take(1)[0]
>>> latents = first_user[1]
>>> len(latents) == 4
True
>>> model.productFeatures().count() == 2
True
>>> first_product = model.productFeatures().take(1)[0]
>>> latents = first_product[1]
>>> len(latents) == 4
True
"""
def __init__(self, sc, java_model):
......@@ -83,6 +100,20 @@ class MatrixFactorizationModel(object):
return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
AutoBatchedSerializer(PickleSerializer()))
def userFeatures(self):
sc = self._context
juf = self._java_model.userFeatures()
juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
AutoBatchedSerializer(PickleSerializer()))
def productFeatures(self):
sc = self._context
jpf = self._java_model.productFeatures()
jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
AutoBatchedSerializer(PickleSerializer()))
class ALS(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment