diff --git a/graphx/src/test/resources/als-test.data b/graphx/src/test/resources/als-test.data new file mode 100644 index 0000000000000000000000000000000000000000..e476cc23e047d78aa4e2bda8fa2f5cea631ef708 --- /dev/null +++ b/graphx/src/test/resources/als-test.data @@ -0,0 +1,16 @@ +1,1,5.0 +1,2,1.0 +1,3,5.0 +1,4,1.0 +2,1,5.0 +2,2,1.0 +2,3,5.0 +2,4,1.0 +3,1,1.0 +3,2,5.0 +3,3,1.0 +3,4,5.0 +4,1,1.0 +4,2,5.0 +4,3,1.0 +4,4,5.0 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 057d9b3d518e010a0d4883d63208a4dc335e5a3f..e01df56e94de937a49e358dc35f1f7f59a1d641a 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -19,11 +19,7 @@ package org.apache.spark.graphx.lib import org.scalatest.FunSuite -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { @@ -31,16 +27,16 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => val svdppErr = 8.0 - val edges = sc.textFile("mllib/data/als/test.data").map { line => + val edges = sc.textFile(getClass.getResource("/als-test.data").getFile).map { line => val fields = line.split(",") Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations var (graph, u) = SVDPlusPlus.run(edges, conf) graph.cache() - val err = graph.vertices.collect.map{ case (vid, vd) => + val err = graph.vertices.collect().map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 - }.reduce(_ + _) / graph.triplets.collect.size + }.reduce(_ + _) / graph.triplets.collect().size assert(err <= svdppErr) } }