Skip to content
Snippets Groups Projects
Commit 4698a0d6 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Shuffle ratings in a more efficient way at start of ALS

parent d47c16f7
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,10 @@ import scala.util.Sorting ...@@ -6,8 +6,10 @@ import scala.util.Sorting
import spark.{HashPartitioner, Partitioner, SparkContext, RDD} import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
import spark.storage.StorageLevel import spark.storage.StorageLevel
import spark.KryoRegistrator
import spark.SparkContext._ import spark.SparkContext._
import com.esotericsoftware.kryo.Kryo
import org.jblas.{DoubleMatrix, SimpleBlas, Solve} import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
...@@ -98,8 +100,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l ...@@ -98,8 +100,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
val partitioner = new HashPartitioner(numBlocks) val partitioner = new HashPartitioner(numBlocks)
val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, (u, p, r)) } val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, Rating(u, p, r)) }
val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, (p, u, r)) } val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, Rating(p, u, r)) }
val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock) val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock) val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
...@@ -179,12 +181,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l ...@@ -179,12 +181,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
* having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
*/ */
private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, (Int, Int, Double))]) private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)])
: (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
{ {
val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => { val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
val ratings = elements.map{case (k, t) => Rating(t._1, t._2, t._3)}.toArray val ratings = elements.map{_._2}.toArray
val inLinkBlock = makeInLinkBlock(numBlocks, ratings) val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
Iterator.single((blockId, (inLinkBlock, outLinkBlock))) Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
...@@ -383,6 +385,12 @@ object ALS { ...@@ -383,6 +385,12 @@ object ALS {
train(ratings, rank, iterations, 0.01, -1) train(ratings, rank, iterations, 0.01, -1)
} }
private class ALSRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[Rating])
}
}
def main(args: Array[String]) { def main(args: Array[String]) {
if (args.length != 5 && args.length != 6) { if (args.length != 5 && args.length != 6) {
println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]") println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
...@@ -392,6 +400,8 @@ object ALS { ...@@ -392,6 +400,8 @@ object ALS {
(args(0), args(1), args(2).toInt, args(3).toInt, args(4)) (args(0), args(1), args(2).toInt, args(3).toInt, args(4))
val blocks = if (args.length == 6) args(5).toInt else -1 val blocks = if (args.length == 6) args(5).toInt else -1
System.setProperty("spark.serializer", "spark.KryoSerializer") System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
System.setProperty("spark.kryo.referenceTracking", "false")
System.setProperty("spark.locality.wait", "10000") System.setProperty("spark.locality.wait", "10000")
val sc = new SparkContext(master, "ALS") val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line => val ratings = sc.textFile(ratingsFile).map { line =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment