Skip to content
Snippets Groups Projects
Commit 0cc7b88c authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-5536] replace old ALS implementation by the new one

The only issue is that `analyzeBlock` is removed, which was marked as a developer API. I didn't change other tests in the ALSSuite under `spark.mllib` to ensure that the implementation is correct.

CC: srowen coderxiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #4321 from mengxr/SPARK-5536 and squashes the following commits:

5a3cee8 [Xiangrui Meng] update python tests that are too strict
e840acf [Xiangrui Meng] ignore scala style check for ALS.train
e9a721c [Xiangrui Meng] update mima excludes
9ee6a36 [Xiangrui Meng] merge master
9a8aeac [Xiangrui Meng] update tests
d8c3271 [Xiangrui Meng] remove analyzeBlocks
d68eee7 [Xiangrui Meng] add checkpoint to new ALS
22a56f8 [Xiangrui Meng] wrap old ALS
c387dff [Xiangrui Meng] support random seed
3bdf24b [Xiangrui Meng] make storage level configurable in the new ALS
parent b8ebebea
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ import java.{util => ju}
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.util.Sorting
import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
......@@ -37,6 +38,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom
......@@ -412,7 +414,7 @@ object ALS extends Logging {
/**
* Implementation of the ALS algorithm.
*/
def train[ID: ClassTag](
def train[ID: ClassTag]( // scalastyle:ignore
ratings: RDD[Rating[ID]],
rank: Int = 10,
numUserBlocks: Int = 10,
......@@ -421,34 +423,47 @@ object ALS extends Logging {
regParam: Double = 1.0,
implicitPrefs: Boolean = false,
alpha: Double = 1.0,
nonnegative: Boolean = false)(
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
"ALS is not designed to run without persisting intermediate RDDs.")
val sc = ratings.sparkContext
val userPart = new HashPartitioner(numUserBlocks)
val itemPart = new HashPartitioner(numItemBlocks)
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
val blockRatings = partitionRatings(ratings, userPart, itemPart)
.persist(intermediateRDDStorageLevel)
val (userInBlocks, userOutBlocks) =
makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
// materialize blockRatings and user blocks
userOutBlocks.count()
val swappedBlockRatings = blockRatings.map {
case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
}
val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart)
val (itemInBlocks, itemOutBlocks) =
makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
// materialize item blocks
itemOutBlocks.count()
var userFactors = initialize(userInBlocks, rank)
var itemFactors = initialize(itemInBlocks, rank)
val seedGen = new XORShiftRandom(seed)
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
if (implicitPrefs) {
for (iter <- 1 to maxIter) {
userFactors.setName(s"userFactors-$iter").persist()
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist()
if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
itemFactors.checkpoint()
}
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
......@@ -467,21 +482,23 @@ object ALS extends Logging {
.join(userFactors)
.values
.setName("userFactors")
.cache()
userIdAndFactors.count()
itemFactors.unpersist()
.persist(finalRDDStorageLevel)
val itemIdAndFactors = itemInBlocks
.mapValues(_.srcIds)
.join(itemFactors)
.values
.setName("itemFactors")
.cache()
itemIdAndFactors.count()
userInBlocks.unpersist()
userOutBlocks.unpersist()
itemInBlocks.unpersist()
itemOutBlocks.unpersist()
blockRatings.unpersist()
.persist(finalRDDStorageLevel)
if (finalRDDStorageLevel != StorageLevel.NONE) {
userIdAndFactors.count()
itemFactors.unpersist()
itemIdAndFactors.count()
userInBlocks.unpersist()
userOutBlocks.unpersist()
itemInBlocks.unpersist()
itemOutBlocks.unpersist()
blockRatings.unpersist()
}
val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
ids.view.zip(factors)
}
......@@ -546,14 +563,15 @@ object ALS extends Logging {
*/
private def initialize[ID](
inBlocks: RDD[(Int, InBlock[ID])],
rank: Int): RDD[(Int, FactorBlock)] = {
rank: Int,
seed: Long): RDD[(Int, FactorBlock)] = {
// Choose a unit vector uniformly at random from the unit sphere, but from the
// "first quadrant" where all elements are nonnegative. This can be done by choosing
// elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
// This appears to create factorizations that have a slightly better reconstruction
// (<1%) compared picking elements uniformly at random in [0,1].
inBlocks.map { case (srcBlockId, inBlock) =>
val random = new XORShiftRandom(srcBlockId)
val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId))
val factors = Array.fill(inBlock.srcIds.length) {
val factor = Array.fill(rank)(random.nextGaussian().toFloat)
val nrm = blas.snrm2(rank, factor, 1)
......@@ -877,7 +895,8 @@ object ALS extends Logging {
prefix: String,
ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
srcPart: Partitioner,
dstPart: Partitioner)(
dstPart: Partitioner,
storageLevel: StorageLevel)(
implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
val inBlocks = ratingBlocks.map {
case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
......@@ -914,7 +933,8 @@ object ALS extends Logging {
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
}
builder.build().compress()
}.setName(prefix + "InBlocks").cache()
}.setName(prefix + "InBlocks")
.persist(storageLevel)
val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
val encoder = new LocalIndexEncoder(dstPart.numPartitions)
val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
......@@ -936,7 +956,8 @@ object ALS extends Logging {
activeIds.map { x =>
x.result()
}
}.setName(prefix + "OutBlocks").cache()
}.setName(prefix + "OutBlocks")
.persist(storageLevel)
(inBlocks, outBlocks)
}
......
......@@ -414,7 +414,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val (training, test) =
genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03,
numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
}
}
......
......@@ -24,9 +24,7 @@ import scala.util.Random
import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.recommendation.ALS.BlockStats
import org.apache.spark.storage.StorageLevel
object ALSSuite {
......@@ -189,22 +187,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false)
}
test("analyze one user block and one product block") {
val localRatings = Seq(
Rating(0, 100, 1.0),
Rating(0, 101, 2.0),
Rating(0, 102, 3.0),
Rating(1, 102, 4.0),
Rating(2, 103, 5.0))
val ratings = sc.makeRDD(localRatings, 2)
val stats = ALS.analyzeBlocks(ratings, 1, 1)
assert(stats.size === 2)
assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3))
assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4))
}
// TODO: add tests for analyzing multiple user/product blocks
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
......
......@@ -69,7 +69,12 @@ object MimaExcludes {
) ++ Seq(
// SPARK-5540
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.recommendation.ALS.solveLeastSquares")
"org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"),
// SPARK-5536
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock")
) ++ Seq(
// SPARK-3325
ProblemFilters.exclude[MissingMethodProblem](
......
......@@ -49,17 +49,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> r3 = (2, 1, 2.0)
>>> ratings = sc.parallelize([r1, r2, r3])
>>> model = ALS.trainImplicit(ratings, 1, seed=10)
>>> model.predict(2,2)
0.4473...
>>> model.predict(2, 2)
0.43...
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 1, seed=10)
>>> model = ALS.train(ratings, 2, seed=0)
>>> model.predictAll(testset).collect()
[Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)]
[Rating(user=1, product=1, rating=1.0...), Rating(user=1, product=2, rating=1.9...)]
>>> model = ALS.train(ratings, 4, seed=10)
>>> model.userFeatures().collect()
[(2, array('d', [...])), (1, array('d', [...]))]
[(1, array('d', [...])), (2, array('d', [...]))]
>>> first_user = model.userFeatures().take(1)[0]
>>> latents = first_user[1]
......@@ -67,7 +67,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
True
>>> model.productFeatures().collect()
[(2, array('d', [...])), (1, array('d', [...]))]
[(1, array('d', [...])), (2, array('d', [...]))]
>>> first_product = model.productFeatures().take(1)[0]
>>> latents = first_product[1]
......@@ -76,11 +76,11 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
3.735...
3.8...
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
0.4473...
0.43...
"""
def predict(self, user, product):
return self._java_model.predict(int(user), int(product))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment