Skip to content
Snippets Groups Projects
Commit f86960cb authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #313 from rxin/pde_size_compress

Added a partition preserving flag to MapPartitionsWithSplitRDD.
parents 3ebd8e18 bd6dd1a3
No related branches found
No related tags found
No related merge requests found
package spark
import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
import java.net.URL
import java.util.Random
import java.util.Date
import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
......@@ -48,7 +48,7 @@ import spark.storage.StorageLevel
import SparkContext._
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
* partitioned collection of elements that can be operated on in parallel. This class contains the
* basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
* [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
......@@ -87,28 +87,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
@transient val dependencies: List[Dependency[_]]
// Methods available on all RDDs:
/** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
/** Optionally overridden by subclasses to specify placement preferences. */
def preferredLocations(split: Split): Seq[String] = Nil
/** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
/**
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
......@@ -124,32 +124,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
}
}
// This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key.
def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} )
val p = this.partitioner
new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
override val partitioner = p
override val partitioner = p
}
}
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
......@@ -162,9 +162,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
compute(split)
}
}
// Transformations (return a new RDD)
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
......@@ -200,13 +200,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var multiplier = 3.0
var initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
if (num > initialCount) {
total = maxSelected
fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
......@@ -216,14 +216,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
total = num
}
val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
}
Utils.randomizeInPlace(samples, rand).take(total)
}
......@@ -291,8 +291,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f))
def mapPartitionsWithSplit[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
......@@ -351,7 +353,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2.
*/
......@@ -452,7 +454,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
sc.runApproximateJob(this, countPartition, evaluator, timeout)
}
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
* it will be slow if a lot of partitions are required. In that case, use collect() to get the
......
......@@ -12,9 +12,11 @@ import spark.Split
private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U])
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean)
extends RDD[U](prev.context) {
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment