diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala index a67cc44f6ec8d4889cf67fc3edb27c8827f970c6..6beaea07fa060094aa0ad217a1240a12c0df7f26 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala @@ -18,7 +18,7 @@ object Analytics extends Logging { * Run PageRank for a fixed number of iterations returning a graph * with vertex attributes containing the PageRank and edge * attributes the normalized edge weight. - * + * * The following PageRank fixed point is computed for each vertex. * * {{{ @@ -35,7 +35,7 @@ object Analytics extends Logging { * where `alpha` is the random reset probability (typically 0.15), * `inNbrs[i]` is the set of neighbors whick link to `i` and * `outDeg[j]` is the out degree of vertex `j`. - * + * * Note that this is not the "normalized" PageRank and as a * consequence pages that have no inlinks will have a PageRank of * alpha. @@ -52,7 +52,7 @@ object Analytics extends Logging { * */ def pagerank[VD: Manifest, ED: Manifest]( - graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): + graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = { /** @@ -76,13 +76,13 @@ object Analytics extends Logging { // version of Pregel def vertexProgram(id: Vid, attr: Double, msgSum: Double): Double = resetProb + (1.0 - resetProb) * msgSum - def sendMessage(id: Vid, edge: EdgeTriplet[Double, Double]): Option[Double] = - Some(edge.srcAttr * edge.attr) + def sendMessage(edge: EdgeTriplet[Double, Double]) = + Array((edge.dstId, edge.srcAttr * edge.attr)) def messageCombiner(a: Double, b: Double): Double = a + b // The initial message received by all vertices in PageRank - val initialMessage = 0.0 + val initialMessage = 0.0 - // Execute pregel for a fixed number of iterations. + // Execute pregel for a fixed number of iterations. Pregel(pagerankGraph, initialMessage, numIter)( vertexProgram, sendMessage, messageCombiner) } @@ -107,7 +107,7 @@ object Analytics extends Logging { * where `alpha` is the random reset probability (typically 0.15), * `inNbrs[i]` is the set of neighbors whick link to `i` and * `outDeg[j]` is the out degree of vertex `j`. - * + * * Note that this is not the "normalized" PageRank and as a * consequence pages that have no inlinks will have a PageRank of * alpha. @@ -124,11 +124,11 @@ object Analytics extends Logging { * PageRank and each edge containing the normalized weight. */ def deltaPagerank[VD: Manifest, ED: Manifest]( - graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = { /** - * Initialize the pagerankGraph with each edge attribute + * Initialize the pagerankGraph with each edge attribute * having weight 1/outDegree and each vertex with attribute 1.0. */ val pagerankGraph: Graph[(Double, Double), Double] = graph @@ -136,7 +136,7 @@ object Analytics extends Logging { .outerJoinVertices(graph.outDegrees){ (vid, vdata, deg) => deg.getOrElse(0) } - // Set the weight on the edges based on the degree + // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initalPR, delta = 0) .mapVertices( (id, attr) => (0.0, 0.0) ) @@ -151,16 +151,16 @@ object Analytics extends Logging { val newPR = oldPR + (1.0 - resetProb) * msgSum (newPR, newPR - oldPR) } - def sendMessage(id: Vid, edge: EdgeTriplet[(Double, Double), Double]): Option[Double] = { + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { if (edge.srcAttr._2 > tol) { - Some(edge.srcAttr._2 * edge.attr) - } else { None } - } + Array((edge.dstId, edge.srcAttr._2 * edge.attr)) + } else { Array.empty[(Vid, Double)] } + } def messageCombiner(a: Double, b: Double): Double = a + b // The initial message received by all vertices in PageRank val initialMessage = resetProb / (1.0 - resetProb) - // Execute a dynamic version of Pregel. + // Execute a dynamic version of Pregel. Pregel(pagerankGraph, initialMessage)( vertexProgram, sendMessage, messageCombiner) .mapVertices( (vid, attr) => attr._1 ) @@ -182,26 +182,28 @@ object Analytics extends Logging { * @return a graph with vertex attributes containing the smallest * vertex in each connected component */ - def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): + def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): Graph[Vid, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = { - val thisAttr = edge.vertexAttr(id) - val otherAttr = edge.otherVertexAttr(id) - if(thisAttr < otherAttr) { Some(thisAttr) } - else { None } + def sendMessage(edge: EdgeTriplet[Vid, ED]) = { + if (edge.srcAttr < edge.dstAttr) { + Array((edge.dstId, edge.srcAttr)) + } else if (edge.srcAttr > edge.dstAttr) { + Array((edge.srcId, edge.dstAttr)) + } else { + Array.empty[(Vid, Vid)] + } } - val initialMessage = Long.MaxValue - Pregel(ccGraph, initialMessage, EdgeDirection.Both)( + Pregel(ccGraph, initialMessage)( (id, attr, msg) => math.min(attr, msg), - sendMessage, + sendMessage, (a,b) => math.min(a,b) ) - } // end of connectedComponents + } // end of connectedComponents + - def main(args: Array[String]) = { val host = args(0) @@ -213,7 +215,7 @@ object Analytics extends Logging { case _ => throw new IllegalArgumentException("Invalid argument: " + arg) } } - + def setLogLevels(level: org.apache.log4j.Level, loggers: TraversableOnce[String]) = { loggers.map{ loggerName => @@ -265,7 +267,7 @@ object Analytics extends Logging { val sc = new SparkContext(host, "PageRank(" + fname + ")") - val graph = GraphLoader.textFile(sc, fname, a => 1.0F, + val graph = GraphLoader.textFile(sc, fname, a => 1.0F, minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache() val startTime = System.currentTimeMillis @@ -314,7 +316,7 @@ object Analytics extends Logging { val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")") //val graph = GraphLoader.textFile(sc, fname, a => 1.0F) - val graph = GraphLoader.textFile(sc, fname, a => 1.0F, + val graph = GraphLoader.textFile(sc, fname, a => 1.0F, minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache() val cc = Analytics.connectedComponents(graph) //val cc = if(isDynamic) Analytics.dynamicConnectedComponents(graph, numIter) diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala b/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala index 8c7f4c25e295231dd35c4baa4ad07b70d8e36350..5fd8cd699106e5b93a9249a9643353fc15492f07 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala @@ -11,21 +11,21 @@ import org.apache.spark.util.ClosureCleaner * the graph type and is implicitly constructed for each Graph object. * All operations in `GraphOps` are expressed in terms of the * efficient GraphX API. - * + * * @tparam VD the vertex attribute type - * @tparam ED the edge attribute type + * @tparam ED the edge attribute type * */ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { /** - * Compute the number of edges in the graph. + * Compute the number of edges in the graph. */ lazy val numEdges: Long = graph.edges.count() /** - * Compute the number of vertices in the graph. + * Compute the number of vertices in the graph. */ lazy val numVertices: Long = graph.vertices.count() @@ -39,7 +39,7 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { /** - * Compute the out-degree of each vertex in the Graph returning an RDD. + * Compute the out-degree of each vertex in the Graph returning an RDD. * @note Vertices with no out edges are not returned in the resulting RDD. */ lazy val outDegrees: VertexSetRDD[Int] = degreesRDD(EdgeDirection.Out) @@ -60,7 +60,13 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { * neighboring vertex attributes. */ private def degreesRDD(edgeDirection: EdgeDirection): VertexSetRDD[Int] = { - graph.aggregateNeighbors((vid, edge) => Some(1), _+_, edgeDirection) + if (edgeDirection == EdgeDirection.In) { + graph.mapReduceTriplets(et => Array((et.dstId,1)), _+_) + } else if (edgeDirection == EdgeDirection.Out) { + graph.mapReduceTriplets(et => Array((et.srcId,1)), _+_) + } else { // EdgeDirection.both + graph.mapReduceTriplets(et => Array((et.srcId,1), (et.dstId,1)), _+_) + } } @@ -89,7 +95,7 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { * * @example We can use this function to compute the average follower * age for each user - * + * * {{{ * val graph: Graph[Int,Int] = loadGraph() * val averageFollowerAge: RDD[(Int, Int)] = @@ -113,15 +119,15 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { ClosureCleaner.clean(mapFunc) ClosureCleaner.clean(reduceFunc) - // Define a new map function over edge triplets + // Define a new map function over edge triplets val mf = (et: EdgeTriplet[VD,ED]) => { // Compute the message to the dst vertex - val dst = + val dst = if (dir == EdgeDirection.In || dir == EdgeDirection.Both) { mapFunc(et.dstId, et) } else { Option.empty[A] } // Compute the message to the source vertex - val src = + val src = if (dir == EdgeDirection.Out || dir == EdgeDirection.Both) { mapFunc(et.srcId, et) } else { Option.empty[A] } @@ -130,7 +136,7 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { case (None, None) => Array.empty[(Vid, A)] case (Some(srcA),None) => Array((et.srcId, srcA)) case (None, Some(dstA)) => Array((et.dstId, dstA)) - case (Some(srcA), Some(dstA)) => + case (Some(srcA), Some(dstA)) => Array((et.srcId, srcA), (et.dstId, dstA)) } } @@ -141,14 +147,14 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { /** - * Return the Ids of the neighboring vertices. + * Return the Ids of the neighboring vertices. * * @param edgeDirection the direction along which to collect * neighboring vertices * * @return the vertex set of neighboring ids for each vertex. */ - def collectNeighborIds(edgeDirection: EdgeDirection) : + def collectNeighborIds(edgeDirection: EdgeDirection) : VertexSetRDD[Array[Vid]] = { val nbrs = graph.aggregateNeighbors[Array[Vid]]( (vid, edge) => Some(Array(edge.otherVertexId(vid))), @@ -171,10 +177,10 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](graph: Graph[VD, ED]) { * @return the vertex set of neighboring vertex attributes for each * vertex. */ - def collectNeighbors(edgeDirection: EdgeDirection) : + def collectNeighbors(edgeDirection: EdgeDirection) : VertexSetRDD[ Array[(Vid, VD)] ] = { val nbrs = graph.aggregateNeighbors[Array[(Vid,VD)]]( - (vid, edge) => + (vid, edge) => Some(Array( (edge.otherVertexId(vid), edge.otherVertexAttr(vid)) )), (a, b) => a ++ b, edgeDirection) diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala index 1750b7f8dcabc9a7d1b99bb4686df27a8624287e..501e593e917eae3cf4df6940b138d2ff9a5d2c0b 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala @@ -34,8 +34,8 @@ import org.apache.spark.rdd.RDD * def sendMessage(id: Vid, edge: EdgeTriplet[Double, Double]): Option[Double] = * Some(edge.srcAttr * edge.attr) * def messageCombiner(a: Double, b: Double): Double = a + b - * val initialMessage = 0.0 - * // Execute pregel for a fixed number of iterations. + * val initialMessage = 0.0 + * // Execute pregel for a fixed number of iterations. * Pregel(pagerankGraph, initialMessage, numIter)( * vertexProgram, sendMessage, messageCombiner) * }}} @@ -64,7 +64,7 @@ object Pregel { * @tparam ED the edge data type * @tparam A the Pregel message type * - * @param graph the input graph. + * @param graph the input graph. * * @param initialMsg the message each vertex will receive at the on * the first iteration. @@ -93,78 +93,17 @@ object Pregel { def apply[VD: ClassManifest, ED: ClassManifest, A: ClassManifest] (graph: Graph[VD, ED], initialMsg: A, numIter: Int)( vprog: (Vid, VD, A) => VD, - sendMsg: (Vid, EdgeTriplet[VD, ED]) => Option[A], + sendMsg: EdgeTriplet[VD, ED] => Array[(Vid,A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { - apply(graph, initialMsg, numIter, EdgeDirection.Out)(vprog, sendMsg, mergeMsg) - } // end of Apply - - - /** - * Execute a Pregel-like iterative vertex-parallel abstraction. The - * user-defined vertex-program `vprog` is executed in parallel on - * each vertex receiving any inbound messages and computing a new - * value for the vertex. The `sendMsg` function is then invoked on - * all out-edges and is used to compute an optional message to the - * destination vertex. The `mergeMsg` function is a commutative - * associative function used to combine messages destined to the - * same vertex. - * - * On the first iteration all vertices receive the `initialMsg` and - * on subsequent iterations if a vertex does not receive a message - * then the vertex-program is not invoked. - * - * This function iterates a fixed number (`numIter`) of iterations. - * - * @tparam VD the vertex data type - * @tparam ED the edge data type - * @tparam A the Pregel message type - * - * @param graph the input graph. - * - * @param initialMsg the message each vertex will receive at the on - * the first iteration. - * - * @param numIter the number of iterations to run this computation. - * - * @param sendDir the edge direction along which the `sendMsg` - * function is invoked. - * - * @param vprog the user-defined vertex program which runs on each - * vertex and receives the inbound message and computes a new vertex - * value. On the first iteration the vertex program is invoked on - * all vertices and is passed the default message. On subsequent - * iterations the vertex program is only invoked on those vertices - * that receive messages. - * - * @param sendMsg a user supplied function that is applied to each - * edge in the direction `sendDir` adjacent to vertices that - * received messages in the current iteration. - * - * @param mergeMsg a user supplied function that takes two incoming - * messages of type A and merges them into a single message of type - * A. ''This function must be commutative and associative and - * ideally the size of A should not increase.'' - * - * @return the resulting graph at the end of the computation - * - */ - def apply[VD: ClassManifest, ED: ClassManifest, A: ClassManifest] - (graph: Graph[VD, ED], initialMsg: A, numIter: Int, sendDir: EdgeDirection)( - vprog: (Vid, VD, A) => VD, - sendMsg: (Vid, EdgeTriplet[VD, ED]) => Option[A], - mergeMsg: (A, A) => A) - : Graph[VD, ED] = { - - def mapF(vid: Vid, edge: EdgeTriplet[VD,ED]) = sendMsg(edge.otherVertexId(vid), edge) // Receive the first set of messages var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg)) - + var i = 0 while (i < numIter) { // compute the messages - val messages = g.aggregateNeighbors(mapF, mergeMsg, sendDir.reverse) + val messages = g.mapReduceTriplets(sendMsg, mergeMsg) // receive the messages g = g.joinVertices(messages)(vprog) // count the iteration @@ -195,7 +134,7 @@ object Pregel { * @tparam ED the edge data type * @tparam A the Pregel message type * - * @param graph the input graph. + * @param graph the input graph. * * @param initialMsg the message each vertex will receive at the on * the first iteration. @@ -224,66 +163,7 @@ object Pregel { def apply[VD: ClassManifest, ED: ClassManifest, A: ClassManifest] (graph: Graph[VD, ED], initialMsg: A)( vprog: (Vid, VD, A) => VD, - sendMsg: (Vid, EdgeTriplet[VD, ED]) => Option[A], - mergeMsg: (A, A) => A) - : Graph[VD, ED] = { - apply(graph, initialMsg, EdgeDirection.Out)(vprog, sendMsg, mergeMsg) - } // end of apply - - - /** - * Execute a Pregel-like iterative vertex-parallel abstraction. The - * user-defined vertex-program `vprog` is executed in parallel on - * each vertex receiving any inbound messages and computing a new - * value for the vertex. The `sendMsg` function is then invoked on - * all out-edges and is used to compute an optional message to the - * destination vertex. The `mergeMsg` function is a commutative - * associative function used to combine messages destined to the - * same vertex. - * - * On the first iteration all vertices receive the `initialMsg` and - * on subsequent iterations if a vertex does not receive a message - * then the vertex-program is not invoked. - * - * This function iterates until there are no remaining messages. - * - * @tparam VD the vertex data type - * @tparam ED the edge data type - * @tparam A the Pregel message type - * - * @param graph the input graph. - * - * @param initialMsg the message each vertex will receive at the on - * the first iteration. - * - * @param numIter the number of iterations to run this computation. - * - * @param sendDir the edge direction along which the `sendMsg` - * function is invoked. - * - * @param vprog the user-defined vertex program which runs on each - * vertex and receives the inbound message and computes a new vertex - * value. On the first iteration the vertex program is invoked on - * all vertices and is passed the default message. On subsequent - * iterations the vertex program is only invoked on those vertices - * that receive messages. - * - * @param sendMsg a user supplied function that is applied to each - * edge in the direction `sendDir` adjacent to vertices that - * received messages in the current iteration. - * - * @param mergeMsg a user supplied function that takes two incoming - * messages of type A and merges them into a single message of type - * A. ''This function must be commutative and associative and - * ideally the size of A should not increase.'' - * - * @return the resulting graph at the end of the computation - * - */ - def apply[VD: ClassManifest, ED: ClassManifest, A: ClassManifest] - (graph: Graph[VD, ED], initialMsg: A, sendDir: EdgeDirection)( - vprog: (Vid, VD, A) => VD, - sendMsg: (Vid, EdgeTriplet[VD, ED]) => Option[A], + sendMsg: EdgeTriplet[VD, ED] => Array[(Vid,A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { @@ -294,7 +174,7 @@ object Pregel { } } - def sendMsgFun(vid: Vid, edge: EdgeTriplet[(VD,Boolean), ED]): Option[A] = { + def sendMsgFun(edge: EdgeTriplet[(VD,Boolean), ED]): Array[(Vid, A)] = { if(edge.srcAttr._2) { val et = new EdgeTriplet[VD, ED] et.srcId = edge.srcId @@ -302,22 +182,22 @@ object Pregel { et.dstId = edge.dstId et.dstAttr = edge.dstAttr._1 et.attr = edge.attr - sendMsg(edge.otherVertexId(vid), et) - } else { None } + sendMsg(et) + } else { Array.empty[(Vid,A)] } } - var g = graph.mapVertices( (vid, vdata) => (vprog(vid, vdata, initialMsg), true) ) + var g = graph.mapVertices( (vid, vdata) => (vprog(vid, vdata, initialMsg), true) ) // compute the messages - var messages = g.aggregateNeighbors(sendMsgFun, mergeMsg, sendDir.reverse).cache + var messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache var activeMessages = messages.count - // Loop + // Loop var i = 0 while (activeMessages > 0) { // receive the messages g = g.outerJoinVertices(messages)(vprogFun) val oldMessages = messages // compute the messages - messages = g.aggregateNeighbors(sendMsgFun, mergeMsg, sendDir.reverse).cache + messages = g.mapReduceTriplets(sendMsgFun, mergeMsg).cache activeMessages = messages.count // after counting we can unpersist the old messages oldMessages.unpersist(blocking=false) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index d0df35d4226f724b3b419653b7277894261d7015..a6c4cc4b66c32ff4d4a16f2824aed8b8927dc2e5 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -6,7 +6,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext._ -import org.apache.spark.HashPartitioner +import org.apache.spark.HashPartitioner import org.apache.spark.util.ClosureCleaner import org.apache.spark.graph._ @@ -27,7 +27,7 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( private var pos = 0 private val et = new EdgeTriplet[VD, ED] private val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray) - + override def hasNext: Boolean = pos < edgePartition.size override def next() = { et.srcId = edgePartition.srcIds(pos) @@ -105,16 +105,16 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( override def statistics: Map[String, Any] = { val numVertices = this.numVertices val numEdges = this.numEdges - val replicationRatio = + val replicationRatio = vid2pid.map(kv => kv._2.size).sum / vTable.count - val loadArray = + val loadArray = eTable.map{ case (pid, epart) => epart.data.size }.collect.map(x => x.toDouble / numEdges) val minLoad = loadArray.min val maxLoad = loadArray.max Map( "Num Vertices" -> numVertices, "Num Edges" -> numEdges, - "Replication" -> replicationRatio, "Load Array" -> loadArray, - "Min Load" -> minLoad, "Max Load" -> maxLoad) + "Replication" -> replicationRatio, "Load Array" -> loadArray, + "Min Load" -> minLoad, "Max Load" -> maxLoad) } /** @@ -136,10 +136,10 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( println(indent + name + ": " + cacheLevel.description + " (partitioner: " + partitioner + ", " + numparts +")") println(indent + " |---> Deps: " + deps.map(d => (d, d.rdd.id) ).toString) println(indent + " |---> PrefLoc: " + locs.map(x=> x.toString).mkString(", ")) - deps.foreach(d => traverseLineage(d.rdd, indent + " | ", visited)) + deps.foreach(d => traverseLineage(d.rdd, indent + " | ", visited)) } } - + println("eTable ------------------------------------------") traverseLineage(eTable, " ") var visited = Map(eTable.id -> "eTable") @@ -160,11 +160,11 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( traverseLineage(vid2pid, " ", visited) visited += (vid2pid.id -> "vid2pid") visited += (vid2pid.valuesRDD.id -> "vid2pid.values") - + println("\n\nlocalVidMap -------------------------------------") traverseLineage(localVidMap, " ", visited) visited += (localVidMap.id -> "localVidMap") - + println("\n\nvTableReplicatedValues --------------------------") traverseLineage(vTableReplicatedValues, " ", visited) visited += (vTableReplicatedValues.id -> "vTableReplicatedValues") @@ -175,7 +175,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( } // end of print lineage override def reverse: Graph[VD, ED] = { - val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) }, + val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) }, preservesPartitioning = true) new GraphImpl(vTable, vid2pid, localVidMap, newEtable) } @@ -194,7 +194,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = GraphImpl.mapTriplets(this, f) - override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true), vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = { /** @todo The following code behaves deterministically on each @@ -202,7 +202,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( * this version */ // val predGraph = mapVertices(v => (v.data, vpred(v))) - // val newETable = predGraph.triplets.filter(t => + // val newETable = predGraph.triplets.filter(t => // if(v.src.data._2 && v.dst.data._2) { // val src = Vertex(t.src.id, t.src.data._1) // val dst = Vertex(t.dst.id, t.dst.data._1) @@ -213,7 +213,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( // .map(v => (v.id, v.data._1)).indexed() // Reuse the partitioner (but not the index) from this graph - val newVTable = + val newVTable = VertexSetRDD(vertices.filter(v => vpred(v._1, v._2)).partitionBy(vTable.index.partitioner)) @@ -224,9 +224,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ) .map( t => Edge(t.srcId, t.dstId, t.attr) )) - // Construct the Vid2Pid map. Here we assume that the filter operation - // behaves deterministically. - // @todo reindex the vertex and edge tables + // Construct the Vid2Pid map. Here we assume that the filter operation + // behaves deterministically. + // @todo reindex the vertex and edge tables val newVid2Pid = createVid2Pid(newETable, newVTable.index) val newVidMap = createLocalVidMap(newETable) @@ -281,7 +281,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( override def mapReduceTriplets[A: ClassManifest]( mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)], reduceFunc: (A, A) => A) - : VertexSetRDD[A] = + : VertexSetRDD[A] = GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc) override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest] @@ -298,29 +298,29 @@ object GraphImpl { def apply[VD: ClassManifest, ED: ClassManifest]( vertices: RDD[(Vid, VD)], edges: RDD[Edge[ED]], - defaultVertexAttr: VD): + defaultVertexAttr: VD): GraphImpl[VD,ED] = { apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a) } def apply[VD: ClassManifest, ED: ClassManifest]( - vertices: RDD[(Vid, VD)], + vertices: RDD[(Vid, VD)], edges: RDD[Edge[ED]], defaultVertexAttr: VD, mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = { - val vtable = VertexSetRDD(vertices, mergeFunc) - /** - * @todo Verify that there are no edges that contain vertices + val vtable = VertexSetRDD(vertices, mergeFunc) + /** + * @todo Verify that there are no edges that contain vertices * that are not in vTable. This should probably be resolved: * * edges.flatMap{ e => Array((e.srcId, null), (e.dstId, null)) } * .cogroup(vertices).map{ - * case (vid, _, attr) => + * case (vid, _, attr) => * if (attr.isEmpty) (vid, defaultValue) * else (vid, attr) * } - * + * */ val etable = createETable(edges) val vid2pid = createVid2Pid(etable, vtable.index) @@ -340,7 +340,7 @@ object GraphImpl { : RDD[(Pid, EdgePartition[ED])] = { // Get the number of partitions val numPartitions = edges.partitions.size - val ceilSqrt: Pid = math.ceil(math.sqrt(numPartitions)).toInt + val ceilSqrt: Pid = math.ceil(math.sqrt(numPartitions)).toInt edges.map { e => // Random partitioning based on the source vertex id. // val part: Pid = edgePartitionFunction1D(e.srcId, e.dstId, numPartitions) @@ -372,18 +372,18 @@ object GraphImpl { edgePartition.foreach(e => {vSet.add(e.srcId); vSet.add(e.dstId)}) vSet.iterator.map { vid => (vid.toLong, pid) } } - VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex, + VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex, (p: Pid) => ArrayBuffer(p), (ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab}, (a: ArrayBuffer[Pid], b: ArrayBuffer[Pid]) => a ++ b) .mapValues(a => a.toArray).cache() } - protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]): + protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]): RDD[(Pid, VertexIdToIndexMap)] = { eTable.mapPartitions( _.map{ case (pid, epart) => val vidToIndex = new VertexIdToIndexMap - epart.foreach{ e => + epart.foreach{ e => vidToIndex.add(e.srcId) vidToIndex.add(e.dstId) } @@ -392,17 +392,17 @@ object GraphImpl { } protected def createVTableReplicated[VD: ClassManifest]( - vTable: VertexSetRDD[VD], + vTable: VertexSetRDD[VD], vid2pid: VertexSetRDD[Array[Pid]], - replicationMap: RDD[(Pid, VertexIdToIndexMap)]): + replicationMap: RDD[(Pid, VertexIdToIndexMap)]): RDD[(Pid, Array[VD])] = { - // Join vid2pid and vTable, generate a shuffle dependency on the joined + // Join vid2pid and vTable, generate a shuffle dependency on the joined // result, and get the shuffle id so we can use it on the slave. val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) => pids.iterator.map { pid => new VertexMessage[VD](pid, vid, vdata) } }.partitionBy(replicationMap.partitioner.get).cache() - replicationMap.zipPartitions(msgsByPartition){ + replicationMap.zipPartitions(msgsByPartition){ (mapIter, msgsIter) => val (pid, vidToIndex) = mapIter.next() assert(!mapIter.hasNext) @@ -418,12 +418,12 @@ object GraphImpl { // @todo assert edge table has partitioner } - def makeTriplets[VD: ClassManifest, ED: ClassManifest]( + def makeTriplets[VD: ClassManifest, ED: ClassManifest]( localVidMap: RDD[(Pid, VertexIdToIndexMap)], vTableReplicatedValues: RDD[(Pid, Array[VD]) ], eTable: RDD[(Pid, EdgePartition[ED])]): RDD[EdgeTriplet[VD, ED]] = { - localVidMap.zipPartitions(vTableReplicatedValues, eTable) { - (vidMapIter, replicatedValuesIter, eTableIter) => + eTable.zipPartitions(localVidMap, vTableReplicatedValues) { + (eTableIter, vidMapIter, replicatedValuesIter) => val (_, vidToIndex) = vidMapIter.next() val (_, vertexArray) = replicatedValuesIter.next() val (_, edgePartition) = eTableIter.next() @@ -432,9 +432,9 @@ object GraphImpl { } def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest]( - g: GraphImpl[VD, ED], + g: GraphImpl[VD, ED], f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ + val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ (edgePartitionIter, vidToIndexIter, vertexArrayIter) => val (pid, edgePartition) = edgePartitionIter.next() val (_, vidToIndex) = vidToIndexIter.next() @@ -460,8 +460,8 @@ object GraphImpl { ClosureCleaner.clean(mapFunc) ClosureCleaner.clean(reduceFunc) - // Map and preaggregate - val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ + // Map and preaggregate + val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ (edgePartitionIter, vidToIndexIter, vertexArrayIter) => val (_, edgePartition) = edgePartitionIter.next() val (_, vidToIndex) = vidToIndexIter.next() @@ -477,6 +477,7 @@ object GraphImpl { val msgBS = new BitSet(vertexArray.size) // Iterate over the partition val et = new EdgeTriplet[VD, ED] + edgePartition.foreach { e => et.set(e) et.srcAttr = vmap(e.srcId) @@ -484,7 +485,7 @@ object GraphImpl { // TODO(rxin): rewrite the foreach using a simple while loop to speed things up. // Also given we are only allowing zero, one, or two messages, we can completely unroll // the for loop. - mapFunc(et).foreach{ case (vid, msg) => + mapFunc(et).foreach { case (vid, msg) => // verify that the vid is valid assert(vid == et.srcId || vid == et.dstId) // Get the index of the key @@ -492,7 +493,7 @@ object GraphImpl { // Populate the aggregator map if (msgBS.get(ind)) { msgArray(ind) = reduceFunc(msgArray(ind), msg) - } else { + } else { msgArray(ind) = msg msgBS.set(ind) } @@ -506,64 +507,64 @@ object GraphImpl { } protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = { - val mixingPrime: Vid = 1125899906842597L + val mixingPrime: Vid = 1125899906842597L (math.abs(src) * mixingPrime).toInt % numParts } /** - * This function implements a classic 2D-Partitioning of a sparse matrix. - * Suppose we have a graph with 11 vertices that we want to partition + * This function implements a classic 2D-Partitioning of a sparse matrix. + * Suppose we have a graph with 11 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: * * __________________________________ - * v0 | P0 * | P1 | P2 * | + * v0 | P0 * | P1 | P2 * | * v1 | **** | * | | * v2 | ******* | ** | **** | - * v3 | ***** | * * | * | + * v3 | ***** | * * | * | * ---------------------------------- - * v4 | P3 * | P4 *** | P5 ** * | + * v4 | P3 * | P4 *** | P5 ** * | * v5 | * * | * | | * v6 | * | ** | **** | - * v7 | * * * | * * | * | + * v7 | * * * | * * | * | * ---------------------------------- - * v8 | P6 * | P7 * | P8 * *| + * v8 | P6 * | P7 * | P8 * *| * v9 | * | * * | | * v10 | * | ** | * * | - * v11 | * <-E | *** | ** | + * v11 | * <-E | *** | ** | * ---------------------------------- * - * The edge denoted by E connects v11 with v1 and is assigned to + * The edge denoted by E connects v11 with v1 and is assigned to * processor P6. To get the processor number we divide the matrix * into sqrt(numProc) by sqrt(numProc) blocks. Notice that edges - * adjacent to v11 can only be in the first colum of - * blocks (P0, P3, P6) or the last row of blocks (P6, P7, P8). - * As a consequence we can guarantee that v11 will need to be + * adjacent to v11 can only be in the first colum of + * blocks (P0, P3, P6) or the last row of blocks (P6, P7, P8). + * As a consequence we can guarantee that v11 will need to be * replicated to at most 2 * sqrt(numProc) machines. * - * Notice that P0 has many edges and as a consequence this + * Notice that P0 has many edges and as a consequence this * partitioning would lead to poor work balance. To improve - * balance we first multiply each vertex id by a large prime - * to effectively shuffle the vertex locations. + * balance we first multiply each vertex id by a large prime + * to effectively shuffle the vertex locations. * * One of the limitations of this approach is that the number of * machines must either be a perfect square. We partially address - * this limitation by computing the machine assignment to the next - * largest perfect square and then mapping back down to the actual - * number of machines. Unfortunately, this can also lead to work - * imbalance and so it is suggested that a perfect square is used. - * + * this limitation by computing the machine assignment to the next + * largest perfect square and then mapping back down to the actual + * number of machines. Unfortunately, this can also lead to work + * imbalance and so it is suggested that a perfect square is used. + * * */ - protected def edgePartitionFunction2D(src: Vid, dst: Vid, + protected def edgePartitionFunction2D(src: Vid, dst: Vid, numParts: Pid, ceilSqrtNumParts: Pid): Pid = { - val mixingPrime: Vid = 1125899906842597L + val mixingPrime: Vid = 1125899906842597L val col: Pid = ((math.abs(src) * mixingPrime) % ceilSqrtNumParts).toInt val row: Pid = ((math.abs(dst) * mixingPrime) % ceilSqrtNumParts).toInt (col * ceilSqrtNumParts + row) % numParts } /** - * Assign edges to an aribtrary machine corresponding to a + * Assign edges to an aribtrary machine corresponding to a * random vertex cut. */ protected def randomVertexCut(src: Vid, dst: Vid, numParts: Pid): Pid = { @@ -574,9 +575,9 @@ object GraphImpl { * @todo This will only partition edges to the upper diagonal * of the 2D processor space. */ - protected def canonicalEdgePartitionFunction2D(srcOrig: Vid, dstOrig: Vid, + protected def canonicalEdgePartitionFunction2D(srcOrig: Vid, dstOrig: Vid, numParts: Pid, ceilSqrtNumParts: Pid): Pid = { - val mixingPrime: Vid = 1125899906842597L + val mixingPrime: Vid = 1125899906842597L // Partitions by canonical edge direction val src = math.min(srcOrig, dstOrig) val dst = math.max(srcOrig, dstOrig) diff --git a/graphx-shell b/graphx-shell new file mode 100755 index 0000000000000000000000000000000000000000..4dd6c68ace888d3996b0ee578057eb922507d38a --- /dev/null +++ b/graphx-shell @@ -0,0 +1,124 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark Shell REPL +# Note that it will set MASTER to spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT} +# if those two env vars are set in spark-env.sh but MASTER is not. +# Options: +# -c <cores> Set the number of cores for REPL to use +# + +# Enter posix mode for bash +set -o posix + + +# Update the the banner logo +export SPARK_BANNER_TEXT="Welcome to + ______ __ _ __ + / ____/________ _____ / /_ | |/ / + / / __/ ___/ __ \`/ __ \/ __ \| / + / /_/ / / / /_/ / /_/ / / / / | + \____/_/ \__,_/ .___/_/ /_/_/|_| + /_/ Alpha Release + +Powered by: + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ \`/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ + /_/ version 0.9.0 + +Example: + + scala> val graph = GraphLoader.textFile(sc, \"hdfs://links\") + scala> graph.numVertices + scala> graph.numEdges + scala> val pageRankGraph = Analytics.pagerank(graph, 10) // 10 iterations + scala> val maxPr = pageRankGraph.vertices.map{ case (vid, pr) => pr }.max + scala> println(maxPr) + +" + +export SPARK_SHELL_INIT_BLOCK="import org.apache.spark.graph._;" + +# Set the serializer to use Kryo for graphx objects +SPARK_JAVA_OPTS+=" -Dspark.serializer=org.apache.spark.serializer.KryoSerializer " +SPARK_JAVA_OPTS+="-Dspark.kryo.registrator=org.apache.spark.graph.GraphKryoRegistrator " +SPARK_JAVA_OPTS+="-Dspark.kryoserializer.buffer.mb=10 " + + + +FWDIR="`dirname $0`" + +for o in "$@"; do + if [ "$1" = "-c" -o "$1" = "--cores" ]; then + shift + if [ -n "$1" ]; then + OPTIONS="-Dspark.cores.max=$1" + shift + fi + fi +done + +# Set MASTER from spark-env if possible +if [ -z "$MASTER" ]; then + if [ -e "$FWDIR/conf/spark-env.sh" ]; then + . "$FWDIR/conf/spark-env.sh" + fi + if [[ "x" != "x$SPARK_MASTER_IP" && "y" != "y$SPARK_MASTER_PORT" ]]; then + MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}" + export MASTER + fi +fi + +# Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in +# binary distribution of Spark where Scala is not installed +exit_status=127 +saved_stty="" + +# restore stty settings (echo in particular) +function restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +function onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +# to reenable echo if we are interrupted before completing. +trap onExit INT + +# save terminal settings +saved_stty=$(stty -g 2>/dev/null) +# clear on error so we don't later try to restore them +if [[ ! $? ]]; then + saved_stty="" +fi + +$FWDIR/spark-class $OPTIONS org.apache.spark.repl.Main "$@" + +# record the exit status lest it be overwritten: +# then reenable echo and propagate the code. +exit_status=$? +onExit diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 0ced284da68f50bc24a4305dd43668268f7f09a5..efdd90c47f7c84ce5225c64f9e20e5997fa64d83 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -45,7 +45,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master)) def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None) def this() = this(None, new PrintWriter(Console.out, true), None) - + var in: InteractiveReader = _ // the input stream from which commands come var settings: Settings = _ var intp: SparkIMain = _ @@ -56,16 +56,16 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Power[g.type](this, g) } */ - + // TODO // object opt extends AestheticSettings - // + // @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp - + @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i - + def history = in.history /** The context class loader at the time this object was created */ @@ -75,7 +75,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private val signallable = /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand()) else*/ null - + // classpath entries added via :cp var addedClasspath: String = "" @@ -87,10 +87,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: /** Record a command for replay should the user request a :replay */ def addReplay(cmd: String) = replayCommandStack ::= cmd - + /** Try to install sigint handler: ignore failure. Signal handler * will interrupt current line execution if any is in progress. - * + * * Attempting to protect the repl from accidental exit, we only honor * a single ctrl-C if the current buffer is empty: otherwise we look * for a second one within a short time. @@ -124,7 +124,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Thread.currentThread.setContextClassLoader(originalClassLoader) } } - + class SparkILoopInterpreter extends SparkIMain(settings, out) { override lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt @@ -135,7 +135,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: |// She's gone rogue, captain! Have to take her out! |// Calling Thread.stop on runaway %s with offending code: |// scala> %s""".stripMargin - + echo(template.format(line.thread, line.code)) // XXX no way to suppress the deprecation warning line.thread.stop() @@ -151,7 +151,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def createInterpreter() { if (addedClasspath != "") settings.classpath append addedClasspath - + intp = new SparkILoopInterpreter intp.setContextClassLoader() installSigIntHandler() @@ -168,10 +168,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private def helpSummary() = { val usageWidth = commands map (_.usageMsg.length) max val formatStr = "%-" + usageWidth + "s %s %s" - + echo("All commands can be abbreviated, e.g. :he instead of :help.") echo("Those marked with a * have more detailed help, e.g. :help imports.\n") - + commands foreach { cmd => val star = if (cmd.hasLongHelp) "*" else " " echo(formatStr.format(cmd.usageMsg, star, cmd.help)) @@ -182,7 +182,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: case Nil => echo(cmd + ": no such command. Type :help for help.") case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") } - Result(true, None) + Result(true, None) } private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) private def uniqueCommand(cmd: String): Option[LoopCommand] = { @@ -193,31 +193,35 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: case xs => xs find (_.name == cmd) } } - + /** Print a welcome message */ def printWelcome() { - echo("""Welcome to - ____ __ + val prop = System.getenv("SPARK_BANNER_TEXT") + val bannerText = + if (prop != null) prop else + """Welcome to + ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT - /_/ -""") + /_/ + """ + echo(bannerText) import Properties._ val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) + versionString, javaVmName, javaVersion) echo(welcomeMsg) } - + /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { override def usage = "[num]" def defaultLines = 20 - + def apply(line: String): Result = { if (history eq NoHistory) return "No history available." - + val xs = words(line) val current = history.index val count = try xs.head.toInt catch { case _: Exception => defaultLines } @@ -237,21 +241,21 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: out print msg out.flush() } - + /** Search the history */ def searchHistory(_cmdline: String) { val cmdline = _cmdline.toLowerCase val offset = history.index - history.size + 1 - + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) echo("%d %s".format(index + offset, line)) } - + private var currentPrompt = Properties.shellPromptString def setPrompt(prompt: String) = currentPrompt = prompt /** Prompt to print when awaiting input */ def prompt = currentPrompt - + import LoopCommand.{ cmd, nullary } /** Standard commands **/ @@ -273,7 +277,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: nullary("silent", "disable/enable automatic printing of results", verbosity), cmd("type", "<expr>", "display the type of an expression without evaluating it", typeCommand) ) - + /** Power user commands */ lazy val powerCommands: List[LoopCommand] = List( //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand), @@ -298,10 +302,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: |An argument of clear will remove the wrapper if any is active. |Note that wrappers do not compose (a new one replaces the old |one) and also that the :phase command uses the same machinery, - |so setting :wrap will clear any :phase setting. + |so setting :wrap will clear any :phase setting. """.stripMargin.trim) ) - + /* private def dumpCommand(): Result = { echo("" + power) @@ -309,7 +313,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: in.redrawLine() } */ - + private val typeTransforms = List( "scala.collection.immutable." -> "immutable.", "scala.collection.mutable." -> "mutable.", @@ -317,7 +321,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: "java.lang." -> "jl.", "scala.runtime." -> "runtime." ) - + private def importsCommand(line: String): Result = { val tokens = words(line) val handlers = intp.languageWildcardHandlers ++ intp.importHandlers @@ -333,7 +337,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - + intp.reporter.printMessage("%2d) %-30s %s%s".format( idx + 1, handler.importString, @@ -342,12 +346,12 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: )) } } - + private def implicitsCommand(line: String): Result = { val intp = SparkILoop.this.intp import intp._ import global.Symbol - + def p(x: Any) = intp.reporter.printMessage("" + x) // If an argument is given, only show a source with that @@ -360,14 +364,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else (args exists (source.name.toString contains _)) } } - + if (filtered.isEmpty) return "No implicits have been imported other than those in Predef." - + filtered foreach { case (source, syms) => p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") - + // This groups the members by where the symbol is defined val byOwner = syms groupBy (_.owner) val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) } @@ -388,10 +392,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: xss map (xs => xs sortBy (_.name.toString)) } - - val ownerMessage = if (owner == source) " defined in " else " inherited from " + + val ownerMessage = if (owner == source) " defined in " else " inherited from " p(" /* " + members.size + ownerMessage + owner.fullName + " */") - + memberGroups foreach { group => group foreach (s => p(" " + intp.symbolDefString(s))) p("") @@ -400,7 +404,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: p("") } } - + protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) { override def tryClass(path: String): Array[Byte] = { // Look for Foo first, then Foo$, but if Foo$ is given explicitly, @@ -417,20 +421,20 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private lazy val javap = try newJavap() catch { case _: Exception => null } - + private def typeCommand(line: String): Result = { intp.typeOfExpression(line) match { case Some(tp) => tp.toString case _ => "Failed to determine type." } } - + private def javapCommand(line: String): Result = { if (javap == null) return ":javap unavailable on this platform." if (line == "") return ":javap [-lcsvp] [path1 path2 ...]" - + javap(words(line)) foreach { res => if (res.isError) return "Failed: " + res.value else res.show() @@ -504,25 +508,25 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } else { val what = phased.parse(name) - if (what.isEmpty || !phased.set(what)) + if (what.isEmpty || !phased.set(what)) "'" + name + "' does not appear to represent a valid phase." else { intp.setExecutionWrapper(pathToPhaseWrapper) val activeMessage = if (what.toString.length == name.length) "" + what else "%s (%s)".format(what, name) - + "Active phase is now: " + activeMessage } } } */ - + /** Available commands */ def commands: List[LoopCommand] = standardCommands /* ++ ( if (isReplPower) powerCommands else Nil )*/ - + val replayQuestionMessage = """|The repl compiler has crashed spectacularly. Shall I replay your |session? I can re-run all lines except the last one. @@ -579,10 +583,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { + def interpretAllFrom(file: File) { val oldIn = in val oldReplay = replayCommandStack - + try file applyReader { reader => in = SimpleReader(reader, out, false) echo("Loading " + file + "...") @@ -604,26 +608,26 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: echo("") } } - + /** fork a shell and run a command */ lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { override def usage = "<command line>" def apply(line: String): Result = line match { case "" => showUsage() - case _ => + case _ => val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" intp interpret toRun () } } - + def withFile(filename: String)(action: File => Unit) { val f = File(filename) - + if (f.exists) action(f) else echo("That file does not exist") } - + def loadCommand(arg: String) = { var shouldReplay: Option[String] = None withFile(arg)(f => { @@ -657,7 +661,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } else echo("The path '" + f + "' doesn't seem to exist.") } - + def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode() @@ -667,13 +671,13 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: //power.unleash() //echo(power.banner) } - + def verbosity() = { val old = intp.printResults intp.printResults = !old echo("Switched " + (if (old) "off" else "on") + " result printing.") } - + /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ @@ -688,11 +692,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else if (intp.global == null) Result(false, None) // Notice failure to create compiler else Result(true, interpretStartingWith(line)) } - + private def readWhile(cond: String => Boolean) = { Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) } - + def pasteCommand(): Result = { echo("// Entering paste mode (ctrl-D to finish)\n") val code = readWhile(_ => true) mkString "\n" @@ -700,17 +704,17 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: intp interpret code () } - + private object paste extends Pasted { val ContinueString = " | " val PromptString = "scala> " - + def interpret(line: String): Unit = { echo(line.trim) intp interpret line echo("") } - + def transcript(start: String) = { // Printing this message doesn't work very well because it's buried in the // transcript they just pasted. Todo: a short timer goes off when @@ -731,7 +735,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() - + def reallyInterpret = { val reallyResult = intp.interpret(code) (reallyResult, reallyResult match { @@ -741,7 +745,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (in.interactive && code.endsWith("\n\n")) { echo("You typed two blank lines. Starting a new command.") None - } + } else in.readLine(ContinueString) match { case null => // we know compilation is going to fail since we're at EOF and the @@ -755,10 +759,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } }) } - + /** Here we place ourselves between the user and the interpreter and examine * the input they are ostensibly submitting. We intervene in several cases: - * + * * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation * on the previous result. @@ -787,7 +791,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: val (code, result) = reallyInterpret //if (power != null && code == IR.Error) // runCompletion - + result } else runCompletion match { @@ -808,7 +812,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } case _ => } - + /** Tries to create a JLineReader, falling back to SimpleReader: * unless settings or properties are such that it should start * with SimpleReader. @@ -837,6 +841,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: org.apache.spark.repl.Main.interp.out.flush(); """) command("import org.apache.spark.SparkContext._") + val prop = System.getenv("SPARK_SHELL_INIT_BLOCK") + if (prop != null) { + command(prop) + } } echo("Type in expressions to have them evaluated.") echo("Type :help for more information.") @@ -884,7 +892,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: this.settings = settings createInterpreter() - + // sets in to some kind of reader depending on environmental cues in = in0 match { case Some(reader) => SimpleReader(reader, out, true) @@ -895,10 +903,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: // it is broken on startup; go ahead and exit if (intp.reporter.hasErrors) return false - - try { + + try { // this is about the illusion of snappiness. We call initialize() - // which spins off a separate thread, then print the prompt and try + // which spins off a separate thread, then print the prompt and try // our best to look ready. Ideally the user will spend a // couple seconds saying "wow, it starts so fast!" and by the time // they type a command the compiler is ready to roll. @@ -920,19 +928,19 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def neededHelp(): String = (if (command.settings.help.value) command.usageMsg + "\n" else "") + (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") - + // if they asked for no help and command is valid, we call the real main neededHelp() match { case "" => command.ok && process(command.settings) case help => echoNoNL(help) ; true } } - + @deprecated("Use `process` instead", "2.9.0") def main(args: Array[String]): Unit = { if (isReplDebug) System.out.println(new java.util.Date) - + process(args) } @deprecated("Use `process` instead", "2.9.0") @@ -948,7 +956,7 @@ object SparkILoop { // like if you'd just typed it into the repl. def runForTranscript(code: String, settings: Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - + stringFromStream { ostream => Console.withOut(ostream) { val output = new PrintWriter(new OutputStreamWriter(ostream), true) { @@ -977,19 +985,19 @@ object SparkILoop { } } } - + /** Creates an interpreter loop with default settings and feeds * the given code to it as input. */ def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - + stringFromStream { ostream => Console.withOut(ostream) { val input = new BufferedReader(new StringReader(code)) val output = new PrintWriter(new OutputStreamWriter(ostream), true) val repl = new SparkILoop(input, output) - + if (sets.classpath.isDefault) sets.classpath.value = sys.props("java.class.path") @@ -1017,7 +1025,7 @@ object SparkILoop { repl.settings.embeddedDefaults[T] repl.createInterpreter() repl.in = SparkJLineReader(repl) - + // rebind exit so people don't accidentally call sys.exit by way of predef repl.quietRun("""def exit = println("Type :quit to resume program execution.")""") args foreach (p => repl.bind(p.name, p.tpe, p.value)) @@ -1025,5 +1033,5 @@ object SparkILoop { echo("\nDebug repl exiting.") repl.closeInterpreter() - } + } }