From c1cc8c4da239965e8ad478089b27e9c694088978 Mon Sep 17 00:00:00 2001
From: Aaron Davidson <aaron@databricks.com>
Date: Sat, 7 Sep 2013 14:41:31 -0700
Subject: [PATCH] Export StorageLevel and refactor

---
 .../apache/spark/storage/StorageLevel.scala   |  6 +--
 python/pyspark/__init__.py                    |  5 ++-
 python/pyspark/context.py                     | 35 ++++++---------
 python/pyspark/rdd.py                         |  3 +-
 python/pyspark/shell.py                       |  2 +-
 python/pyspark/storagelevel.py                | 43 +++++++++++++++++++
 6 files changed, 65 insertions(+), 29 deletions(-)
 create mode 100644 python/pyspark/storagelevel.py

diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 755f1a760e..632ff047d1 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -23,9 +23,9 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
  * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
  * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
  * in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
- * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants for
- * commonly useful storage levels. To create your own storage level object, use the factor method
- * of the singleton object (`StorageLevel(...)`).
+ * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants
+ * for commonly useful storage levels. To create your own storage level object, use the
+ * factory method of the singleton object (`StorageLevel(...)`).
  */
 class StorageLevel private(
     private var useDisk_ : Boolean,
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index fd5972d381..1f35f6f939 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -30,6 +30,8 @@ Public classes:
         An "add-only" shared variable that tasks can only add values to.
     - L{SparkFiles<pyspark.files.SparkFiles>}
         Access files shipped with jobs.
+    - L{StorageLevel<pyspark.storagelevel.StorageLevel>}
+        Finer-grained cache persistence levels.
 """
 import sys
 import os
@@ -39,6 +41,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg
 from pyspark.context import SparkContext
 from pyspark.rdd import RDD
 from pyspark.files import SparkFiles
+from pyspark.storagelevel import StorageLevel
 
 
-__all__ = ["SparkContext", "RDD", "SparkFiles"]
+__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"]
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 4c48cd3f37..efd7828df6 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -27,6 +27,7 @@ from pyspark.broadcast import Broadcast
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 
 from py4j.java_collections import ListConverter
@@ -119,29 +120,6 @@ class SparkContext(object):
         self._temp_dir = \
             self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
 
-        self._initStorageLevel()
-
-    def _initStorageLevel(self):
-        """
-        Initializes the StorageLevel object, which mimics the behavior of the scala object
-        by the same name. e.g., StorageLevel.DISK_ONLY returns the equivalent Java StorageLevel.
-        """
-        newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
-        levels = {
-            'NONE': newStorageLevel(False, False, False, 1),
-            'DISK_ONLY': newStorageLevel(True, False, False, 1),
-            'DISK_ONLY_2': newStorageLevel(True, False, False, 2),
-            'MEMORY_ONLY': newStorageLevel(False, True, True, 1),
-            'MEMORY_ONLY_2': newStorageLevel(False, True, True, 2),
-            'MEMORY_ONLY_SER': newStorageLevel(False, True, False, 1),
-            'MEMORY_ONLY_SER_2': newStorageLevel(False, True, False, 2),
-            'MEMORY_AND_DISK': newStorageLevel(True, True, True, 1),
-            'MEMORY_AND_DISK_2': newStorageLevel(True, True, True, 2),
-            'MEMORY_AND_DISK_SER': newStorageLevel(True, True, False, 1),
-            'MEMORY_AND_DISK_SER_2': newStorageLevel(True, True, False, 2),
-        }
-        self.StorageLevel = type('StorageLevel', (), levels)
-
     @property
     def defaultParallelism(self):
         """
@@ -303,6 +281,17 @@ class SparkContext(object):
         """
         self._jsc.sc().setCheckpointDir(dirName, useExisting)
 
+    def _getJavaStorageLevel(self, storageLevel):
+        """
+        Returns a Java StorageLevel based on a pyspark.StorageLevel.
+        """
+        if not isinstance(storageLevel, StorageLevel):
+            raise Exception("storageLevel must be of type pyspark.StorageLevel")
+
+        newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
+        return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory,
+            storageLevel.deserialized, storageLevel.replication)
+
 def _test():
     import atexit
     import doctest
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 332258f5d1..58e1849cad 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -77,7 +77,8 @@ class RDD(object):
         have a storage level set yet.
         """
         self.is_cached = True
-        self._jrdd.persist(storageLevel)
+        javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
+        self._jrdd.persist(javaStorageLevel)
         return self
 
     def unpersist(self):
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index e374ca4ee4..dc205b306f 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -24,12 +24,12 @@ import os
 import platform
 import pyspark
 from pyspark.context import SparkContext
+from pyspark.storagelevel import StorageLevel
 
 # this is the equivalent of ADD_JARS
 add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None
 
 sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files)
-StorageLevel = sc.StorageLevel # alias StorageLevel to global scope
 
 print """Welcome to
       ____              __
diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py
new file mode 100644
index 0000000000..b31f4762e6
--- /dev/null
+++ b/python/pyspark/storagelevel.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+__all__ = ["StorageLevel"]
+
+class StorageLevel:
+    """
+    Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
+    whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
+    in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
+    Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY.
+    """
+
+    def __init__(self, useDisk, useMemory, deserialized, replication = 1):
+        self.useDisk = useDisk
+        self.useMemory = useMemory
+        self.deserialized = deserialized
+        self.replication = replication
+
+StorageLevel.DISK_ONLY = StorageLevel(True, False, False)
+StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, 2)
+StorageLevel.MEMORY_ONLY = StorageLevel(False, True, True)
+StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, True, 2)
+StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False)
+StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, 2)
+StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, True)
+StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, True, 2)
+StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False)
+StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, 2)
-- 
GitLab