diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 247f10173f1e983f4f36dae8e872c92d1021a0ec..32c5fdad75e582db661b3a1083bdcb375eaa753a 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -54,17 +54,17 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable */ @DeveloperApi class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) - (implicit random: Random = new XORShiftRandom) extends RandomSampler[T, T] { - def this(ratio: Double)(implicit random: Random = new XORShiftRandom) - = this(0.0d, ratio)(random) + private[random] var rng: Random = new XORShiftRandom - override def setSeed(seed: Long) = random.setSeed(seed) + def this(ratio: Double) = this(0.0d, ratio) + + override def setSeed(seed: Long) = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { items.filter { item => - val x = random.nextDouble() + val x = rng.nextDouble() (x >= lb && x < ub) ^ complement } } @@ -72,7 +72,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) /** * Return a sampler that is the complement of the range specified of the current sampler. */ - def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) + def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) override def clone = new BernoulliSampler[T](lb, ub, complement) } @@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) * :: DeveloperApi :: * A sampler based on values drawn from Poisson distribution. * - * @param poisson a Poisson random number generator + * @param mean Poisson mean * @tparam T item type */ @DeveloperApi -class PoissonSampler[T](mean: Double) - (implicit var poisson: Poisson = new Poisson(mean, new DRand)) - extends RandomSampler[T, T] { +class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] { + + private[random] var rng = new Poisson(mean, new DRand) override def setSeed(seed: Long) { - poisson = new Poisson(mean, new DRand(seed.toInt)) + rng = new Poisson(mean, new DRand(seed.toInt)) } override def sample(items: Iterator[T]): Iterator[T] = { items.flatMap { item => - val count = poisson.nextInt() + val count = rng.nextInt() if (count == 0) { Iterator.empty } else { diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 00c273df63b29f45f827bd6cc225e42793e7572b..5dd8de319a654443cf0e18b837a03aa188185855 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.rdd import org.scalatest.FunSuite import org.apache.spark.SharedSparkContext -import org.apache.spark.util.random.RandomSampler +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} /** a sampler that outputs its seed */ class MockSampler extends RandomSampler[Long, Long] { @@ -32,7 +32,7 @@ class MockSampler extends RandomSampler[Long, Long] { } override def sample(items: Iterator[Long]): Iterator[Long] = { - return Iterator(s) + Iterator(s) } override def clone = new MockSampler @@ -40,11 +40,21 @@ class MockSampler extends RandomSampler[Long, Long] { class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { - test("seedDistribution") { + test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) val sampler = new MockSampler val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L) - assert(sample.distinct.count == 2, "Seeds must be different.") + assert(sample.distinct().count == 2, "Seeds must be different.") + } + + test("concurrency") { + // SPARK-2251: zip with self computes each partition twice. + // We want to make sure there are no concurrency issues. + val rdd = sc.parallelize(0 until 111, 10) + for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) { + val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler) + sampled.zip(sampled).count() + } } } diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index e166787f17544d8b9ecf3539951313fae6aaf35f..36877476e708e72800ca67cdf424c4606d74ee47 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.25, 0.55)(random) + val sampler = new BernoulliSampler[Int](0.25, 0.55) + sampler.rng = random assert(sampler.sample(a.iterator).toList == List(3, 4, 5)) } } @@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) + val sampler = new BernoulliSampler[Int](0.25, 0.55, true) + sampler.rng = random assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9)) } } @@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.35)(random) + val sampler = new BernoulliSampler[Int](0.35) + sampler.rng = random assert(sampler.sample(a.iterator).toList == List(1, 2, 3)) } } @@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) + val sampler = new BernoulliSampler[Int](0.25, 0.55, true) + sampler.rng = random assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9)) } } @@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar random.setSeed(10L) } whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.2)(random) + val sampler = new BernoulliSampler[Int](0.2) + sampler.rng = random sampler.setSeed(10L) } } @@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } whenExecuting(poisson) { - val sampler = new PoissonSampler[Int](0.2)(poisson) + val sampler = new PoissonSampler[Int](0.2) + sampler.rng = poisson assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6)) } }