Skip to content
Snippets Groups Projects
Commit a63d4c7d authored by Aaron Davidson's avatar Aaron Davidson
Browse files

SPARK-660: Add StorageLevel support in Python

It uses reflection... I am not proud of that fact, but it at least ensures
compatibility (sans refactoring of the StorageLevel stuff).
parent 714e7f9e
No related branches found
No related tags found
No related merge requests found
......@@ -28,6 +28,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PipedRDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
......@@ -270,6 +271,16 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
/**
* Returns the StorageLevel with the given string name.
* Throws an exception if the name is not a valid StorageLevel.
*/
def getStorageLevel(name: String) : StorageLevel = {
// In Scala, "val MEMORY_ONLY" produces a public getter by the same name.
val storageLevelGetter = StorageLevel.getClass().getDeclaredMethod(name)
return storageLevelGetter.invoke(StorageLevel).asInstanceOf[StorageLevel]
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeIteratorToPickleFile(items.asScala, filename)
......
......@@ -279,6 +279,20 @@ class SparkContext(object):
"""
self._jsc.sc().setCheckpointDir(dirName, useExisting)
class StorageLevelReader:
"""
Mimics the Scala StorageLevel by directing all attribute requests
(e.g., StorageLevel.DISK_ONLY) to the JVM for reflection.
"""
def __init__(self, sc):
self.sc = sc
def __getattr__(self, name):
try:
return self.sc._jvm.PythonRDD.getStorageLevel(name)
except:
print "Failed to find StorageLevel:", name
def _test():
import atexit
......
......@@ -70,6 +70,24 @@ class RDD(object):
self._jrdd.cache()
return self
def persist(self, storageLevel):
"""
Set this RDD's storage level to persist its values across operations after the first time
it is computed. This can only be used to assign a new storage level if the RDD does not
have a storage level set yet.
"""
self.is_cached = True
self._jrdd.persist(storageLevel)
return self
def unpersist(self):
"""
Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
"""
self.is_cached = False
self._jrdd.unpersist()
return self
def checkpoint(self):
"""
Mark this RDD for checkpointing. It will be saved to a file inside the
......
......@@ -23,12 +23,13 @@ This file is designed to be launched as a PYTHONSTARTUP script.
import os
import platform
import pyspark
from pyspark.context import SparkContext
from pyspark.context import SparkContext, StorageLevelReader
# this is the equivalent of ADD_JARS
add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None
sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files)
StorageLevel = StorageLevelReader(sc)
print """Welcome to
____ __
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment