Skip to content
Snippets Groups Projects
Commit 46d50f15 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-5513][MLLIB] Add nonnegative option to ml's ALS

This PR ports the NNLS solver to the new ALS implementation.

CC: coderxiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #4302 from mengxr/SPARK-5513 and squashes the following commits:

4cbdab0 [Xiangrui Meng] fix serialization
88de634 [Xiangrui Meng] add NNLS to ml's ALS
parent 1646f89d
No related branches found
No related tags found
No related merge requests found
......@@ -25,12 +25,14 @@ import scala.util.Sorting
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.jblas.DoubleMatrix
import org.netlib.util.intW
import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dsl._
......@@ -80,6 +82,10 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
def getRatingCol: String = get(ratingCol)
val nonnegative = new BooleanParam(
this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false))
val getNonnegative: Boolean = get(nonnegative)
/**
* Validates and transforms the input schema.
* @param schema input schema
......@@ -186,6 +192,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
def setPredictionCol(value: String): this.type = set(predictionCol, value)
def setMaxIter(value: Int): this.type = set(maxIter, value)
def setRegParam(value: Double): this.type = set(regParam, value)
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
/** Sets both numUserBlocks and numItemBlocks to the specific value. */
def setNumBlocks(value: Int): this.type = {
......@@ -207,7 +214,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
alpha = map(alpha))
alpha = map(alpha), nonnegative = map(nonnegative))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
......@@ -232,11 +239,16 @@ object ALS extends Logging {
/** Rating class for better code readability. */
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {
/** Solves a least squares problem (possibly with other constraints). */
def solve(ne: NormalEquation, lambda: Double): Array[Float]
}
/** Cholesky solver for least square problems. */
private[recommendation] class CholeskySolver {
private[recommendation] class CholeskySolver extends LeastSquaresNESolver {
private val upper = "U"
private val info = new intW(0)
/**
* Solves a least squares problem with L2 regularization:
......@@ -247,7 +259,7 @@ object ALS extends Logging {
* @param lambda regularization constant, which will be scaled by n
* @return the solution x
*/
def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val k = ne.k
// Add scaled lambda to the diagonals of AtA.
val scaledlambda = lambda * ne.n
......@@ -258,6 +270,7 @@ object ALS extends Logging {
i += j
j += 1
}
val info = new intW(0)
lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info)
val code = info.`val`
assert(code == 0, s"lapack.dppsv returned $code.")
......@@ -272,6 +285,63 @@ object ALS extends Logging {
}
}
/** NNLS solver. */
private[recommendation] class NNLSSolver extends LeastSquaresNESolver {
private var rank: Int = -1
private var workspace: NNLS.Workspace = _
private var ata: DoubleMatrix = _
private var initialized: Boolean = false
private def initialize(rank: Int): Unit = {
if (!initialized) {
this.rank = rank
workspace = NNLS.createWorkspace(rank)
ata = new DoubleMatrix(rank, rank)
initialized = true
} else {
require(this.rank == rank)
}
}
/**
* Solves a nonnegative least squares problem with L2 regularizatin:
*
* min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^
* subject to x >= 0
*/
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val rank = ne.k
initialize(rank)
fillAtA(ne.ata, lambda * ne.n)
val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace)
ne.reset()
x.map(x => x.toFloat)
}
/**
* Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
* matrix that it represents, storing it into destMatrix.
*/
private def fillAtA(triAtA: Array[Double], lambda: Double) {
var i = 0
var pos = 0
var a = 0.0
val data = ata.data
while (i < rank) {
var j = 0
while (j <= i) {
a = triAtA(pos)
data(i * rank + j) = a
data(j * rank + i) = a
pos += 1
j += 1
}
data(i * rank + i) += lambda
i += 1
}
}
}
/** Representing a normal equation (ALS' subproblem). */
private[recommendation] class NormalEquation(val k: Int) extends Serializable {
......@@ -350,12 +420,14 @@ object ALS extends Logging {
maxIter: Int = 10,
regParam: Double = 1.0,
implicitPrefs: Boolean = false,
alpha: Double = 1.0)(
alpha: Double = 1.0,
nonnegative: Boolean = false)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
val userPart = new HashPartitioner(numUserBlocks)
val itemPart = new HashPartitioner(numItemBlocks)
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
// materialize blockRatings and user blocks
......@@ -374,20 +446,20 @@ object ALS extends Logging {
userFactors.setName(s"userFactors-$iter").persist()
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha)
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist()
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha)
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder)
userLocalIndexEncoder, solver = solver)
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder)
itemLocalIndexEncoder, solver = solver)
}
}
val userIdAndFactors = userInBlocks
......@@ -879,6 +951,7 @@ object ALS extends Logging {
* @param srcEncoder encoder for src local indices
* @param implicitPrefs whether to use implicit preference
* @param alpha the alpha constant in the implicit preference formulation
* @param solver solver for least squares problems
*
* @return dst factors
*/
......@@ -890,7 +963,8 @@ object ALS extends Logging {
regParam: Double,
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
alpha: Double = 1.0,
solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
......@@ -909,7 +983,6 @@ object ALS extends Logging {
val dstFactors = new Array[Array[Float]](dstIds.length)
var j = 0
val ls = new NormalEquation(rank)
val solver = new CholeskySolver // TODO: add NNLS solver
while (j < dstIds.length) {
ls.reset()
if (implicitPrefs) {
......
......@@ -19,13 +19,11 @@ package org.apache.spark.mllib.optimization
import org.jblas.{DoubleMatrix, SimpleBlas}
import org.apache.spark.annotation.DeveloperApi
/**
* Object used to solve nonnegative least squares problems using a modified
* projected gradient method.
*/
private[mllib] object NNLS {
private[spark] object NNLS {
class Workspace(val n: Int) {
val scratch = new DoubleMatrix(n, 1)
val grad = new DoubleMatrix(n, 1)
......
......@@ -444,4 +444,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
assert(strUserFactors.first()._1.getClass === classOf[String])
}
test("nonnegative constraint") {
val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true)
def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = {
factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _)
}
assert(isNonnegative(userFactors))
assert(isNonnegative(itemFactors))
// TODO: Validate the solution.
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment