From 7db56244fa3dba92246bad6694f31bbf68ea47ec Mon Sep 17 00:00:00 2001
From: Yong Tang <yong.tang.github@outlook.com>
Date: Tue, 5 Apr 2016 12:19:20 +0900
Subject: [PATCH] [SPARK-14368][PYSPARK] Support python.spark.worker.memory
 with upper-case unit.

## What changes were proposed in this pull request?

This fix tries to address the issue in PySpark where `spark.python.worker.memory`
could only be configured with a lower case unit (`k`, `m`, `g`, `t`). This fix
allows the upper case unit (`K`, `M`, `G`, `T`) to be used as well. This is to
conform to the JVM memory string as is specified in the documentation .

## How was this patch tested?

This fix adds additional test to cover the changes.

Author: Yong Tang <yong.tang.github@outlook.com>

Closes #12163 from yongtang/SPARK-14368.
---
 python/pyspark/rdd.py   |  2 +-
 python/pyspark/tests.py | 12 ++++++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cd1f64e8aa..8978f028c5 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -115,7 +115,7 @@ def _parse_memory(s):
     2048
     """
     units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
-    if s[-1] not in units:
+    if s[-1].lower() not in units:
         raise ValueError("invalid format: " + s)
     return int(float(s[:-1]) * units[s[-1].lower()])
 
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index a5a83c7e38..40fccb8c00 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1966,6 +1966,18 @@ class ContextTests(unittest.TestCase):
             self.assertGreater(sc.startTime, 0)
 
 
+class ConfTests(unittest.TestCase):
+    def test_memory_conf(self):
+        memoryList = ["1T", "1G", "1M", "1024K"]
+        for memory in memoryList:
+            sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory))
+            l = list(range(1024))
+            random.shuffle(l)
+            rdd = sc.parallelize(l, 4)
+            self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
+            sc.stop()
+
+
 @unittest.skipIf(not _have_scipy, "SciPy not installed")
 class SciPyTests(PySparkTestCase):
 
-- 
GitLab