From e601b3b9e56d6ce978c09506ca07fd3e252e4673 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@cs.berkeley.edu> Date: Tue, 17 Apr 2012 16:40:56 -0700 Subject: [PATCH] Added the ability to set environmental variables in piped rdd. --- core/src/main/scala/spark/PipedRDD.scala | 14 ++++++- core/src/main/scala/spark/RDD.scala | 5 ++- core/src/test/scala/spark/PipedRDDSuite.scala | 37 +++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 core/src/test/scala/spark/PipedRDDSuite.scala diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala index 3f993d895a..8a5de3d7e9 100644 --- a/core/src/main/scala/spark/PipedRDD.scala +++ b/core/src/main/scala/spark/PipedRDD.scala @@ -3,6 +3,7 @@ package spark import java.io.PrintWriter import java.util.StringTokenizer +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -10,8 +11,12 @@ import scala.io.Source * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String]) +class PipedRDD[T: ClassManifest]( + parent: RDD[T], command: Seq[String], envVars: Map[String, String]) extends RDD[String](parent.context) { + + def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) @@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String]) override val dependencies = List(new OneToOneDependency(parent)) override def compute(split: Split): Iterator[String] = { - val proc = Runtime.getRuntime.exec(command.toArray) + val pb = new ProcessBuilder(command) + // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) } + + val proc = pb.start() val env = SparkEnv.get // Start a thread to print the process's stderr to ours diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 1160de5fd1..7fe6633f1b 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -9,8 +9,6 @@ import java.util.Random import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Map -import scala.collection.mutable.HashMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) + def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = new MapPartitionsRDD(this, sc.clean(f)) diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala new file mode 100644 index 0000000000..d5dc2efd91 --- /dev/null +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -0,0 +1,37 @@ +package spark + +import org.scalatest.FunSuite +import SparkContext._ + +class PipedRDDSuite extends FunSuite { + + test("basic pipe") { + val sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + val piped = nums.pipe(Seq("cat")) + + val c = piped.collect() + println(c.toSeq) + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + sc.stop() + } + + test("pipe with env variable") { + val sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.size === 2) + assert(c(0) === "LALALA") + assert(c(1) === "LALALA") + sc.stop() + } + +} + + -- GitLab