Skip to content
Snippets Groups Projects
Commit 8d0c2f73 authored by Hossein Falaki's avatar Hossein Falaki
Browse files

Added python binding for bulk recommendation

parent dfe57fa8
No related branches found
No related tags found
No related merge requests found
......@@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable {
return new Rating(user, product, rating)
}
private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
val bb = ByteBuffer.wrap(tupleBytes)
bb.order(ByteOrder.nativeOrder())
val v1 = bb.getInt()
val v2 = bb.getInt()
(v1, v2)
}
private[spark] def serializeRating(rate: Rating): Array[Byte] = {
val bytes = new Array[Byte](24)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(rate.user.toDouble)
bb.putDouble(rate.product.toDouble)
bb.putDouble(rate.rating)
bytes
}
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
......
......@@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.PythonMLLibAPI
import org.jblas._
import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.api.java.JavaRDD
/**
* Model representing the result of matrix factorization.
......@@ -65,6 +67,12 @@ class MatrixFactorizationModel(
}
}
def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
}
// TODO: Figure out what other good bulk prediction methods would look like.
// Probably want a way to get the top users for a product or vice-versa.
}
......@@ -213,6 +213,16 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba
def _deserialize_rating(ba):
ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
return ar.copy()
def _serialize_tuple(t):
ba = bytearray(8)
intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
intpart[0], intpart[1] = t
return ba
def _test():
import doctest
globs = globals().copy()
......
......@@ -20,7 +20,10 @@ from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
_serialize_tuple, _deserialize_rating
from pyspark.serializers import BatchedSerializer
from pyspark.rdd import RDD
class MatrixFactorizationModel(object):
"""A matrix factorisation model trained by regularized alternating
......@@ -45,6 +48,11 @@ class MatrixFactorizationModel(object):
def predict(self, user, product):
return self._java_model.predict(user, product)
def predictAll(self, usersProducts):
usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
class ALS(object):
@classmethod
def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment