Skip to content
Snippets Groups Projects
Commit 100718bc authored by Ankur Dave's avatar Ankur Dave
Browse files

Svdpp -> SVDPlusPlus

parent 43e1bdc8
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,7 @@ import org.apache.spark.graph._ ...@@ -5,7 +5,7 @@ import org.apache.spark.graph._
import scala.util.Random import scala.util.Random
import org.apache.commons.math.linear._ import org.apache.commons.math.linear._
class SvdppConf( // Svdpp parameters class SVDPlusPlusConf( // SVDPlusPlus parameters
var rank: Int, var rank: Int,
var maxIters: Int, var maxIters: Int,
var minVal: Double, var minVal: Double,
...@@ -15,7 +15,7 @@ class SvdppConf( // Svdpp parameters ...@@ -15,7 +15,7 @@ class SvdppConf( // Svdpp parameters
var gamma6: Double, var gamma6: Double,
var gamma7: Double) extends Serializable var gamma7: Double) extends Serializable
object Svdpp { object SVDPlusPlus {
/** /**
* Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model", * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model",
* paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. * paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
...@@ -23,12 +23,12 @@ object Svdpp { ...@@ -23,12 +23,12 @@ object Svdpp {
* *
* @param edges edges for constructing the graph * @param edges edges for constructing the graph
* *
* @param conf Svdpp parameters * @param conf SVDPlusPlus parameters
* *
* @return a graph with vertex attributes containing the trained model * @return a graph with vertex attributes containing the trained model
*/ */
def run(edges: RDD[Edge[Double]], conf: SvdppConf): (Graph[(RealVector, RealVector, Double, Double), Double], Double) = { def run(edges: RDD[Edge[Double]], conf: SVDPlusPlusConf): (Graph[(RealVector, RealVector, Double, Double), Double], Double) = {
// generate default vertex attribute // generate default vertex attribute
def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = { def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = {
...@@ -55,7 +55,7 @@ object Svdpp { ...@@ -55,7 +55,7 @@ object Svdpp {
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
} }
def mapTrainF(conf: SvdppConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) def mapTrainF(conf: SVDPlusPlusConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
: Iterator[(VertexID, (RealVector, RealVector, Double))] = { : Iterator[(VertexID, (RealVector, RealVector, Double))] = {
val (usr, itm) = (et.srcAttr, et.dstAttr) val (usr, itm) = (et.srcAttr, et.dstAttr)
val (p, q) = (usr._1, itm._1) val (p, q) = (usr._1, itm._1)
...@@ -85,7 +85,7 @@ object Svdpp { ...@@ -85,7 +85,7 @@ object Svdpp {
} }
// calculate error on training set // calculate error on training set
def mapTestF(conf: SvdppConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]): Iterator[(VertexID, Double)] = { def mapTestF(conf: SVDPlusPlusConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]): Iterator[(VertexID, Double)] = {
val (usr, itm) = (et.srcAttr, et.dstAttr) val (usr, itm) = (et.srcAttr, et.dstAttr)
val (p, q) = (usr._1, itm._1) val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2) var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
......
...@@ -9,21 +9,21 @@ import org.apache.spark.graph.util.GraphGenerators ...@@ -9,21 +9,21 @@ import org.apache.spark.graph.util.GraphGenerators
import org.apache.spark.rdd._ import org.apache.spark.rdd._
class SvdppSuite extends FunSuite with LocalSparkContext { class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
test("Test SVD++ with mean square error on training set") { test("Test SVD++ with mean square error on training set") {
withSpark { sc => withSpark { sc =>
val SvdppErr = 8.0 val svdppErr = 8.0
val edges = sc.textFile("mllib/data/als/test.data").map { line => val edges = sc.textFile("mllib/data/als/test.data").map { line =>
val fields = line.split(",") val fields = line.split(",")
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
} }
val conf = new SvdppConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
var (graph, u) = Svdpp.run(edges, conf) var (graph, u) = SVDPlusPlus.run(edges, conf)
val err = graph.vertices.collect.map{ case (vid, vd) => val err = graph.vertices.collect.map{ case (vid, vd) =>
if (vid % 2 == 1) vd._4 else 0.0 if (vid % 2 == 1) vd._4 else 0.0
}.reduce(_ + _) / graph.triplets.collect.size }.reduce(_ + _) / graph.triplets.collect.size
assert(err <= SvdppErr) assert(err <= svdppErr)
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment