Skip to content
Snippets Groups Projects
Commit e601b3b9 authored by Reynold Xin's avatar Reynold Xin
Browse files

Added the ability to set environmental variables in piped rdd.

parent 3b745176
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ package spark
import java.io.PrintWriter
import java.util.StringTokenizer
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
......@@ -10,8 +11,12 @@ import scala.io.Source
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
class PipedRDD[T: ClassManifest](
parent: RDD[T], command: Seq[String], envVars: Map[String, String])
extends RDD[String](parent.context) {
def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
......@@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
override val dependencies = List(new OneToOneDependency(parent))
override def compute(split: Split): Iterator[String] = {
val proc = Runtime.getRuntime.exec(command.toArray)
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) }
val proc = pb.start()
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
......
......@@ -9,8 +9,6 @@ import java.util.Random
import java.util.Date
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
......@@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
......
package spark
import org.scalatest.FunSuite
import SparkContext._
class PipedRDDSuite extends FunSuite {
test("basic pipe") {
val sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
val c = piped.collect()
println(c.toSeq)
assert(c.size === 4)
assert(c(0) === "1")
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
sc.stop()
}
test("pipe with env variable") {
val sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
assert(c.size === 2)
assert(c(0) === "LALALA")
assert(c(1) === "LALALA")
sc.stop()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment