diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ccd383396460254d5ba433d0c654004f8a5d5a23..6ca56b3af63f38054a27b0a2603a2c14e653b59d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -28,6 +28,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PipedRDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -270,6 +271,16 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + /** + * Returns the StorageLevel with the given string name. + * Throws an exception if the name is not a valid StorageLevel. + */ + def getStorageLevel(name: String) : StorageLevel = { + // In Scala, "val MEMORY_ONLY" produces a public getter by the same name. + val storageLevelGetter = StorageLevel.getClass().getDeclaredMethod(name) + return storageLevelGetter.invoke(StorageLevel).asInstanceOf[StorageLevel] + } + def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ writeIteratorToPickleFile(items.asScala, filename) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8fbf2965097d925fb62d835e89945fa17bfeaa4f..49f9b4610d4223445c0a5c57163ffffe323b9628 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -279,6 +279,20 @@ class SparkContext(object): """ self._jsc.sc().setCheckpointDir(dirName, useExisting) +class StorageLevelReader: + """ + Mimics the Scala StorageLevel by directing all attribute requests + (e.g., StorageLevel.DISK_ONLY) to the JVM for reflection. + """ + + def __init__(self, sc): + self.sc = sc + + def __getattr__(self, name): + try: + return self.sc._jvm.PythonRDD.getStorageLevel(name) + except: + print "Failed to find StorageLevel:", name def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 914118ccdd92b7560514e13e118b18dbd947659c..332258f5d1064ce1d71e5f9c803079ff595cdeae 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,24 @@ class RDD(object): self._jrdd.cache() return self + def persist(self, storageLevel): + """ + Set this RDD's storage level to persist its values across operations after the first time + it is computed. This can only be used to assign a new storage level if the RDD does not + have a storage level set yet. + """ + self.is_cached = True + self._jrdd.persist(storageLevel) + return self + + def unpersist(self): + """ + Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + """ + self.is_cached = False + self._jrdd.unpersist() + return self + def checkpoint(self): """ Mark this RDD for checkpointing. It will be saved to a file inside the diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 54823f80378787ff5c73c78f6b8ab8d336d52975..9acc176d550091816c6f845a63684672f1e9c81b 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -23,12 +23,13 @@ This file is designed to be launched as a PYTHONSTARTUP script. import os import platform import pyspark -from pyspark.context import SparkContext +from pyspark.context import SparkContext, StorageLevelReader # this is the equivalent of ADD_JARS add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files) +StorageLevel = StorageLevelReader(sc) print """Welcome to ____ __