Skip to content
Snippets Groups Projects
Commit 2c39d809 authored by Reynold Xin's avatar Reynold Xin
Browse files

Merge pull request #69 from jegonzal/MissingVertices

Addressing issue in Graph creation
parents a81fcb74 33b2deaf
No related branches found
No related tags found
No related merge requests found
......@@ -404,6 +404,30 @@ object Graph {
}
/**
* Construct a graph from a collection attributed vertices and
* edges. Duplicate vertices are combined using the `mergeFunc` and
* vertices found in the edge collection but not in the input
* vertices are the default attribute `defautVertexAttr`.
*
* @tparam VD the vertex attribute type
* @tparam ED the edge attribute type
* @param vertices the "set" of vertices and their attributes
* @param edges the collection of edges in the graph
* @param defaultVertexAttr the default vertex attribute to use for
* vertices that are mentioned in `edges` but not in `vertices
* @param mergeFunc the function used to merge duplicate vertices
* in the `vertices` collection.
*
*/
def apply[VD: ClassManifest, ED: ClassManifest](
vertices: RDD[(Vid,VD)],
edges: RDD[Edge[ED]],
defaultVertexAttr: VD): Graph[VD, ED] = {
GraphImpl(vertices, edges, defaultVertexAttr, (a,b) => a)
}
/**
* Construct a graph from a collection attributed vertices and
* edges. Duplicate vertices are combined using the `mergeFunc` and
......
......@@ -204,6 +204,31 @@ class VertexSetRDD[@specialized V: ClassManifest](
new VertexSetRDD[U](index, newValuesRDD)
} // end of mapValues
/**
* Fill in missing values for all vertices in the index.
*
* @param missingValue the value to be used for vertices in the
* index that don't currently have values.
* @return A VertexSetRDD with a value for all vertices.
*/
def fillMissing(missingValue: V): VertexSetRDD[V] = {
val newValuesRDD: RDD[ (Array[V], BitSet) ] =
valuesRDD.zipPartitions(index.rdd){ (valuesIter, indexIter) =>
val index = indexIter.next
assert(!indexIter.hasNext)
val (values, bs: BitSet) = valuesIter.next
assert(!valuesIter.hasNext)
// Allocate a new values array with missing value as the default
val newValues = Array.fill(values.size)(missingValue)
// Copy over the old values
bs.iterator.foreach { ind => newValues(ind) = values(ind) }
// Create a new bitset matching the keyset
val newBS = index.getBitSet
Iterator((newValues, newBS))
}
new VertexSetRDD[V](index, newValuesRDD)
}
/**
* Pass each vertex attribute along with the vertex id through a map
* function and retain the original RDD's partitioning and index.
......@@ -380,7 +405,6 @@ class VertexSetRDD[@specialized V: ClassManifest](
// this vertex set then we use the much more efficient leftZipJoin
case other: VertexSetRDD[_] if index == other.index => {
leftZipJoin(other)(cleanF)
// @todo handle case where other is a VertexSetRDD with a different index
}
case _ => {
val indexedOther: VertexSetRDD[W] = VertexSetRDD(other, index, cleanMerge)
......@@ -599,28 +623,24 @@ object VertexSetRDD {
* can be used to build VertexSets over subsets of the vertices in
* the input.
*/
def makeIndex(keys: RDD[Vid],
partitioner: Option[Partitioner] = None): VertexSetIndex = {
// @todo: I don't need the boolean its only there to be the second type since I want to shuffle a single RDD
// Ugly hack :-(. In order to partition the keys they must have values.
val tbl = keys.mapPartitions(_.map(k => (k, false)), true)
// Shuffle the table (if necessary)
val shuffledTbl = partitioner match {
case None => {
if (tbl.partitioner.isEmpty) {
// @todo: I don't need the boolean its only there to be the second type of the shuffle.
new ShuffledRDD[Vid, Boolean, (Vid, Boolean)](tbl, Partitioner.defaultPartitioner(tbl))
} else { tbl }
}
case Some(partitioner) =>
tbl.partitionBy(partitioner)
def makeIndex(keys: RDD[Vid], partitionerOpt: Option[Partitioner] = None): VertexSetIndex = {
val partitioner = partitionerOpt match {
case Some(p) => p
case None => Partitioner.defaultPartitioner(keys)
}
val index = shuffledTbl.mapPartitions( iter => {
val preAgg: RDD[(Vid, Unit)] = keys.mapPartitions( iter => {
val keys = new VertexIdToIndexMap
while(iter.hasNext) { keys.add(iter.next) }
keys.iterator.map(k => (k, ()))
}, true).partitionBy(partitioner)
val index = preAgg.mapPartitions( iter => {
val index = new VertexIdToIndexMap
for ( (k,_) <- iter ){ index.add(k) }
while(iter.hasNext) { index.add(iter.next._1) }
Iterator(index)
}, true).cache
}, true).cache
new VertexSetIndex(index)
}
......
......@@ -9,6 +9,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.HashPartitioner
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.Partitioner
import org.apache.spark.graph._
import org.apache.spark.graph.impl.GraphImpl._
import org.apache.spark.graph.impl.MsgRDDFunctions._
......@@ -320,20 +321,21 @@ object GraphImpl {
defaultVertexAttr: VD,
mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = {
val vtable = VertexSetRDD(vertices, mergeFunc)
/**
* @todo Verify that there are no edges that contain vertices
* that are not in vTable. This should probably be resolved:
*
* edges.flatMap{ e => Array((e.srcId, null), (e.dstId, null)) }
* .cogroup(vertices).map{
* case (vid, _, attr) =>
* if (attr.isEmpty) (vid, defaultValue)
* else (vid, attr)
* }
*
*/
val etable = createETable(edges)
vertices.cache
val etable = createETable(edges).cache
// Get the set of all vids, preserving partitions
val partitioner = Partitioner.defaultPartitioner(vertices)
val implicitVids = etable.flatMap {
case (pid, partition) => Array.concat(partition.srcIds, partition.dstIds)
}.map(vid => (vid, ())).partitionBy(partitioner)
val allVids = vertices.zipPartitions(implicitVids) {
(a, b) => a.map(_._1) ++ b.map(_._1)
}
// Index the set of all vids
val index = VertexSetRDD.makeIndex(allVids, Some(partitioner))
// Index the vertices and fill in missing attributes with the default
val vtable = VertexSetRDD(vertices, index, mergeFunc).fillMissing(defaultVertexAttr)
val vid2pid = new Vid2Pid(etable, vtable.index)
val localVidMap = createLocalVidMap(etable)
new GraphImpl(vtable, vid2pid, localVidMap, etable)
......
......@@ -4,7 +4,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.graph.LocalSparkContext._
import org.apache.spark.rdd._
class GraphSuite extends FunSuite with LocalSparkContext {
......@@ -20,6 +20,21 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
test("Graph Creation with invalid vertices") {
withSpark(new SparkContext("local", "test")) { sc =>
val rawEdges = (0L to 98L).zip((1L to 99L) :+ 0L)
val edges: RDD[Edge[Int]] = sc.parallelize(rawEdges).map { case (s, t) => Edge(s, t, 1) }
val vertices: RDD[(Vid, Boolean)] = sc.parallelize((0L until 10L).map(id => (id, true)))
val graph = Graph(vertices, edges, false)
assert( graph.edges.count() === rawEdges.size )
assert( graph.vertices.count() === 100)
graph.triplets.map { et =>
assert( (et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr) )
assert( (et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr) )
}
}
}
test("mapEdges") {
withSpark(new SparkContext("local", "test")) { sc =>
val n = 3
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment