Skip to content
Snippets Groups Projects
Commit c5c38d19 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Some optimizations to loading phase of ALS

parent b91a218c
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ package spark.mllib.recommendation
import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.util.Random
import scala.util.Sorting
import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
import spark.storage.StorageLevel
......@@ -33,6 +34,12 @@ private[recommendation] case class InLinkBlock(
elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]])
/**
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
*/
private[recommendation] case class Rating(user: Int, product: Int, rating: Double)
/**
* Alternating Least Squares matrix factorization.
*
......@@ -126,13 +133,13 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
private def makeOutLinkBlock(numBlocks: Int, ratings: Array[(Int, Int, Double)]): OutLinkBlock = {
val userIds = ratings.map(_._1).distinct.sorted
private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): OutLinkBlock = {
val userIds = ratings.map(_.user).distinct.sorted
val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap
val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks))
for ((u, p, r) <- ratings) {
shouldSend(userIdToPos(u))(p % numBlocks) = true
for (r <- ratings) {
shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true
}
OutLinkBlock(userIds, shouldSend)
}
......@@ -141,18 +148,28 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* Make the in-links table for a block of the users (or products) dataset given a list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
private def makeInLinkBlock(numBlocks: Int, ratings: Array[(Int, Int, Double)]): InLinkBlock = {
val userIds = ratings.map(_._1).distinct.sorted
private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): InLinkBlock = {
val userIds = ratings.map(_.user).distinct.sorted
val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap
// Split out our ratings by product block
val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating])
for (r <- ratings) {
blockRatings(r.product % numBlocks) += r
}
val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks)
for (productBlock <- 0 until numBlocks) {
val ratingsInBlock = ratings.filter(t => t._2 % numBlocks == productBlock)
val ratingsByProduct = ratingsInBlock.groupBy(_._2) // (p, Seq[(u, p, r)])
.toArray
.sortBy(_._1)
.map{case (p, rs) => (rs.map(t => userIdToPos(t._1)), rs.map(_._3))}
ratingsForBlock(productBlock) = ratingsByProduct
// Create an array of (product, Seq(Rating)) ratings
val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray
// Sort them by user ID
val ordering = new Ordering[(Int, ArrayBuffer[Rating])] {
def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = a._1 - b._1
}
Sorting.quickSort(groupedRatings)(ordering)
// Translate the user IDs to indices based on userIdToPos
ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) =>
(rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray)
}
}
InLinkBlock(userIds, ratingsForBlock)
}
......@@ -167,7 +184,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
{
val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
val ratings = elements.map(_._2).toArray
val ratings = elements.map{case (k, t) => Rating(t._1, t._2, t._3)}.toArray
val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
......@@ -373,6 +390,8 @@ object ALS {
}
val (master, ratingsFile, rank, iters, outputDir) =
(args(0), args(1), args(2).toInt, args(3).toInt, args(4))
System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.locality.wait", "10000")
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment