Skip to content
Snippets Groups Projects
Commit bd6dd1a3 authored by Reynold Xin's avatar Reynold Xin
Browse files

Added a partition preserving flag to MapPartitionsWithSplitRDD.

parent f24bfd2d
No related branches found
No related tags found
No related merge requests found
package spark package spark
import java.io.EOFException import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong import java.net.URL
import java.util.Random import java.util.Random
import java.util.Date import java.util.Date
import java.util.{HashMap => JHashMap} import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.NullWritable
...@@ -47,7 +47,7 @@ import spark.storage.StorageLevel ...@@ -47,7 +47,7 @@ import spark.storage.StorageLevel
import SparkContext._ import SparkContext._
/** /**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
* partitioned collection of elements that can be operated on in parallel. This class contains the * partitioned collection of elements that can be operated on in parallel. This class contains the
* basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
* [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
...@@ -86,28 +86,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -86,28 +86,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
@transient val dependencies: List[Dependency[_]] @transient val dependencies: List[Dependency[_]]
// Methods available on all RDDs: // Methods available on all RDDs:
/** Record user function generating this RDD. */ /** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite private[spark] val origin = Utils.getSparkCallSite
/** Optionally overridden by subclasses to specify how they are partitioned. */ /** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None val partitioner: Option[Partitioner] = None
/** Optionally overridden by subclasses to specify placement preferences. */ /** Optionally overridden by subclasses to specify placement preferences. */
def preferredLocations(split: Split): Seq[String] = Nil def preferredLocations(split: Split): Seq[String] = Nil
/** The [[spark.SparkContext]] that this RDD was created on. */ /** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc def context = sc
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
/** A unique ID for this RDD (within its SparkContext). */ /** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId() val id = sc.newRddId()
// Variables relating to persistence // Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE private var storageLevel: StorageLevel = StorageLevel.NONE
/** /**
* Set this RDD's storage level to persist its values across operations after the first time * Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD. * it is computed. Can only be called once on each RDD.
*/ */
...@@ -123,32 +123,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -123,32 +123,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist() def cache(): RDD[T] = persist()
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel def getStorageLevel = storageLevel
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) { if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
} }
// This is a hack. Ideally this should re-use the code used by the CacheTracker // This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key. // to generate the key.
def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
persist(level) persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} ) sc.runJob(this, (iter: Iterator[T]) => {} )
val p = this.partitioner val p = this.partitioner
new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
override val partitioner = p override val partitioner = p
} }
} }
/** /**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom * This should ''not'' be called by users directly, but is available for implementors of custom
...@@ -161,9 +161,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -161,9 +161,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
compute(split) compute(split)
} }
} }
// Transformations (return a new RDD) // Transformations (return a new RDD)
/** /**
* Return a new RDD by applying a function to all elements of this RDD. * Return a new RDD by applying a function to all elements of this RDD.
*/ */
...@@ -199,13 +199,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -199,13 +199,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var multiplier = 3.0 var multiplier = 3.0
var initialCount = count() var initialCount = count()
var maxSelected = 0 var maxSelected = 0
if (initialCount > Integer.MAX_VALUE - 1) { if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1 maxSelected = Integer.MAX_VALUE - 1
} else { } else {
maxSelected = initialCount.toInt maxSelected = initialCount.toInt
} }
if (num > initialCount) { if (num > initialCount) {
total = maxSelected total = maxSelected
fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0) fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
...@@ -215,14 +215,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -215,14 +215,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
fraction = math.min(multiplier * (num + 1) / initialCount, 1.0) fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
total = num total = num
} }
val rand = new Random(seed) val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect() var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) { while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt).collect() samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
} }
Utils.randomizeInPlace(samples, rand).take(total) Utils.randomizeInPlace(samples, rand).take(total)
} }
...@@ -290,8 +290,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -290,8 +290,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition. * of the original partition.
*/ */
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = def mapPartitionsWithSplit[U: ClassManifest](
new MapPartitionsWithSplitRDD(this, sc.clean(f)) f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
// Actions (launch a job to return a value to the user program) // Actions (launch a job to return a value to the user program)
...@@ -342,7 +344,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -342,7 +344,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** /**
* Aggregate the elements of each partition, and then the results for all the partitions, using a * Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not * modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2. * modify t2.
*/ */
...@@ -443,7 +445,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -443,7 +445,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
val evaluator = new GroupedCountEvaluator[T](splits.size, confidence) val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
sc.runApproximateJob(this, countPartition, evaluator, timeout) sc.runApproximateJob(this, countPartition, evaluator, timeout)
} }
/** /**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
* it will be slow if a lot of partitions are required. In that case, use collect() to get the * it will be slow if a lot of partitions are required. In that case, use collect() to get the
......
...@@ -12,9 +12,11 @@ import spark.Split ...@@ -12,9 +12,11 @@ import spark.Split
private[spark] private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T], prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U]) f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean)
extends RDD[U](prev.context) { extends RDD[U](prev.context) {
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def splits = prev.splits override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev)) override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(split.index, prev.iterator(split)) override def compute(split: Split) = f(split.index, prev.iterator(split))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment