diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index efe248896a16b8163abd166985d77f6850342d25..5fac955286ff7a1cdb2647be136ac6ef08728cb1 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -145,8 +145,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
     var initialCount = count()
     var maxSelected = 0
     
-    if (initialCount > Integer.MAX_VALUE) {
-      maxSelected = Integer.MAX_VALUE
+    if (initialCount > Integer.MAX_VALUE - 1) {
+      maxSelected = Integer.MAX_VALUE - 1
     } else {
       maxSelected = initialCount.toInt
     }
@@ -161,15 +161,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
       total = num
     }
   
-    var samples = this.sample(withReplacement, fraction, seed).collect()
+    val rand = new Random(seed)
+    var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
   
     while (samples.length < total) {
-      samples = this.sample(withReplacement, fraction, seed).collect()
+      samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
     }
   
-    val arr = samples.take(total)
-  
-    return arr
+    Utils.randomizeInPlace(samples, rand).take(total)
   }
 
   def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
diff --git a/core/src/main/scala/spark/SampledRDD.scala b/core/src/main/scala/spark/SampledRDD.scala
index 8ef40d8d9e6777ff812cd7c04c5c3c48cdd50049..c066017e892c140803cfe673afca3f9fb77c7799 100644
--- a/core/src/main/scala/spark/SampledRDD.scala
+++ b/core/src/main/scala/spark/SampledRDD.scala
@@ -1,6 +1,8 @@
 package spark
 
 import java.util.Random
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
 
 class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
   override val index: Int = prev.index
@@ -28,19 +30,21 @@ class SampledRDD[T: ClassManifest](
 
   override def compute(splitIn: Split) = {
     val split = splitIn.asInstanceOf[SampledRDDSplit]
-    val rg = new Random(split.seed)
-    // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
     if (withReplacement) {
-      val oldData = prev.iterator(split.prev).toArray
-      val sampleSize = (oldData.size * frac).ceil.toInt
-      val sampledData = { 
-        // all of oldData's indices are candidates, even if sampleSize < oldData.size
-        for (i <- 1 to sampleSize)
-          yield oldData(rg.nextInt(oldData.size)) 
+      // For large datasets, the expected number of occurrences of each element in a sample with
+      // replacement is Poisson(frac). We use that to get a count for each element.
+      val poisson = new Poisson(frac, new DRand(split.seed))
+      prev.iterator(split.prev).flatMap { element =>
+        val count = poisson.nextInt()
+        if (count == 0) {
+          Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
+        } else {
+          Iterator.fill(count)(element)
+        }
       }
-      sampledData.iterator
     } else { // Sampling without replacement
-      prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
+      val rand = new Random(split.seed)
+      prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
     }
   }
 }
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 5a3f8bde4372ab922711f3e7282cda09ba1eb91b..eb7d69e816de60e0de7e298f9a6778bf624594c0 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -2,12 +2,11 @@ package spark
 
 import java.io._
 import java.net.{InetAddress, URL, URI}
-import java.util.{Locale, UUID}
+import java.util.{Locale, Random, UUID}
 import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
 import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
 import scala.io.Source
 
 /**
@@ -172,17 +171,22 @@ object Utils extends Logging {
    * result in a new collection. Unlike scala.util.Random.shuffle, this method
    * uses a local random number generator, avoiding inter-thread contention.
    */
-  def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
-    val buf = new ArrayBuffer[T]()
-    buf ++= seq
-    val rand = new Random()
-    for (i <- (buf.size - 1) to 1 by -1) {
+  def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+    randomizeInPlace(seq.toArray)
+  }
+
+  /**
+   * Shuffle the elements of an array into a random order, modifying the
+   * original array. Returns the original array.
+   */
+  def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
+    for (i <- (arr.length - 1) to 1 by -1) {
       val j = rand.nextInt(i)
-      val tmp = buf(j)
-      buf(j) = buf(i)
-      buf(i) = tmp
+      val tmp = arr(j)
+      arr(j) = arr(i)
+      arr(i) = tmp
     }
-    buf
+    arr
   }
 
   /**