From 9b6de6fbc00b184d81fc28ac160d03451fad80ec Mon Sep 17 00:00:00 2001
From: Bill Bejeck <bbejeck@gmail.com>
Date: Tue, 14 Oct 2014 12:12:38 -0700
Subject: [PATCH] SPARK-3178  setting SPARK_WORKER_MEMORY to a value without a
 label (m or g) sets the worker memory limit to zero

Validate the memory is greater than zero when set from the SPARK_WORKER_MEMORY environment variable or command line without a g or m label.  Added unit tests. If memory is 0 an IllegalStateException is thrown. Updated unit tests to mock environment variables by subclassing SparkConf (tip provided by Josh Rosen).   Updated WorkerArguments to use SparkConf.getenv instead of System.getenv for reading the SPARK_WORKER_MEMORY environment variable.

Author: Bill Bejeck <bbejeck@gmail.com>

Closes #2309 from bbejeck/spark-memory-worker and squashes the following commits:

51cf915 [Bill Bejeck] SPARK-3178 - Validate the memory is greater than zero when set from the SPARK_WORKER_MEMORY environment variable or command line without a g or m label.  Added unit tests. If memory is 0 an IllegalStateException is thrown. Updated unit tests to mock environment variables by subclassing SparkConf (tip provided by Josh Rosen).   Updated WorkerArguments to use SparkConf.getenv instead of System.getenv for reading the SPARK_WORKER_MEMORY environment variable.
---
 .../spark/deploy/worker/WorkerArguments.scala | 13 ++-
 .../deploy/worker/WorkerArgumentsTest.scala   | 82 +++++++++++++++++++
 2 files changed, 93 insertions(+), 2 deletions(-)
 create mode 100644 core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala

diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 1e295aaa48..54e3937edd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -41,8 +41,8 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
   if (System.getenv("SPARK_WORKER_CORES") != null) {
     cores = System.getenv("SPARK_WORKER_CORES").toInt
   }
-  if (System.getenv("SPARK_WORKER_MEMORY") != null) {
-    memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY"))
+  if (conf.getenv("SPARK_WORKER_MEMORY") != null) {
+    memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY"))
   }
   if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
     webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
@@ -56,6 +56,8 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
 
   parse(args.toList)
 
+  checkWorkerMemory()
+
   def parse(args: List[String]): Unit = args match {
     case ("--ip" | "-i") :: value :: tail =>
       Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -153,4 +155,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
     // Leave out 1 GB for the operating system, but don't return a negative memory size
     math.max(totalMb - 1024, 512)
   }
+
+  def checkWorkerMemory(): Unit = {
+    if (memory <= 0) {
+      val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
+      throw new IllegalStateException(message)
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
new file mode 100644
index 0000000000..1a28a9a187
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+
+class WorkerArgumentsTest extends FunSuite {
+
+  test("Memory can't be set to 0 when cmd line args leave off M or G") {
+    val conf = new SparkConf
+    val args = Array("-m", "10000", "spark://localhost:0000  ")
+    intercept[IllegalStateException] {
+      new WorkerArguments(args, conf)
+    }
+  }
+
+
+  test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") {
+    val args = Array("spark://localhost:0000  ")
+
+    class MySparkConf extends SparkConf(false) {
+      override def getenv(name: String) = {
+        if (name == "SPARK_WORKER_MEMORY") "50000"
+        else super.getenv(name)
+      }
+
+      override def clone: SparkConf = {
+        new MySparkConf().setAll(settings)
+      }
+    }
+    val conf = new MySparkConf()
+    intercept[IllegalStateException] {
+      new WorkerArguments(args, conf)
+    }
+  }
+
+  test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") {
+    val args = Array("spark://localhost:0000  ")
+
+    class MySparkConf extends SparkConf(false) {
+      override def getenv(name: String) = {
+        if (name == "SPARK_WORKER_MEMORY") "5G"
+        else super.getenv(name)
+      }
+
+      override def clone: SparkConf = {
+        new MySparkConf().setAll(settings)
+      }
+    }
+    val conf = new MySparkConf()
+    val workerArgs =  new WorkerArguments(args, conf)
+    assert(workerArgs.memory === 5120)
+  }
+
+  test("Memory correctly set from args with M appended to memory value") {
+    val conf = new SparkConf
+    val args = Array("-m", "10000M", "spark://localhost:0000  ")
+
+    val workerArgs = new WorkerArguments(args, conf)
+    assert(workerArgs.memory === 10000)
+
+  }
+
+}
-- 
GitLab