From e601b3b9e56d6ce978c09506ca07fd3e252e4673 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Tue, 17 Apr 2012 16:40:56 -0700
Subject: [PATCH] Added the ability to set environmental variables in piped
 rdd.

---
 core/src/main/scala/spark/PipedRDD.scala      | 14 ++++++-
 core/src/main/scala/spark/RDD.scala           |  5 ++-
 core/src/test/scala/spark/PipedRDDSuite.scala | 37 +++++++++++++++++++
 3 files changed, 52 insertions(+), 4 deletions(-)
 create mode 100644 core/src/test/scala/spark/PipedRDDSuite.scala

diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 3f993d895a..8a5de3d7e9 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,6 +3,7 @@ package spark
 import java.io.PrintWriter
 import java.util.StringTokenizer
 
+import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 
@@ -10,8 +11,12 @@ import scala.io.Source
  * An RDD that pipes the contents of each parent partition through an external command
  * (printing them one per line) and returns the output as a collection of strings.
  */
-class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
+class PipedRDD[T: ClassManifest](
+  parent: RDD[T], command: Seq[String], envVars: Map[String, String])
   extends RDD[String](parent.context) {
+
+  def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
+
   // Similar to Runtime.exec(), if we are given a single string, split it into words
   // using a standard StringTokenizer (i.e. by spaces)
   def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
@@ -21,7 +26,12 @@ class PipedRDD[T: ClassManifest](parent: RDD[T], command: Seq[String])
   override val dependencies = List(new OneToOneDependency(parent))
 
   override def compute(split: Split): Iterator[String] = {
-    val proc = Runtime.getRuntime.exec(command.toArray)
+    val pb = new ProcessBuilder(command)
+    // Add the environmental variables to the process.
+    val currentEnvVars = pb.environment()
+    envVars.foreach { case(variable, value) => currentEnvVars.put(variable, value) }
+    
+    val proc = pb.start()
     val env = SparkEnv.get
 
     // Start a thread to print the process's stderr to ours
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 1160de5fd1..7fe6633f1b 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -9,8 +9,6 @@ import java.util.Random
 import java.util.Date
 
 import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.Map
-import scala.collection.mutable.HashMap
 
 import org.apache.hadoop.io.BytesWritable
 import org.apache.hadoop.io.NullWritable
@@ -146,6 +144,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
 
   def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
 
+  def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
+    new PipedRDD(this, command, env)
+
   def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
     new MapPartitionsRDD(this, sc.clean(f))
 
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
new file mode 100644
index 0000000000..d5dc2efd91
--- /dev/null
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -0,0 +1,37 @@
+package spark
+
+import org.scalatest.FunSuite
+import SparkContext._
+
+class PipedRDDSuite extends FunSuite {
+
+  test("basic pipe") {
+    val sc = new SparkContext("local", "test")
+    val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+
+    val piped = nums.pipe(Seq("cat"))
+
+    val c = piped.collect()
+    println(c.toSeq)
+    assert(c.size === 4)
+    assert(c(0) === "1")
+    assert(c(1) === "2")
+    assert(c(2) === "3")
+    assert(c(3) === "4")
+    sc.stop()
+  }
+
+  test("pipe with env variable") {
+    val sc = new SparkContext("local", "test")
+    val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+    val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+    val c = piped.collect()
+    assert(c.size === 2)
+    assert(c(0) === "LALALA")
+    assert(c(1) === "LALALA")
+    sc.stop()
+  }
+
+}
+
+
-- 
GitLab