Skip to content
Snippets Groups Projects
Commit 4698a0d6 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Shuffle ratings in a more efficient way at start of ALS

parent d47c16f7
No related branches found
No related tags found
No related merge requests found
......@@ -6,8 +6,10 @@ import scala.util.Sorting
import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
import spark.storage.StorageLevel
import spark.KryoRegistrator
import spark.SparkContext._
import com.esotericsoftware.kryo.Kryo
import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
......@@ -98,8 +100,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val partitioner = new HashPartitioner(numBlocks)
val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, (u, p, r)) }
val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, (p, u, r)) }
val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, Rating(u, p, r)) }
val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, Rating(p, u, r)) }
val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
......@@ -179,12 +181,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
* having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
*/
private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, (Int, Int, Double))])
private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)])
: (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
{
val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
val ratings = elements.map{case (k, t) => Rating(t._1, t._2, t._3)}.toArray
val ratings = elements.map{_._2}.toArray
val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
......@@ -383,6 +385,12 @@ object ALS {
train(ratings, rank, iterations, 0.01, -1)
}
private class ALSRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[Rating])
}
}
def main(args: Array[String]) {
if (args.length != 5 && args.length != 6) {
println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
......@@ -392,6 +400,8 @@ object ALS {
(args(0), args(1), args(2).toInt, args(3).toInt, args(4))
val blocks = if (args.length == 6) args(5).toInt else -1
System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
System.setProperty("spark.kryo.referenceTracking", "false")
System.setProperty("spark.locality.wait", "10000")
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment