From dfb152446d0b987ac15afac77b9c27d77c686d90 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Tue, 14 Jan 2014 22:18:43 -0800 Subject: [PATCH] Fixed SVDPlusPlusSuite in Maven build. --- graphx/src/test/resources/als-test.data | 16 ++++++++++++++++ .../spark/graphx/lib/SVDPlusPlusSuite.scala | 10 +++------- 2 files changed, 19 insertions(+), 7 deletions(-) create mode 100644 graphx/src/test/resources/als-test.data diff --git a/graphx/src/test/resources/als-test.data b/graphx/src/test/resources/als-test.data new file mode 100644 index 0000000000..e476cc23e0 --- /dev/null +++ b/graphx/src/test/resources/als-test.data @@ -0,0 +1,16 @@ +1,1,5.0 +1,2,1.0 +1,3,5.0 +1,4,1.0 +2,1,5.0 +2,2,1.0 +2,3,5.0 +2,4,1.0 +3,1,1.0 +3,2,5.0 +3,3,1.0 +3,4,5.0 +4,1,1.0 +4,2,5.0 +4,3,1.0 +4,4,5.0 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 057d9b3d51..e01df56e94 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -19,11 +19,7 @@ package org.apache.spark.graphx.lib import org.scalatest.FunSuite -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { @@ -31,16 +27,16 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => val svdppErr = 8.0 - val edges = sc.textFile("mllib/data/als/test.data").map { line => + val edges = sc.textFile(getClass.getResource("/als-test.data").getFile).map { line => val fields = line.split(",") Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations var (graph, u) = SVDPlusPlus.run(edges, conf) graph.cache() - val err = graph.vertices.collect.map{ case (vid, vd) => + val err = graph.vertices.collect().map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 - }.reduce(_ + _) / graph.triplets.collect.size + }.reduce(_ + _) / graph.triplets.collect().size assert(err <= svdppErr) } } -- GitLab