diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 14ae50e6657fd1eb1f4026156c54af0c73258a66..4db45c9af8fae765c9edbc132be3e647caa82290 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -138,7 +138,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} * */ - def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2) + (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] /** * Transforms each edge attribute in the graph using the map function. The map function is not @@ -348,7 +349,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} */ def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) - (mapFunc: (VertexId, VD, Option[U]) => VD2) + (mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null) : Graph[VD2, ED] /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 15ea05cbe281d55a9788533d7d114724de9605b9..ccdaa82eb91626092560dc65b1c2e7d3f04e8b7e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -104,8 +104,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse()) } - override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { + override def mapVertices[VD2: ClassTag] + (f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = { + // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left + // null if not + if (eq != null) { vertices.cache() // The map preserves type, so we can use incremental replication val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() @@ -232,8 +235,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) - (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { + (updateF: (VertexId, VD, Option[U]) => VD2) + (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = { + // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left + // null if not + if (eq != null) { vertices.cache() // updateF preserves type, so we can use incremental replication val newVerts = vertices.leftJoin(other)(updateF).cache() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 776bfb8dd6bfa8e3e4e6855bbc7efa67961eeded..82e9e065151794edf3f7b474eec6469ee77e5577 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -41,7 +41,7 @@ object LabelPropagation { * * @return a graph with vertex attributes containing the label of community affiliation */ - def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = { + def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(e: EdgeTriplet[VertexId, ED]) = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index bba070f256d801448e5e2756386c3dd7204faa2d..590f0474957dd739431284ba31161bdcddbfdaf6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -49,7 +49,7 @@ object ShortestPaths { * @return a graph where each vertex attribute is a map containing the shortest-path distance to * each reachable landmark vertex. */ - def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { + def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { val spGraph = graph.mapVertices { (vid, attr) => if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index abc25d0671133986adaf814b651672ee5beaf3bc..6506bac73d71c9aca6467199c24fc7c94431c9c9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("mapVertices changing type with same erased type") { + withSpark { sc => + val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])]( + (1L, Some(1)), + (2L, Some(2)), + (3L, Some(3)) + )) + val edges = sc.parallelize(Array( + Edge(1L, 2L, 0), + Edge(2L, 3L, 0), + Edge(3L, 1L, 0) + )) + val graph0 = Graph(vertices, edges) + // Trigger initial vertex replication + graph0.triplets.foreach(x => {}) + // Change type of replicated vertices, but preserve erased type + val graph1 = graph0.mapVertices { + case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + } + // Access replicated vertices, exposing the erased type + val graph2 = graph1.mapTriplets(t => t.srcAttr.get) + assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + } + } + test("mapEdges") { withSpark { sc => val n = 3