diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/SVDPlusPlus.scala similarity index 89% rename from graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala rename to graph/src/main/scala/org/apache/spark/graph/algorithms/SVDPlusPlus.scala index 85fa23d30946963da2a77c6e29c14e5aa32d99b0..083aa305388e864b63523a7571119e5ab18bdc58 100644 --- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala +++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/SVDPlusPlus.scala @@ -5,7 +5,7 @@ import org.apache.spark.graph._ import scala.util.Random import org.apache.commons.math.linear._ -class SvdppConf( // Svdpp parameters +class SVDPlusPlusConf( // SVDPlusPlus parameters var rank: Int, var maxIters: Int, var minVal: Double, @@ -15,7 +15,7 @@ class SvdppConf( // Svdpp parameters var gamma6: Double, var gamma7: Double) extends Serializable -object Svdpp { +object SVDPlusPlus { /** * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model", * paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. @@ -23,12 +23,12 @@ object Svdpp { * * @param edges edges for constructing the graph * - * @param conf Svdpp parameters + * @param conf SVDPlusPlus parameters * * @return a graph with vertex attributes containing the trained model */ - def run(edges: RDD[Edge[Double]], conf: SvdppConf): (Graph[(RealVector, RealVector, Double, Double), Double], Double) = { + def run(edges: RDD[Edge[Double]], conf: SVDPlusPlusConf): (Graph[(RealVector, RealVector, Double, Double), Double], Double) = { // generate default vertex attribute def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = { @@ -55,7 +55,7 @@ object Svdpp { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: SvdppConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) + def mapTrainF(conf: SVDPlusPlusConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) : Iterator[(VertexID, (RealVector, RealVector, Double))] = { val (usr, itm) = (et.srcAttr, et.dstAttr) val (p, q) = (usr._1, itm._1) @@ -85,7 +85,7 @@ object Svdpp { } // calculate error on training set - def mapTestF(conf: SvdppConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]): Iterator[(VertexID, Double)] = { + def mapTestF(conf: SVDPlusPlusConf, u: Double)(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]): Iterator[(VertexID, Double)] = { val (usr, itm) = (et.srcAttr, et.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2) diff --git a/graph/src/test/scala/org/apache/spark/graph/algorithms/SvdppSuite.scala b/graph/src/test/scala/org/apache/spark/graph/algorithms/SVDPlusPlusSuite.scala similarity index 72% rename from graph/src/test/scala/org/apache/spark/graph/algorithms/SvdppSuite.scala rename to graph/src/test/scala/org/apache/spark/graph/algorithms/SVDPlusPlusSuite.scala index 411dd3d336c2a478f752d93c6e8a92c361fb14ac..a0a6eb33e36fc942a263bfbe976197fae54228ba 100644 --- a/graph/src/test/scala/org/apache/spark/graph/algorithms/SvdppSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/algorithms/SVDPlusPlusSuite.scala @@ -9,21 +9,21 @@ import org.apache.spark.graph.util.GraphGenerators import org.apache.spark.rdd._ -class SvdppSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => - val SvdppErr = 8.0 + val svdppErr = 8.0 val edges = sc.textFile("mllib/data/als/test.data").map { line => val fields = line.split(",") Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } - val conf = new SvdppConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations - var (graph, u) = Svdpp.run(edges, conf) + val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations + var (graph, u) = SVDPlusPlus.run(edges, conf) val err = graph.vertices.collect.map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 }.reduce(_ + _) / graph.triplets.collect.size - assert(err <= SvdppErr) + assert(err <= svdppErr) } }