Skip to content
Snippets Groups Projects
Commit f133dece authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-5534] [graphx] Graph getStorageLevel fix

This fixes getStorageLevel for EdgeRDDImpl and VertexRDDImpl (and therefore for Graph).

See code example on JIRA which failed before but works with this patch: [https://issues.apache.org/jira/browse/SPARK-5534]
(The added unit tests also failed before but work with this fix.)

Note: I used partitionsRDD, assuming that getStorageLevel will only be called on the driver.

CC: mengxr  (related to LDA PR), rxin  ankurdave   Thanks in advance!

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4317 from jkbradley/graphx-storagelevel and squashes the following commits:

1c21e49 [Joseph K. Bradley] made graph getStorageLevel test more robust
18d64ca [Joseph K. Bradley] Added tests for getStorageLevel in VertexRDDSuite, EdgeRDDSuite, GraphSuite
17b488b [Joseph K. Bradley] overrode getStorageLevel in Vertex/EdgeRDDImpl to use partitionsRDD
parent 8aa3cfff
No related branches found
No related tags found
No related merge requests found
......@@ -70,6 +70,8 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
this
}
override def getStorageLevel = partitionsRDD.getStorageLevel
override def checkpoint() = {
partitionsRDD.checkpoint()
}
......
......@@ -71,6 +71,8 @@ class VertexRDDImpl[VD] private[graphx] (
this
}
override def getStorageLevel = partitionsRDD.getStorageLevel
override def checkpoint() = {
partitionsRDD.checkpoint()
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.graphx
import org.scalatest.FunSuite
import org.apache.spark.storage.StorageLevel
class EdgeRDDSuite extends FunSuite with LocalSparkContext {
test("cache, getStorageLevel") {
// test to see if getStorageLevel returns correct value after caching
withSpark { sc =>
val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3)))
val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
assert(edges.getStorageLevel == StorageLevel.NONE)
edges.cache()
assert(edges.getStorageLevel == StorageLevel.MEMORY_ONLY)
}
}
}
......@@ -25,6 +25,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
class GraphSuite extends FunSuite with LocalSparkContext {
......@@ -390,6 +391,20 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
test("cache, getStorageLevel") {
// test to see if getStorageLevel returns correct value
withSpark { sc =>
val verts = sc.parallelize(List((1: VertexId, "a"), (2: VertexId, "b")), 1)
val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2)
val graph = Graph(verts, edges, "", StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY)
// Note: Before caching, graph.vertices is cached, but graph.edges is not (but graph.edges'
// parent RDD is cached).
graph.cache()
assert(graph.vertices.getStorageLevel == StorageLevel.MEMORY_ONLY)
assert(graph.edges.getStorageLevel == StorageLevel.MEMORY_ONLY)
}
}
test("non-default number of edge partitions") {
val n = 10
val defaultParallelism = 3
......
......@@ -17,12 +17,11 @@
package org.apache.spark.graphx
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.impl.EdgePartition
import org.apache.spark.rdd._
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.storage.StorageLevel
class VertexRDDSuite extends FunSuite with LocalSparkContext {
def vertices(sc: SparkContext, n: Int) = {
......@@ -110,4 +109,16 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
}
}
test("cache, getStorageLevel") {
// test to see if getStorageLevel returns correct value after caching
withSpark { sc =>
val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3)))
val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b)
assert(rdd.getStorageLevel == StorageLevel.NONE)
rdd.cache()
assert(rdd.getStorageLevel == StorageLevel.MEMORY_ONLY)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment