Skip to content
Snippets Groups Projects
Commit c23f5db3 authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Patrick Wendell
Browse files

[SPARK-2251] fix concurrency issues in random sampler

The following code is very likely to throw an exception:

~~~
val rdd = sc.parallelize(0 until 111, 10).sample(false, 0.1)
rdd.zip(rdd).count()
~~~

because the same random number generator is used in compute partitions.

Author: Xiangrui Meng <meng@databricks.com>

Closes #1229 from mengxr/fix-sample and squashes the following commits:

f1ee3d7 [Xiangrui Meng] fix concurrency issues in random sampler
parent d1636dd7
No related branches found
No related tags found
No related merge requests found
...@@ -54,17 +54,17 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable ...@@ -54,17 +54,17 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
*/ */
@DeveloperApi @DeveloperApi
class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
(implicit random: Random = new XORShiftRandom)
extends RandomSampler[T, T] { extends RandomSampler[T, T] {
def this(ratio: Double)(implicit random: Random = new XORShiftRandom) private[random] var rng: Random = new XORShiftRandom
= this(0.0d, ratio)(random)
override def setSeed(seed: Long) = random.setSeed(seed) def this(ratio: Double) = this(0.0d, ratio)
override def setSeed(seed: Long) = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = { override def sample(items: Iterator[T]): Iterator[T] = {
items.filter { item => items.filter { item =>
val x = random.nextDouble() val x = rng.nextDouble()
(x >= lb && x < ub) ^ complement (x >= lb && x < ub) ^ complement
} }
} }
...@@ -72,7 +72,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) ...@@ -72,7 +72,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
/** /**
* Return a sampler that is the complement of the range specified of the current sampler. * Return a sampler that is the complement of the range specified of the current sampler.
*/ */
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
override def clone = new BernoulliSampler[T](lb, ub, complement) override def clone = new BernoulliSampler[T](lb, ub, complement)
} }
...@@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) ...@@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
* :: DeveloperApi :: * :: DeveloperApi ::
* A sampler based on values drawn from Poisson distribution. * A sampler based on values drawn from Poisson distribution.
* *
* @param poisson a Poisson random number generator * @param mean Poisson mean
* @tparam T item type * @tparam T item type
*/ */
@DeveloperApi @DeveloperApi
class PoissonSampler[T](mean: Double) class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
(implicit var poisson: Poisson = new Poisson(mean, new DRand))
extends RandomSampler[T, T] { private[random] var rng = new Poisson(mean, new DRand)
override def setSeed(seed: Long) { override def setSeed(seed: Long) {
poisson = new Poisson(mean, new DRand(seed.toInt)) rng = new Poisson(mean, new DRand(seed.toInt))
} }
override def sample(items: Iterator[T]): Iterator[T] = { override def sample(items: Iterator[T]): Iterator[T] = {
items.flatMap { item => items.flatMap { item =>
val count = poisson.nextInt() val count = rng.nextInt()
if (count == 0) { if (count == 0) {
Iterator.empty Iterator.empty
} else { } else {
......
...@@ -20,7 +20,7 @@ package org.apache.spark.rdd ...@@ -20,7 +20,7 @@ package org.apache.spark.rdd
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.SharedSparkContext import org.apache.spark.SharedSparkContext
import org.apache.spark.util.random.RandomSampler import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}
/** a sampler that outputs its seed */ /** a sampler that outputs its seed */
class MockSampler extends RandomSampler[Long, Long] { class MockSampler extends RandomSampler[Long, Long] {
...@@ -32,7 +32,7 @@ class MockSampler extends RandomSampler[Long, Long] { ...@@ -32,7 +32,7 @@ class MockSampler extends RandomSampler[Long, Long] {
} }
override def sample(items: Iterator[Long]): Iterator[Long] = { override def sample(items: Iterator[Long]): Iterator[Long] = {
return Iterator(s) Iterator(s)
} }
override def clone = new MockSampler override def clone = new MockSampler
...@@ -40,11 +40,21 @@ class MockSampler extends RandomSampler[Long, Long] { ...@@ -40,11 +40,21 @@ class MockSampler extends RandomSampler[Long, Long] {
class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
test("seedDistribution") { test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler val sampler = new MockSampler
val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L) val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
assert(sample.distinct.count == 2, "Seeds must be different.") assert(sample.distinct().count == 2, "Seeds must be different.")
}
test("concurrency") {
// SPARK-2251: zip with self computes each partition twice.
// We want to make sure there are no concurrency issues.
val rdd = sc.parallelize(0 until 111, 10)
for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
sampled.zip(sampled).count()
}
} }
} }
...@@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
} }
whenExecuting(random) { whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55)(random) val sampler = new BernoulliSampler[Int](0.25, 0.55)
sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(3, 4, 5)) assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
} }
} }
...@@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
} }
whenExecuting(random) { whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
sampler.rng = random
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9)) assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
} }
} }
...@@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
} }
whenExecuting(random) { whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.35)(random) val sampler = new BernoulliSampler[Int](0.35)
sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 3)) assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
} }
} }
...@@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
} }
whenExecuting(random) { whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9)) assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
} }
} }
...@@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
random.setSeed(10L) random.setSeed(10L)
} }
whenExecuting(random) { whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.2)(random) val sampler = new BernoulliSampler[Int](0.2)
sampler.rng = random
sampler.setSeed(10L) sampler.setSeed(10L)
} }
} }
...@@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
} }
whenExecuting(poisson) { whenExecuting(poisson) {
val sampler = new PoissonSampler[Int](0.2)(poisson) val sampler = new PoissonSampler[Int](0.2)
sampler.rng = poisson
assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6)) assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment