diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 755f1a760ee051dfd9c94ce6efb5d7765609060d..632ff047d10428dbcc1ab055e2d23bf1a5753b66 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -23,9 +23,9 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. - * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. To create your own storage level object, use the factor method - * of the singleton object (`StorageLevel(...)`). + * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants + * for commonly useful storage levels. To create your own storage level object, use the + * factory method of the singleton object (`StorageLevel(...)`). */ class StorageLevel private( private var useDisk_ : Boolean, diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index fd5972d381a80c7a118ede5201c06eda513318b5..1f35f6f939d8e70c16ec3ebaff5d0b70df59c7bc 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -30,6 +30,8 @@ Public classes: An "add-only" shared variable that tasks can only add values to. - L{SparkFiles<pyspark.files.SparkFiles>} Access files shipped with jobs. + - L{StorageLevel<pyspark.storagelevel.StorageLevel>} + Finer-grained cache persistence levels. """ import sys import os @@ -39,6 +41,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles +from pyspark.storagelevel import StorageLevel -__all__ = ["SparkContext", "RDD", "SparkFiles"] +__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8fbf2965097d925fb62d835e89945fa17bfeaa4f..597110321a86370c29063052eb892f9213a4bfb3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -27,6 +27,7 @@ from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -279,6 +280,16 @@ class SparkContext(object): """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + def _getJavaStorageLevel(self, storageLevel): + """ + Returns a Java StorageLevel based on a pyspark.StorageLevel. + """ + if not isinstance(storageLevel, StorageLevel): + raise Exception("storageLevel must be of type pyspark.StorageLevel") + + newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel + return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, + storageLevel.deserialized, storageLevel.replication) def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 914118ccdd92b7560514e13e118b18dbd947659c..58e1849cadac8611c0a4e85681cc67ea6bb7ca25 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,25 @@ class RDD(object): self._jrdd.cache() return self + def persist(self, storageLevel): + """ + Set this RDD's storage level to persist its values across operations after the first time + it is computed. This can only be used to assign a new storage level if the RDD does not + have a storage level set yet. + """ + self.is_cached = True + javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + self._jrdd.persist(javaStorageLevel) + return self + + def unpersist(self): + """ + Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + """ + self.is_cached = False + self._jrdd.unpersist() + return self + def checkpoint(self): """ Mark this RDD for checkpointing. It will be saved to a file inside the diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 54823f80378787ff5c73c78f6b8ab8d336d52975..dc205b306f0a93d398d4b18c95da6f888a93ec6e 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -24,6 +24,7 @@ import os import platform import pyspark from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel # this is the equivalent of ADD_JARS add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py new file mode 100644 index 0000000000000000000000000000000000000000..b31f4762e69bc42cad009183004c0349fcf5e798 --- /dev/null +++ b/python/pyspark/storagelevel.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["StorageLevel"] + +class StorageLevel: + """ + Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, + whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory + in a serialized format, and whether to replicate the RDD partitions on multiple nodes. + Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. + """ + + def __init__(self, useDisk, useMemory, deserialized, replication = 1): + self.useDisk = useDisk + self.useMemory = useMemory + self.deserialized = deserialized + self.replication = replication + +StorageLevel.DISK_ONLY = StorageLevel(True, False, False) +StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, 2) +StorageLevel.MEMORY_ONLY = StorageLevel(False, True, True) +StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, True, 2) +StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False) +StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, 2) +StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, True) +StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, True, 2) +StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False) +StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, 2)