Skip to content
Snippets Groups Projects
Commit 8e5c7324 authored by Reynold Xin's avatar Reynold Xin
Browse files

Moved SVDPlusPlusConf into SVDPlusPlus object itself.

parent 1dce9ce4
No related branches found
No related tags found
No related merge requests found
......@@ -5,19 +5,21 @@ import org.apache.commons.math.linear._
import org.apache.spark.rdd._
import org.apache.spark.graphx._
/** Configuration parameters for SVDPlusPlus. */
class SVDPlusPlusConf(
var rank: Int,
var maxIters: Int,
var minVal: Double,
var maxVal: Double,
var gamma1: Double,
var gamma2: Double,
var gamma6: Double,
var gamma7: Double) extends Serializable
/** Implementation of SVD++ algorithm. */
object SVDPlusPlus {
/** Configuration parameters for SVDPlusPlus. */
class Conf(
var rank: Int,
var maxIters: Int,
var minVal: Double,
var maxVal: Double,
var gamma1: Double,
var gamma2: Double,
var gamma6: Double,
var gamma7: Double)
extends Serializable
/**
* Implement SVD++ based on "Factorization Meets the Neighborhood:
* a Multifaceted Collaborative Filtering Model",
......@@ -32,7 +34,7 @@ object SVDPlusPlus {
*
* @return a graph with vertex attributes containing the trained model
*/
def run(edges: RDD[Edge[Double]], conf: SVDPlusPlusConf)
def run(edges: RDD[Edge[Double]], conf: Conf)
: (Graph[(RealVector, RealVector, Double, Double), Double], Double) =
{
// Generate default vertex attribute
......@@ -64,7 +66,7 @@ object SVDPlusPlus {
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
}
def mapTrainF(conf: SVDPlusPlusConf, u: Double)
def mapTrainF(conf: Conf, u: Double)
(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
: Iterator[(VertexID, (RealVector, RealVector, Double))] = {
val (usr, itm) = (et.srcAttr, et.dstAttr)
......@@ -112,7 +114,7 @@ object SVDPlusPlus {
}
// calculate error on training set
def mapTestF(conf: SVDPlusPlusConf, u: Double)
def mapTestF(conf: Conf, u: Double)
(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
: Iterator[(VertexID, Double)] =
{
......
......@@ -18,7 +18,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
val fields = line.split(",")
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
var (graph, u) = SVDPlusPlus.run(edges, conf)
graph.cache()
val err = graph.vertices.collect.map{ case (vid, vd) =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment