Skip to content
Snippets Groups Projects
Commit bc04fa2e authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-6642][MLLIB] use 1.2 lambda scaling and remove addImplicit from NormalEquation


This PR changes lambda scaling from number of users/items to number of explicit ratings. The latter is the behavior in 1.2. Slight refactor of NormalEquation to make it independent of ALS models. srowen codexiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #5314 from mengxr/SPARK-6642 and squashes the following commits:

dc655a1 [Xiangrui Meng] relax python tests
f410df2 [Xiangrui Meng] use 1.2 scaling and remove addImplicit from NormalEquation

(cherry picked from commit ccafd757)
Signed-off-by: default avatarXiangrui Meng <meng@databricks.com>

Conflicts:
	mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
parent 1c31ebd1
No related branches found
No related tags found
No related merge requests found
...@@ -321,7 +321,7 @@ object ALS extends Logging { ...@@ -321,7 +321,7 @@ object ALS extends Logging {
/** Trait for least squares solvers applied to the normal equation. */ /** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable { private[recommendation] trait LeastSquaresNESolver extends Serializable {
/** Solves a least squares problem (possibly with other constraints). */ /** Solves a least squares problem with regularization (possibly with other constraints). */
def solve(ne: NormalEquation, lambda: Double): Array[Float] def solve(ne: NormalEquation, lambda: Double): Array[Float]
} }
...@@ -333,20 +333,19 @@ object ALS extends Logging { ...@@ -333,20 +333,19 @@ object ALS extends Logging {
/** /**
* Solves a least squares problem with L2 regularization: * Solves a least squares problem with L2 regularization:
* *
* min norm(A x - b)^2^ + lambda * n * norm(x)^2^ * min norm(A x - b)^2^ + lambda * norm(x)^2^
* *
* @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
* @param lambda regularization constant, which will be scaled by n * @param lambda regularization constant
* @return the solution x * @return the solution x
*/ */
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val k = ne.k val k = ne.k
// Add scaled lambda to the diagonals of AtA. // Add scaled lambda to the diagonals of AtA.
val scaledlambda = lambda * ne.n
var i = 0 var i = 0
var j = 2 var j = 2
while (i < ne.triK) { while (i < ne.triK) {
ne.ata(i) += scaledlambda ne.ata(i) += lambda
i += j i += j
j += 1 j += 1
} }
...@@ -392,7 +391,7 @@ object ALS extends Logging { ...@@ -392,7 +391,7 @@ object ALS extends Logging {
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val rank = ne.k val rank = ne.k
initialize(rank) initialize(rank)
fillAtA(ne.ata, lambda * ne.n) fillAtA(ne.ata, lambda)
val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace) val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace)
ne.reset() ne.reset()
x.map(x => x.toFloat) x.map(x => x.toFloat)
...@@ -422,7 +421,15 @@ object ALS extends Logging { ...@@ -422,7 +421,15 @@ object ALS extends Logging {
} }
} }
/** Representing a normal equation (ALS' subproblem). */ /**
* Representing a normal equation to solve the following weighted least squares problem:
*
* minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x.
*
* Its normal equation is given by
*
* \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0.
*/
private[recommendation] class NormalEquation(val k: Int) extends Serializable { private[recommendation] class NormalEquation(val k: Int) extends Serializable {
/** Number of entries in the upper triangular part of a k-by-k matrix. */ /** Number of entries in the upper triangular part of a k-by-k matrix. */
...@@ -431,8 +438,6 @@ object ALS extends Logging { ...@@ -431,8 +438,6 @@ object ALS extends Logging {
val ata = new Array[Double](triK) val ata = new Array[Double](triK)
/** A^T^ * b */ /** A^T^ * b */
val atb = new Array[Double](k) val atb = new Array[Double](k)
/** Number of observations. */
var n = 0
private val da = new Array[Double](k) private val da = new Array[Double](k)
private val upper = "U" private val upper = "U"
...@@ -446,28 +451,13 @@ object ALS extends Logging { ...@@ -446,28 +451,13 @@ object ALS extends Logging {
} }
/** Adds an observation. */ /** Adds an observation. */
def add(a: Array[Float], b: Float): this.type = { def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(a.length == k) require(c >= 0.0)
copyToDouble(a)
blas.dspr(upper, k, 1.0, da, 1, ata)
blas.daxpy(k, b.toDouble, da, 1, atb, 1)
n += 1
this
}
/**
* Adds an observation with implicit feedback. Note that this does not increment the counter.
*/
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
require(a.length == k) require(a.length == k)
// Extension to the original paper to handle b < 0. confidence is a function of |b| instead
// so that it is never negative.
val confidence = 1.0 + alpha * math.abs(b)
copyToDouble(a) copyToDouble(a)
blas.dspr(upper, k, confidence - 1.0, da, 1, ata) blas.dspr(upper, k, c, da, 1, ata)
// For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. if (b != 0.0) {
if (b > 0) { blas.daxpy(k, c * b, da, 1, atb, 1)
blas.daxpy(k, confidence, da, 1, atb, 1)
} }
this this
} }
...@@ -477,7 +467,6 @@ object ALS extends Logging { ...@@ -477,7 +467,6 @@ object ALS extends Logging {
require(other.k == k) require(other.k == k)
blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
n += other.n
this this
} }
...@@ -485,7 +474,6 @@ object ALS extends Logging { ...@@ -485,7 +474,6 @@ object ALS extends Logging {
def reset(): Unit = { def reset(): Unit = {
ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(atb, 0.0) ju.Arrays.fill(atb, 0.0)
n = 0
} }
} }
...@@ -1116,6 +1104,7 @@ object ALS extends Logging { ...@@ -1116,6 +1104,7 @@ object ALS extends Logging {
ls.merge(YtY.get) ls.merge(YtY.get)
} }
var i = srcPtrs(j) var i = srcPtrs(j)
var numExplicits = 0
while (i < srcPtrs(j + 1)) { while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i) val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded) val blockId = srcEncoder.blockId(encoded)
...@@ -1123,13 +1112,23 @@ object ALS extends Logging { ...@@ -1123,13 +1112,23 @@ object ALS extends Logging {
val srcFactor = sortedSrcFactors(blockId)(localIndex) val srcFactor = sortedSrcFactors(blockId)(localIndex)
val rating = ratings(i) val rating = ratings(i)
if (implicitPrefs) { if (implicitPrefs) {
ls.addImplicit(srcFactor, rating, alpha) // Extension to the original paper to handle b < 0. confidence is a function of |b|
// instead so that it is never negative. c1 is confidence - 1.0.
val c1 = alpha * math.abs(rating)
// For rating <= 0, the corresponding preference is 0. So the term below is only added
// for rating > 0. Because YtY is already added, we need to adjust the scaling here.
if (rating > 0) {
numExplicits += 1
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else { } else {
ls.add(srcFactor, rating) ls.add(srcFactor, rating)
numExplicits += 1
} }
i += 1 i += 1
} }
dstFactors(j) = solver.solve(ls, regParam) // Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1 j += 1
} }
dstFactors dstFactors
...@@ -1143,7 +1142,7 @@ object ALS extends Logging { ...@@ -1143,7 +1142,7 @@ object ALS extends Logging {
private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
factorBlocks.values.aggregate(new NormalEquation(rank))( factorBlocks.values.aggregate(new NormalEquation(rank))(
seqOp = (ne, factors) => { seqOp = (ne, factors) => {
factors.foreach(ne.add(_, 0.0f)) factors.foreach(ne.add(_, 0.0))
ne ne
}, },
combOp = (ne1, ne2) => ne1.merge(ne2)) combOp = (ne1, ne2) => ne1.merge(ne2))
......
...@@ -68,39 +68,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { ...@@ -68,39 +68,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
} }
} }
test("normal equation construction with explict feedback") { test("normal equation construction") {
val k = 2 val k = 2
val ne0 = new NormalEquation(k) val ne0 = new NormalEquation(k)
.add(Array(1.0f, 2.0f), 3.0f) .add(Array(1.0f, 2.0f), 3.0)
.add(Array(4.0f, 5.0f), 6.0f) .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted
assert(ne0.k === k) assert(ne0.k === k)
assert(ne0.triK === k * (k + 1) / 2) assert(ne0.triK === k * (k + 1) / 2)
assert(ne0.n === 2)
// NumPy code that computes the expected values: // NumPy code that computes the expected values:
// A = np.matrix("1 2; 4 5") // A = np.matrix("1 2; 4 5")
// b = np.matrix("3; 6") // b = np.matrix("3; 6")
// ata = A.transpose() * A // C = np.matrix(np.diag([1, 2]))
// atb = A.transpose() * b // ata = A.transpose() * C * A
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) // atb = A.transpose() * C * b
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8)
val ne1 = new NormalEquation(2) val ne1 = new NormalEquation(2)
.add(Array(7.0f, 8.0f), 9.0f) .add(Array(7.0f, 8.0f), 9.0)
ne0.merge(ne1) ne0.merge(ne1)
assert(ne0.n === 3)
// NumPy code that computes the expected values: // NumPy code that computes the expected values:
// A = np.matrix("1 2; 4 5; 7 8") // A = np.matrix("1 2; 4 5; 7 8")
// b = np.matrix("3; 6; 9") // b = np.matrix("3; 6; 9")
// ata = A.transpose() * A // C = np.matrix(np.diag([1, 2, 1]))
// atb = A.transpose() * b // ata = A.transpose() * C * A
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) // atb = A.transpose() * C * b
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8)
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
ne0.add(Array(1.0f), 2.0f) ne0.add(Array(1.0f), 2.0)
} }
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0)
}
intercept[IllegalArgumentException] {
ne0.add(Array(1.0f, 2.0f), 0.0, -1.0)
} }
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
val ne2 = new NormalEquation(3) val ne2 = new NormalEquation(3)
...@@ -108,41 +111,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { ...@@ -108,41 +111,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
} }
ne0.reset() ne0.reset()
assert(ne0.n === 0)
assert(ne0.ata.forall(_ == 0.0)) assert(ne0.ata.forall(_ == 0.0))
assert(ne0.atb.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0))
} }
test("normal equation construction with implicit feedback") {
val k = 2
val alpha = 0.5
val ne0 = new NormalEquation(k)
.addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
.addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
.addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
assert(ne0.k === k)
assert(ne0.triK === k * (k + 1) / 2)
assert(ne0.n === 0) // addImplicit doesn't increase the count.
// NumPy code that computes the expected values:
// alpha = 0.5
// A = np.matrix("-5 -4; -2 -1; 1 2")
// b = np.matrix("-3; 0; 3")
// b1 = b > 0
// c = 1.0 + alpha * np.abs(b)
// C = np.diag(c.A1)
// I = np.eye(3)
// ata = A.transpose() * (C - I) * A
// atb = A.transpose() * C * b1
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
}
test("CholeskySolver") { test("CholeskySolver") {
val k = 2 val k = 2
val ne0 = new NormalEquation(k) val ne0 = new NormalEquation(k)
.add(Array(1.0f, 2.0f), 4.0f) .add(Array(1.0f, 2.0f), 4.0)
.add(Array(1.0f, 3.0f), 9.0f) .add(Array(1.0f, 3.0f), 9.0)
.add(Array(1.0f, 4.0f), 16.0f) .add(Array(1.0f, 4.0f), 16.0)
val ne1 = new NormalEquation(k) val ne1 = new NormalEquation(k)
.merge(ne0) .merge(ne0)
...@@ -154,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { ...@@ -154,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
// x0 = np.linalg.lstsq(A, b)[0] // x0 = np.linalg.lstsq(A, b)[0]
assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
assert(ne0.n === 0)
assert(ne0.ata.forall(_ == 0.0)) assert(ne0.ata.forall(_ == 0.0))
assert(ne0.atb.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0))
val x1 = chol.solve(ne1, 0.5).map(_.toDouble) val x1 = chol.solve(ne1, 1.5).map(_.toDouble)
// NumPy code that computes the expected solution, where lambda is scaled by n: // NumPy code that computes the expected solution, where lambda is scaled by n:
// x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b)
assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
} }
......
...@@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ...@@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> ratings = sc.parallelize([r1, r2, r3]) >>> ratings = sc.parallelize([r1, r2, r3])
>>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model = ALS.trainImplicit(ratings, 1, seed=10)
>>> model.predict(2, 2) >>> model.predict(2, 2)
0.43... 0.4...
>>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 2, seed=0) >>> model = ALS.train(ratings, 2, seed=0)
...@@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ...@@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2) >>> model.predict(2,2)
0.43... 0.4...
>>> import os, tempfile >>> import os, tempfile
>>> path = tempfile.mkdtemp() >>> path = tempfile.mkdtemp()
>>> model.save(sc, path) >>> model.save(sc, path)
>>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path)
>>> sameModel.predict(2,2) >>> sameModel.predict(2,2)
0.43... 0.4...
>>> sameModel.predictAll(testset).collect() >>> sameModel.predictAll(testset).collect()
[Rating(... [Rating(...
>>> try: >>> try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment