Skip to content
Snippets Groups Projects
Commit 04132ea9 authored by Hossein Falaki's avatar Hossein Falaki
Browse files

Added Rating deserializer

parent 11a93fb5
No related branches found
No related tags found
No related merge requests found
...@@ -67,7 +67,14 @@ class MatrixFactorizationModel( ...@@ -67,7 +67,14 @@ class MatrixFactorizationModel(
} }
} }
def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { /**
* Predict the rating of many users for many products.
* This is a Java stub for python predictAll()
*
* @param usersProductsJRDD A JavaRDD with serialized tuples (user, product)
* @return JavaRDD of serialized Rating objects.
*/
def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI() val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext from pyspark import SparkContext
from pyspark.serializers import Serializer
import struct
# Double vector format: # Double vector format:
# #
# [8-byte 1] [8-byte length] [length*8 bytes of data] # [8-byte 1] [8-byte length] [length*8 bytes of data]
...@@ -213,9 +216,21 @@ def _serialize_rating(r): ...@@ -213,9 +216,21 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r intpart[0], intpart[1], doublepart[0] = r
return ba return ba
def _deserialize_rating(ba): class RatingDeserializer(Serializer):
ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C') def loads(self, stream):
return ar.copy() length = struct.unpack("!i", stream.read(4))[0]
ba = stream.read(length)
res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4)
return int(res[0]), int(res[1]), res[2]
def load_stream(self, stream):
while True:
try:
yield self.loads(stream)
except struct.error:
return
except EOFError:
return
def _serialize_tuple(t): def _serialize_tuple(t):
ba = bytearray(8) ba = bytearray(8)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment