Skip to content
Snippets Groups Projects
Commit ca26a212 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-4378][MLLIB] make ALS more Java-friendly

Add Java-friendly version of `run` and `predict`, and use bulk prediction in Java unit tests. The user guide update will come later (though we may not save many lines of code there). srowen

Author: Xiangrui Meng <meng@databricks.com>

Closes #3240 from mengxr/SPARK-4378 and squashes the following commits:

6581503 [Xiangrui Meng] check number of predictions
6c8bbd1 [Xiangrui Meng] make ALS more Java-friendly
parent ce0333f9
No related branches found
No related tags found
No related merge requests found
...@@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation ...@@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation
import scala.collection.mutable import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.math.{abs, sqrt} import scala.math.{abs, sqrt}
import scala.util.Random import scala.util.{Random, Sorting}
import scala.util.Sorting
import scala.util.hashing.byteswap32 import scala.util.hashing.byteswap32
import org.jblas.{DoubleMatrix, SimpleBlas, Solve} import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{Logging, HashPartitioner, Partitioner} import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.mllib.optimization.NNLS
/** /**
* Out-link information for a user or product block. This includes the original user/product IDs * Out-link information for a user or product block. This includes the original user/product IDs
...@@ -325,6 +325,11 @@ class ALS private ( ...@@ -325,6 +325,11 @@ class ALS private (
new MatrixFactorizationModel(rank, usersOut, productsOut) new MatrixFactorizationModel(rank, usersOut, productsOut)
} }
/**
* Java-friendly version of [[ALS.run]].
*/
def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd)
/** /**
* Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
* for each user (or product), in a distributed fashion. * for each user (or product), in a distributed fashion.
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
package org.apache.spark.mllib.recommendation package org.apache.spark.mllib.recommendation
import java.lang.{Integer => JavaInteger}
import org.jblas.DoubleMatrix import org.jblas.DoubleMatrix
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.SerDe import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
/** /**
* Model representing the result of matrix factorization. * Model representing the result of matrix factorization.
...@@ -65,6 +65,13 @@ class MatrixFactorizationModel private[mllib] ( ...@@ -65,6 +65,13 @@ class MatrixFactorizationModel private[mllib] (
} }
} }
/**
* Java-friendly version of [[MatrixFactorizationModel.predict]].
*/
def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
}
/** /**
* Recommends products to a user. * Recommends products to a user.
* *
......
...@@ -23,13 +23,14 @@ import java.util.List; ...@@ -23,13 +23,14 @@ import java.util.List;
import scala.Tuple2; import scala.Tuple2;
import scala.Tuple3; import scala.Tuple3;
import com.google.common.collect.Lists;
import org.jblas.DoubleMatrix; import org.jblas.DoubleMatrix;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
...@@ -47,61 +48,48 @@ public class JavaALSSuite implements Serializable { ...@@ -47,61 +48,48 @@ public class JavaALSSuite implements Serializable {
sc = null; sc = null;
} }
static void validatePrediction( void validatePrediction(
MatrixFactorizationModel model, MatrixFactorizationModel model,
int users, int users,
int products, int products,
int features,
DoubleMatrix trueRatings, DoubleMatrix trueRatings,
double matchThreshold, double matchThreshold,
boolean implicitPrefs, boolean implicitPrefs,
DoubleMatrix truePrefs) { DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features); List<Tuple2<Integer, Integer>> localUsersProducts =
List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect(); Lists.newArrayListWithCapacity(users * products);
for (int i = 0; i < features; ++i) { for (int u=0; u < users; ++u) {
for (Tuple2<Object, double[]> userFeature : userFeatures) { for (int p=0; p < products; ++p) {
predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); localUsersProducts.add(new Tuple2<Integer, Integer>(u, p));
}
}
DoubleMatrix predictedP = new DoubleMatrix(products, features);
List<Tuple2<Object, double[]>> productFeatures =
model.productFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
for (Tuple2<Object, double[]> productFeature : productFeatures) {
predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
} }
} }
JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); List<Rating> predictedRatings = model.predict(usersProducts).collect();
Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) { if (!implicitPrefs) {
for (int u = 0; u < users; ++u) { for (Rating r: predictedRatings) {
for (int p = 0; p < products; ++p) { double prediction = r.rating();
double prediction = predictedRatings.get(u, p); double correct = trueRatings.get(r.user(), r.product());
double correct = trueRatings.get(u, p); Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
}
} }
} else { } else {
// For implicit prefs we use the confidence-weighted RMSE to test // For implicit prefs we use the confidence-weighted RMSE to test
// (ref Mahout's implicit ALS tests) // (ref Mahout's implicit ALS tests)
double sqErr = 0.0; double sqErr = 0.0;
double denom = 0.0; double denom = 0.0;
for (int u = 0; u < users; ++u) { for (Rating r: predictedRatings) {
for (int p = 0; p < products; ++p) { double prediction = r.rating();
double prediction = predictedRatings.get(u, p); double truePref = truePrefs.get(r.user(), r.product());
double truePref = truePrefs.get(u, p); double confidence = 1.0 +
double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product()));
double err = confidence * (truePref - prediction) * (truePref - prediction); double err = confidence * (truePref - prediction) * (truePref - prediction);
sqErr += err; sqErr += err;
denom += confidence; denom += confidence;
}
} }
double rmse = Math.sqrt(sqErr / denom); double rmse = Math.sqrt(sqErr / denom);
Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
rmse, matchThreshold), rmse < matchThreshold); rmse, matchThreshold), rmse < matchThreshold);
} }
} }
...@@ -116,7 +104,7 @@ public class JavaALSSuite implements Serializable { ...@@ -116,7 +104,7 @@ public class JavaALSSuite implements Serializable {
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
} }
@Test @Test
...@@ -132,8 +120,8 @@ public class JavaALSSuite implements Serializable { ...@@ -132,8 +120,8 @@ public class JavaALSSuite implements Serializable {
MatrixFactorizationModel model = new ALS().setRank(features) MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations) .setIterations(iterations)
.run(data.rdd()); .run(data);
validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
} }
@Test @Test
...@@ -147,7 +135,7 @@ public class JavaALSSuite implements Serializable { ...@@ -147,7 +135,7 @@ public class JavaALSSuite implements Serializable {
JavaRDD<Rating> data = sc.parallelize(testData._1()); JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
} }
@Test @Test
...@@ -165,7 +153,7 @@ public class JavaALSSuite implements Serializable { ...@@ -165,7 +153,7 @@ public class JavaALSSuite implements Serializable {
.setIterations(iterations) .setIterations(iterations)
.setImplicitPrefs(true) .setImplicitPrefs(true)
.run(data.rdd()); .run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
} }
@Test @Test
...@@ -183,7 +171,7 @@ public class JavaALSSuite implements Serializable { ...@@ -183,7 +171,7 @@ public class JavaALSSuite implements Serializable {
.setImplicitPrefs(true) .setImplicitPrefs(true)
.setSeed(8675309L) .setSeed(8675309L)
.run(data.rdd()); .run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
} }
@Test @Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment