Skip to content
Snippets Groups Projects
Commit 051785c7 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Several fixes to sampling issues pointed out by Henry Milner:

- takeSample was biased towards earlier partitions
- There were some range errors in takeSample
- SampledRDDs with replacement didn't produce appropriate counts
  across partitions (we took exactly frac of each one)
parent 56c90485
No related branches found
No related tags found
No related merge requests found
......@@ -145,8 +145,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE) {
maxSelected = Integer.MAX_VALUE
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
......@@ -161,15 +161,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
total = num
}
var samples = this.sample(withReplacement, fraction, seed).collect()
val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, seed).collect()
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
}
val arr = samples.take(total)
return arr
Utils.randomizeInPlace(samples, rand).take(total)
}
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
......
package spark
import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
override val index: Int = prev.index
......@@ -28,19 +30,21 @@ class SampledRDD[T: ClassManifest](
override def compute(splitIn: Split) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
val rg = new Random(split.seed)
// Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
if (withReplacement) {
val oldData = prev.iterator(split.prev).toArray
val sampleSize = (oldData.size * frac).ceil.toInt
val sampledData = {
// all of oldData's indices are candidates, even if sampleSize < oldData.size
for (i <- 1 to sampleSize)
yield oldData(rg.nextInt(oldData.size))
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
prev.iterator(split.prev).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
} else {
Iterator.fill(count)(element)
}
}
sampledData.iterator
} else { // Sampling without replacement
prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
val rand = new Random(split.seed)
prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
}
}
}
......@@ -2,12 +2,11 @@ package spark
import java.io._
import java.net.{InetAddress, URL, URI}
import java.util.{Locale, UUID}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import scala.io.Source
/**
......@@ -172,17 +171,22 @@ object Utils extends Logging {
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
val buf = new ArrayBuffer[T]()
buf ++= seq
val rand = new Random()
for (i <- (buf.size - 1) to 1 by -1) {
def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
randomizeInPlace(seq.toArray)
}
/**
* Shuffle the elements of an array into a random order, modifying the
* original array. Returns the original array.
*/
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
for (i <- (arr.length - 1) to 1 by -1) {
val j = rand.nextInt(i)
val tmp = buf(j)
buf(j) = buf(i)
buf(i) = tmp
val tmp = arr(j)
arr(j) = arr(i)
arr(i) = tmp
}
buf
arr
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment