Skip to content
Snippets Groups Projects
Commit 087487e9 authored by Patrick Wendell's avatar Patrick Wendell
Browse files

Merge pull request #434 from rxin/graphxmaven

Fixed SVDPlusPlusSuite in Maven build.

This should go into 0.9.0 also.
parents 3a386e23 dfb15244
No related branches found
No related tags found
No related merge requests found
1,1,5.0
1,2,1.0
1,3,5.0
1,4,1.0
2,1,5.0
2,2,1.0
2,3,5.0
2,4,1.0
3,1,1.0
3,2,5.0
3,3,1.0
3,4,5.0
4,1,1.0
4,2,5.0
4,3,1.0
4,4,5.0
......@@ -19,11 +19,7 @@ package org.apache.spark.graphx.lib
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
......@@ -31,16 +27,16 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
test("Test SVD++ with mean square error on training set") {
withSpark { sc =>
val svdppErr = 8.0
val edges = sc.textFile("mllib/data/als/test.data").map { line =>
val edges = sc.textFile(getClass.getResource("/als-test.data").getFile).map { line =>
val fields = line.split(",")
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
var (graph, u) = SVDPlusPlus.run(edges, conf)
graph.cache()
val err = graph.vertices.collect.map{ case (vid, vd) =>
val err = graph.vertices.collect().map{ case (vid, vd) =>
if (vid % 2 == 1) vd._4 else 0.0
}.reduce(_ + _) / graph.triplets.collect.size
}.reduce(_ + _) / graph.triplets.collect().size
assert(err <= svdppErr)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment