diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 247f10173f1e983f4f36dae8e872c92d1021a0ec..32c5fdad75e582db661b3a1083bdcb375eaa753a 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -54,17 +54,17 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
  */
 @DeveloperApi
 class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
-    (implicit random: Random = new XORShiftRandom)
   extends RandomSampler[T, T] {
 
-  def this(ratio: Double)(implicit random: Random = new XORShiftRandom)
-    = this(0.0d, ratio)(random)
+  private[random] var rng: Random = new XORShiftRandom
 
-  override def setSeed(seed: Long) = random.setSeed(seed)
+  def this(ratio: Double) = this(0.0d, ratio)
+
+  override def setSeed(seed: Long) = rng.setSeed(seed)
 
   override def sample(items: Iterator[T]): Iterator[T] = {
     items.filter { item =>
-      val x = random.nextDouble()
+      val x = rng.nextDouble()
       (x >= lb && x < ub) ^ complement
     }
   }
@@ -72,7 +72,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
   /**
    *  Return a sampler that is the complement of the range specified of the current sampler.
    */
-  def cloneComplement():  BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
+  def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
 
   override def clone = new BernoulliSampler[T](lb, ub, complement)
 }
@@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
  * :: DeveloperApi ::
  * A sampler based on values drawn from Poisson distribution.
  *
- * @param poisson a Poisson random number generator
+ * @param mean Poisson mean
  * @tparam T item type
  */
 @DeveloperApi
-class PoissonSampler[T](mean: Double)
-    (implicit var poisson: Poisson = new Poisson(mean, new DRand))
-  extends RandomSampler[T, T] {
+class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
+
+  private[random] var rng = new Poisson(mean, new DRand)
 
   override def setSeed(seed: Long) {
-    poisson = new Poisson(mean, new DRand(seed.toInt))
+    rng = new Poisson(mean, new DRand(seed.toInt))
   }
 
   override def sample(items: Iterator[T]): Iterator[T] = {
     items.flatMap { item =>
-      val count = poisson.nextInt()
+      val count = rng.nextInt()
       if (count == 0) {
         Iterator.empty
       } else {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 00c273df63b29f45f827bd6cc225e42793e7572b..5dd8de319a654443cf0e18b837a03aa188185855 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.rdd
 import org.scalatest.FunSuite
 
 import org.apache.spark.SharedSparkContext
-import org.apache.spark.util.random.RandomSampler
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}
 
 /** a sampler that outputs its seed */
 class MockSampler extends RandomSampler[Long, Long] {
@@ -32,7 +32,7 @@ class MockSampler extends RandomSampler[Long, Long] {
   }
 
   override def sample(items: Iterator[Long]): Iterator[Long] = {
-    return Iterator(s)
+    Iterator(s)
   }
 
   override def clone = new MockSampler
@@ -40,11 +40,21 @@ class MockSampler extends RandomSampler[Long, Long] {
 
 class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
 
-  test("seedDistribution") {
+  test("seed distribution") {
     val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
     val sampler = new MockSampler
     val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
-    assert(sample.distinct.count == 2, "Seeds must be different.")
+    assert(sample.distinct().count == 2, "Seeds must be different.")
+  }
+
+  test("concurrency") {
+    // SPARK-2251: zip with self computes each partition twice.
+    // We want to make sure there are no concurrency issues.
+    val rdd = sc.parallelize(0 until 111, 10)
+    for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
+      val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+      sampled.zip(sampled).count()
+    }
   }
 }
 
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index e166787f17544d8b9ecf3539951313fae6aaf35f..36877476e708e72800ca67cdf424c4606d74ee47 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       }
     }
     whenExecuting(random) {
-      val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
+      val sampler = new BernoulliSampler[Int](0.25, 0.55)
+      sampler.rng = random
       assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
     }
   }
@@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       }
     }
     whenExecuting(random) {
-      val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+      val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
+      sampler.rng = random
       assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
     }
   }
@@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       }
     }
     whenExecuting(random) {
-      val sampler = new BernoulliSampler[Int](0.35)(random)
+      val sampler = new BernoulliSampler[Int](0.35)
+      sampler.rng = random
       assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
     }
   }
@@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       }
     }
     whenExecuting(random) {
-      val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+      val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
+      sampler.rng = random
       assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
     }
   }
@@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       random.setSeed(10L)
     }
     whenExecuting(random) {
-      val sampler = new BernoulliSampler[Int](0.2)(random)
+      val sampler = new BernoulliSampler[Int](0.2)
+      sampler.rng = random
       sampler.setSeed(10L)
     }
   }
@@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       }
     }
     whenExecuting(poisson) {
-      val sampler = new PoissonSampler[Int](0.2)(poisson)
+      val sampler = new PoissonSampler[Int](0.2)
+      sampler.rng = poisson
       assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
     }
   }