Skip to content
Snippets Groups Projects
Commit ccafd757 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-6642][MLLIB] use 1.2 lambda scaling and remove addImplicit from NormalEquation

This PR changes lambda scaling from number of users/items to number of explicit ratings. The latter is the behavior in 1.2. Slight refactor of NormalEquation to make it independent of ALS models. srowen codexiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #5314 from mengxr/SPARK-6642 and squashes the following commits:

dc655a1 [Xiangrui Meng] relax python tests
f410df2 [Xiangrui Meng] use 1.2 scaling and remove addImplicit from NormalEquation
parent f084c5de
No related branches found
No related tags found
No related merge requests found
......@@ -320,7 +320,7 @@ object ALS extends Logging {
/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {
/** Solves a least squares problem (possibly with other constraints). */
/** Solves a least squares problem with regularization (possibly with other constraints). */
def solve(ne: NormalEquation, lambda: Double): Array[Float]
}
......@@ -332,20 +332,19 @@ object ALS extends Logging {
/**
* Solves a least squares problem with L2 regularization:
*
* min norm(A x - b)^2^ + lambda * n * norm(x)^2^
* min norm(A x - b)^2^ + lambda * norm(x)^2^
*
* @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
* @param lambda regularization constant, which will be scaled by n
* @param lambda regularization constant
* @return the solution x
*/
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val k = ne.k
// Add scaled lambda to the diagonals of AtA.
val scaledlambda = lambda * ne.n
var i = 0
var j = 2
while (i < ne.triK) {
ne.ata(i) += scaledlambda
ne.ata(i) += lambda
i += j
j += 1
}
......@@ -391,7 +390,7 @@ object ALS extends Logging {
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val rank = ne.k
initialize(rank)
fillAtA(ne.ata, lambda * ne.n)
fillAtA(ne.ata, lambda)
val x = NNLS.solve(ata, ne.atb, workspace)
ne.reset()
x.map(x => x.toFloat)
......@@ -420,7 +419,15 @@ object ALS extends Logging {
}
}
/** Representing a normal equation (ALS' subproblem). */
/**
* Representing a normal equation to solve the following weighted least squares problem:
*
* minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x.
*
* Its normal equation is given by
*
* \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0.
*/
private[recommendation] class NormalEquation(val k: Int) extends Serializable {
/** Number of entries in the upper triangular part of a k-by-k matrix. */
......@@ -429,8 +436,6 @@ object ALS extends Logging {
val ata = new Array[Double](triK)
/** A^T^ * b */
val atb = new Array[Double](k)
/** Number of observations. */
var n = 0
private val da = new Array[Double](k)
private val upper = "U"
......@@ -444,28 +449,13 @@ object ALS extends Logging {
}
/** Adds an observation. */
def add(a: Array[Float], b: Float): this.type = {
require(a.length == k)
copyToDouble(a)
blas.dspr(upper, k, 1.0, da, 1, ata)
blas.daxpy(k, b.toDouble, da, 1, atb, 1)
n += 1
this
}
/**
* Adds an observation with implicit feedback. Note that this does not increment the counter.
*/
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
require(a.length == k)
// Extension to the original paper to handle b < 0. confidence is a function of |b| instead
// so that it is never negative.
val confidence = 1.0 + alpha * math.abs(b)
copyToDouble(a)
blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
// For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
if (b > 0) {
blas.daxpy(k, confidence, da, 1, atb, 1)
blas.dspr(upper, k, c, da, 1, ata)
if (b != 0.0) {
blas.daxpy(k, c * b, da, 1, atb, 1)
}
this
}
......@@ -475,7 +465,6 @@ object ALS extends Logging {
require(other.k == k)
blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
n += other.n
this
}
......@@ -483,7 +472,6 @@ object ALS extends Logging {
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(atb, 0.0)
n = 0
}
}
......@@ -1114,6 +1102,7 @@ object ALS extends Logging {
ls.merge(YtY.get)
}
var i = srcPtrs(j)
var numExplicits = 0
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
......@@ -1121,13 +1110,23 @@ object ALS extends Logging {
val srcFactor = sortedSrcFactors(blockId)(localIndex)
val rating = ratings(i)
if (implicitPrefs) {
ls.addImplicit(srcFactor, rating, alpha)
// Extension to the original paper to handle b < 0. confidence is a function of |b|
// instead so that it is never negative. c1 is confidence - 1.0.
val c1 = alpha * math.abs(rating)
// For rating <= 0, the corresponding preference is 0. So the term below is only added
// for rating > 0. Because YtY is already added, we need to adjust the scaling here.
if (rating > 0) {
numExplicits += 1
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else {
ls.add(srcFactor, rating)
numExplicits += 1
}
i += 1
}
dstFactors(j) = solver.solve(ls, regParam)
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
}
dstFactors
......@@ -1141,7 +1140,7 @@ object ALS extends Logging {
private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
factorBlocks.values.aggregate(new NormalEquation(rank))(
seqOp = (ne, factors) => {
factors.foreach(ne.add(_, 0.0f))
factors.foreach(ne.add(_, 0.0))
ne
},
combOp = (ne1, ne2) => ne1.merge(ne2))
......
......@@ -68,39 +68,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}
}
test("normal equation construction with explict feedback") {
test("normal equation construction") {
val k = 2
val ne0 = new NormalEquation(k)
.add(Array(1.0f, 2.0f), 3.0f)
.add(Array(4.0f, 5.0f), 6.0f)
.add(Array(1.0f, 2.0f), 3.0)
.add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted
assert(ne0.k === k)
assert(ne0.triK === k * (k + 1) / 2)
assert(ne0.n === 2)
// NumPy code that computes the expected values:
// A = np.matrix("1 2; 4 5")
// b = np.matrix("3; 6")
// ata = A.transpose() * A
// atb = A.transpose() * b
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8)
// C = np.matrix(np.diag([1, 2]))
// ata = A.transpose() * C * A
// atb = A.transpose() * C * b
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8)
val ne1 = new NormalEquation(2)
.add(Array(7.0f, 8.0f), 9.0f)
.add(Array(7.0f, 8.0f), 9.0)
ne0.merge(ne1)
assert(ne0.n === 3)
// NumPy code that computes the expected values:
// A = np.matrix("1 2; 4 5; 7 8")
// b = np.matrix("3; 6; 9")
// ata = A.transpose() * A
// atb = A.transpose() * b
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8)
// C = np.matrix(np.diag([1, 2, 1]))
// ata = A.transpose() * C * A
// atb = A.transpose() * C * b
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8)
intercept[IllegalArgumentException] {
ne0.add(Array(1.0f), 2.0f)
ne0.add(Array(1.0f), 2.0)
}
intercept[IllegalArgumentException] {
ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f)
ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0)
}
intercept[IllegalArgumentException] {
ne0.add(Array(1.0f, 2.0f), 0.0, -1.0)
}
intercept[IllegalArgumentException] {
val ne2 = new NormalEquation(3)
......@@ -108,41 +111,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}
ne0.reset()
assert(ne0.n === 0)
assert(ne0.ata.forall(_ == 0.0))
assert(ne0.atb.forall(_ == 0.0))
}
test("normal equation construction with implicit feedback") {
val k = 2
val alpha = 0.5
val ne0 = new NormalEquation(k)
.addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
.addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
.addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
assert(ne0.k === k)
assert(ne0.triK === k * (k + 1) / 2)
assert(ne0.n === 0) // addImplicit doesn't increase the count.
// NumPy code that computes the expected values:
// alpha = 0.5
// A = np.matrix("-5 -4; -2 -1; 1 2")
// b = np.matrix("-3; 0; 3")
// b1 = b > 0
// c = 1.0 + alpha * np.abs(b)
// C = np.diag(c.A1)
// I = np.eye(3)
// ata = A.transpose() * (C - I) * A
// atb = A.transpose() * C * b1
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
}
test("CholeskySolver") {
val k = 2
val ne0 = new NormalEquation(k)
.add(Array(1.0f, 2.0f), 4.0f)
.add(Array(1.0f, 3.0f), 9.0f)
.add(Array(1.0f, 4.0f), 16.0f)
.add(Array(1.0f, 2.0f), 4.0)
.add(Array(1.0f, 3.0f), 9.0)
.add(Array(1.0f, 4.0f), 16.0)
val ne1 = new NormalEquation(k)
.merge(ne0)
......@@ -154,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
// x0 = np.linalg.lstsq(A, b)[0]
assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
assert(ne0.n === 0)
assert(ne0.ata.forall(_ == 0.0))
assert(ne0.atb.forall(_ == 0.0))
val x1 = chol.solve(ne1, 0.5).map(_.toDouble)
val x1 = chol.solve(ne1, 1.5).map(_.toDouble)
// NumPy code that computes the expected solution, where lambda is scaled by n:
// x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b)
// x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b)
assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
}
......
......@@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> ratings = sc.parallelize([r1, r2, r3])
>>> model = ALS.trainImplicit(ratings, 1, seed=10)
>>> model.predict(2, 2)
0.43...
0.4...
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 2, seed=0)
......@@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
0.43...
0.4...
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = MatrixFactorizationModel.load(sc, path)
>>> sameModel.predict(2,2)
0.43...
0.4...
>>> sameModel.predictAll(testset).collect()
[Rating(...
>>> try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment