Skip to content
Snippets Groups Projects
Commit dfb15244 authored by Reynold Xin's avatar Reynold Xin
Browse files

Fixed SVDPlusPlusSuite in Maven build.

parent 74b46acd
No related branches found
No related tags found
No related merge requests found
1,1,5.0
1,2,1.0
1,3,5.0
1,4,1.0
2,1,5.0
2,2,1.0
2,3,5.0
2,4,1.0
3,1,1.0
3,2,5.0
3,3,1.0
3,4,5.0
4,1,1.0
4,2,5.0
4,3,1.0
4,4,5.0
......@@ -19,11 +19,7 @@ package org.apache.spark.graphx.lib
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
......@@ -31,16 +27,16 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
test("Test SVD++ with mean square error on training set") {
withSpark { sc =>
val svdppErr = 8.0
val edges = sc.textFile("mllib/data/als/test.data").map { line =>
val edges = sc.textFile(getClass.getResource("/als-test.data").getFile).map { line =>
val fields = line.split(",")
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
var (graph, u) = SVDPlusPlus.run(edges, conf)
graph.cache()
val err = graph.vertices.collect.map{ case (vid, vd) =>
val err = graph.vertices.collect().map{ case (vid, vd) =>
if (vid % 2 == 1) vd._4 else 0.0
}.reduce(_ + _) / graph.triplets.collect.size
}.reduce(_ + _) / graph.triplets.collect().size
assert(err <= svdppErr)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment