diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 4f3081433a542bb6c8c9f43b86ede7c5d924faaa..31bf8dced26388666784a9149260b36d751468e0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.api.java
 
-import java.util.{Comparator, List => JList}
+import java.util.{Comparator, List => JList, Map => JMap}
 import java.lang.{Iterable => JIterable}
 
 import scala.collection.JavaConversions._
@@ -129,6 +129,73 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
     new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
 
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * If `exact` is set to false, create the sample via simple random sampling, with one pass
+   * over the RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
+   * the RDD to create a sample size that's exactly equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values.
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: JMap[K, Double],
+      exact: Boolean,
+      seed: Long): JavaPairRDD[K, V] =
+    new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
+
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * If `exact` is set to false, create the sample via simple random sampling, with one pass
+   * over the RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
+   * the RDD to create a sample size that's exactly equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values.
+   *
+   * Use Utils.random.nextLong as the default seed for the random number generator
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: JMap[K, Double],
+      exact: Boolean): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
+
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * Produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
+   * simple random sampling.
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: JMap[K, Double],
+      seed: Long): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, false, seed)
+
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * Produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
+   * simple random sampling.
+   *
+   * Use Utils.random.nextLong as the default seed for the random number generator
+   */
+  def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
+
   /**
    * Return the union of this RDD and another one. Any identical elements will appear multiple
    * times (use `.distinct()` to eliminate them).
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c04d162a39616c8a6ec0952e1677b25980c3b97a..1af4e5f0b6d08e66a09749a485b330002c97541b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -19,12 +19,10 @@ package org.apache.spark.rdd
 
 import java.nio.ByteBuffer
 import java.text.SimpleDateFormat
-import java.util.Date
-import java.util.{HashMap => JHashMap}
+import java.util.{Date, HashMap => JHashMap}
 
+import scala.collection.{Map, mutable}
 import scala.collection.JavaConversions._
-import scala.collection.Map
-import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
 
@@ -34,19 +32,19 @@ import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.io.SequenceFile.CompressionType
 import org.apache.hadoop.io.compress.CompressionCodec
 import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob,
+import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
 RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
 
 import org.apache.spark._
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.SparkHadoopWriter
 import org.apache.spark.Partitioner.defaultPartitioner
 import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.random.StratifiedSamplingUtils
 
 /**
  * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -195,6 +193,41 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
     foldByKey(zeroValue, defaultPartitioner(self))(func)
   }
 
+  /**
+   * Return a subset of this RDD sampled by key (via stratified sampling).
+   *
+   * Create a sample of this RDD using variable sampling rates for different keys as specified by
+   * `fractions`, a key to sampling rate map.
+   *
+   * If `exact` is set to false, create the sample via simple random sampling, with one pass
+   * over the RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values; otherwise, use
+   * additional passes over the RDD to create a sample size that's exactly equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
+   * without replacement, we need one additional pass over the RDD to guarantee sample size;
+   * when sampling with replacement, we need two additional passes.
+   *
+   * @param withReplacement whether to sample with or without replacement
+   * @param fractions map of specific keys to sampling rates
+   * @param seed seed for the random number generator
+   * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
+   * @return RDD containing the sampled subset
+   */
+  def sampleByKey(withReplacement: Boolean,
+      fractions: Map[K, Double],
+      exact: Boolean = false,
+      seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
+
+    require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
+
+    val samplingFunc = if (withReplacement) {
+      StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
+    } else {
+      StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
+    }
+    self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+  }
+
   /**
    * Merge the values for each key using an associative reduce function. This will also perform
    * the merging locally on each mapper before sending results to a reducer, similarly to a
@@ -531,6 +564,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
 
   /**
    * Return the key-value pairs in this RDD to the master as a Map.
+   *
+   * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
+   *          one value per key is preserved in the map returned)
    */
   def collectAsMap(): Map[K, V] = {
     val data = self.collect()
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index d10141b90e621f2bb1346842ec3487a919cefc95..c9a864ae62778706b1c6a54111339f67acd01fff 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -81,6 +81,9 @@ private[spark] object SamplingUtils {
    *     ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
    *     rate, where success rate is defined the same as in sampling with replacement.
    *
+   * The smallest sampling rate supported is 1e-10 (in order to avoid running into the limit of the
+   * RNG's resolution).
+   *
    * @param sampleSizeLowerBound sample size
    * @param total size of RDD
    * @param withReplacement whether sampling with replacement
@@ -88,14 +91,73 @@ private[spark] object SamplingUtils {
    */
   def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
       withReplacement: Boolean): Double = {
-    val fraction = sampleSizeLowerBound.toDouble / total
     if (withReplacement) {
-      val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
-      fraction + numStDev * math.sqrt(fraction / total)
+      PoissonBounds.getUpperBound(sampleSizeLowerBound) / total
     } else {
-      val delta = 1e-4
-      val gamma = - math.log(delta) / total
-      math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
+      val fraction = sampleSizeLowerBound.toDouble / total
+      BinomialBounds.getUpperBound(1e-4, total, fraction)
     }
   }
 }
+
+/**
+ * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
+ * sample sizes with high confidence when sampling with replacement.
+ */
+private[spark] object PoissonBounds {
+
+  /**
+   * Returns a lambda such that Pr[X > s] is very small, where X ~ Pois(lambda).
+   */
+  def getLowerBound(s: Double): Double = {
+    math.max(s - numStd(s) * math.sqrt(s), 1e-15)
+  }
+
+  /**
+   * Returns a lambda such that Pr[X < s] is very small, where X ~ Pois(lambda).
+   *
+   * @param s sample size
+   */
+  def getUpperBound(s: Double): Double = {
+    math.max(s + numStd(s) * math.sqrt(s), 1e-10)
+  }
+
+  private def numStd(s: Double): Double = {
+    // TODO: Make it tighter.
+    if (s < 6.0) {
+      12.0
+    } else if (s < 16.0) {
+      9.0
+    } else {
+      6.0
+    }
+  }
+}
+
+/**
+ * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact
+ * sample size with high confidence when sampling without replacement.
+ */
+private[spark] object BinomialBounds {
+
+  val minSamplingRate = 1e-10
+
+  /**
+   * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
+   * it is very unlikely to have more than `fraction * n` successes.
+   */
+  def getLowerBound(delta: Double, n: Long, fraction: Double): Double = {
+    val gamma = - math.log(delta) / n * (2.0 / 3.0)
+    fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction)
+  }
+
+  /**
+   * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`,
+   * it is very unlikely to have less than `fraction * n` successes.
+   */
+  def getUpperBound(delta: Double, n: Long, fraction: Double): Double = {
+    val gamma = - math.log(delta) / n
+    math.min(1,
+      math.max(minSamplingRate, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8f95d7c6b799b776a42f7c9b59aee9d2e9fd2528
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import scala.collection.Map
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Auxiliary functions and data structures for the sampleByKey method in PairRDDFunctions.
+ *
+ * Essentially, when exact sample size is necessary, we make additional passes over the RDD to
+ * compute the exact threshold value to use for each stratum to guarantee exact sample size with
+ * high probability. This is achieved by maintaining a waitlist of size O(log(s)), where s is the
+ * desired sample size for each stratum.
+ *
+ * Like in simple random sampling, we generate a random value for each item from the
+ * uniform  distribution [0.0, 1.0]. All items with values <= min(values of items in the waitlist)
+ * are accepted into the sample instantly. The threshold for instant accept is designed so that
+ * s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by maintaining a
+ * waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size s by adding
+ * a portion of the waitlist to the set of items that are instantly accepted. The exact threshold
+ * is computed by sorting the values in the waitlist and picking the value at (s - numAccepted).
+ *
+ * Note that since we use the same seed for the RNG when computing the thresholds and the actual
+ * sample, our computed thresholds are guaranteed to produce the desired sample size.
+ *
+ * For more theoretical background on the sampling techniques used here, please refer to
+ * http://jmlr.org/proceedings/papers/v28/meng13a.html
+ */
+
+private[spark] object StratifiedSamplingUtils extends Logging {
+
+  /**
+   * Count the number of items instantly accepted and generate the waitlist for each stratum.
+   *
+   * This is only invoked when exact sample size is required.
+   */
+  def getAcceptanceResults[K, V](rdd: RDD[(K, V)],
+      withReplacement: Boolean,
+      fractions: Map[K, Double],
+      counts: Option[Map[K, Long]],
+      seed: Long): mutable.Map[K, AcceptanceResult] = {
+    val combOp = getCombOp[K]
+    val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) =>
+      val zeroU: mutable.Map[K, AcceptanceResult] = new mutable.HashMap[K, AcceptanceResult]()
+      val rng = new RandomDataGenerator()
+      rng.reSeed(seed + partition)
+      val seqOp = getSeqOp(withReplacement, fractions, rng, counts)
+      Iterator(iter.aggregate(zeroU)(seqOp, combOp))
+    }
+    mappedPartitionRDD.reduce(combOp)
+  }
+
+  /**
+   * Returns the function used by aggregate to collect sampling statistics for each partition.
+   */
+  def getSeqOp[K, V](withReplacement: Boolean,
+      fractions: Map[K, Double],
+      rng: RandomDataGenerator,
+      counts: Option[Map[K, Long]]):
+    (mutable.Map[K, AcceptanceResult], (K, V)) => mutable.Map[K, AcceptanceResult] = {
+    val delta = 5e-5
+    (result: mutable.Map[K, AcceptanceResult], item: (K, V)) => {
+      val key = item._1
+      val fraction = fractions(key)
+      if (!result.contains(key)) {
+        result += (key -> new AcceptanceResult())
+      }
+      val acceptResult = result(key)
+
+      if (withReplacement) {
+        // compute acceptBound and waitListBound only if they haven't been computed already
+        // since they don't change from iteration to iteration.
+        // TODO change this to the streaming version
+        if (acceptResult.areBoundsEmpty) {
+          val n = counts.get(key)
+          val sampleSize = math.ceil(n * fraction).toLong
+          val lmbd1 = PoissonBounds.getLowerBound(sampleSize)
+          val lmbd2 = PoissonBounds.getUpperBound(sampleSize)
+          acceptResult.acceptBound = lmbd1 / n
+          acceptResult.waitListBound = (lmbd2 - lmbd1) / n
+        }
+        val acceptBound = acceptResult.acceptBound
+        val copiesAccepted = if (acceptBound == 0.0) 0L else rng.nextPoisson(acceptBound)
+        if (copiesAccepted > 0) {
+          acceptResult.numAccepted += copiesAccepted
+        }
+        val copiesWaitlisted = rng.nextPoisson(acceptResult.waitListBound)
+        if (copiesWaitlisted > 0) {
+          acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rng.nextUniform())
+        }
+      } else {
+        // We use the streaming version of the algorithm for sampling without replacement to avoid
+        // using an extra pass over the RDD for computing the count.
+        // Hence, acceptBound and waitListBound change on every iteration.
+        acceptResult.acceptBound =
+          BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction)
+        acceptResult.waitListBound =
+          BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction)
+
+        val x = rng.nextUniform()
+        if (x < acceptResult.acceptBound) {
+          acceptResult.numAccepted += 1
+        } else if (x < acceptResult.waitListBound) {
+          acceptResult.waitList += x
+        }
+      }
+      acceptResult.numItems += 1
+      result
+    }
+  }
+
+  /**
+   * Returns the function used combine results returned by seqOp from different partitions.
+   */
+  def getCombOp[K]: (mutable.Map[K, AcceptanceResult], mutable.Map[K, AcceptanceResult])
+    => mutable.Map[K, AcceptanceResult] = {
+    (result1: mutable.Map[K, AcceptanceResult], result2: mutable.Map[K, AcceptanceResult]) => {
+      // take union of both key sets in case one partition doesn't contain all keys
+      result1.keySet.union(result2.keySet).foreach { key =>
+        // Use result2 to keep the combined result since r1 is usual empty
+        val entry1 = result1.get(key)
+        if (result2.contains(key)) {
+          result2(key).merge(entry1)
+        } else {
+          if (entry1.isDefined) {
+            result2 += (key -> entry1.get)
+          }
+        }
+      }
+      result2
+    }
+  }
+
+  /**
+   * Given the result returned by getCounts, determine the threshold for accepting items to
+   * generate exact sample size.
+   *
+   * To do so, we compute sampleSize = math.ceil(size * samplingRate) for each stratum and compare
+   * it to the number of items that were accepted instantly and the number of items in the waitlist
+   * for that stratum. Most of the time, numAccepted <= sampleSize <= (numAccepted + numWaitlisted),
+   * which means we need to sort the elements in the waitlist by their associated values in order
+   * to find the value T s.t. |{elements in the stratum whose associated values <= T}| = sampleSize.
+   * Note that all elements in the waitlist have values >= bound for instant accept, so a T value
+   * in the waitlist range would allow all elements that were instantly accepted on the first pass
+   * to be included in the sample.
+   */
+  def computeThresholdByKey[K](finalResult: Map[K, AcceptanceResult],
+      fractions: Map[K, Double]): Map[K, Double] = {
+    val thresholdByKey = new mutable.HashMap[K, Double]()
+    for ((key, acceptResult) <- finalResult) {
+      val sampleSize = math.ceil(acceptResult.numItems * fractions(key)).toLong
+      if (acceptResult.numAccepted > sampleSize) {
+        logWarning("Pre-accepted too many")
+        thresholdByKey += (key -> acceptResult.acceptBound)
+      } else {
+        val numWaitListAccepted = (sampleSize - acceptResult.numAccepted).toInt
+        if (numWaitListAccepted >= acceptResult.waitList.size) {
+          logWarning("WaitList too short")
+          thresholdByKey += (key -> acceptResult.waitListBound)
+        } else {
+          thresholdByKey += (key -> acceptResult.waitList.sorted.apply(numWaitListAccepted))
+        }
+      }
+    }
+    thresholdByKey
+  }
+
+  /**
+   * Return the per partition sampling function used for sampling without replacement.
+   *
+   * When exact sample size is required, we make an additional pass over the RDD to determine the
+   * exact sampling rate that guarantees sample size with high confidence.
+   *
+   * The sampling function has a unique seed per partition.
+   */
+  def getBernoulliSamplingFunction[K, V](rdd: RDD[(K,  V)],
+      fractions: Map[K, Double],
+      exact: Boolean,
+      seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
+    var samplingRateByKey = fractions
+    if (exact) {
+      // determine threshold for each stratum and resample
+      val finalResult = getAcceptanceResults(rdd, false, fractions, None, seed)
+      samplingRateByKey = computeThresholdByKey(finalResult, fractions)
+    }
+    (idx: Int, iter: Iterator[(K, V)]) => {
+      val rng = new RandomDataGenerator
+      rng.reSeed(seed + idx)
+      // Must use the same invoke pattern on the rng as in getSeqOp for without replacement
+      // in order to generate the same sequence of random numbers when creating the sample
+      iter.filter(t => rng.nextUniform() < samplingRateByKey(t._1))
+    }
+  }
+
+  /**
+   * Return the per partition sampling function used for sampling with replacement.
+   *
+   * When exact sample size is required, we make two additional passed over the RDD to determine
+   * the exact sampling rate that guarantees sample size with high confidence. The first pass
+   * counts the number of items in each stratum (group of items with the same key) in the RDD, and
+   * the second pass uses the counts to determine exact sampling rates.
+   *
+   * The sampling function has a unique seed per partition.
+   */
+  def getPoissonSamplingFunction[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)],
+      fractions: Map[K, Double],
+      exact: Boolean,
+      seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = {
+    // TODO implement the streaming version of sampling w/ replacement that doesn't require counts
+    if (exact) {
+      val counts = Some(rdd.countByKey())
+      val finalResult = getAcceptanceResults(rdd, true, fractions, counts, seed)
+      val thresholdByKey = computeThresholdByKey(finalResult, fractions)
+      (idx: Int, iter: Iterator[(K, V)]) => {
+        val rng = new RandomDataGenerator()
+        rng.reSeed(seed + idx)
+        iter.flatMap { item =>
+          val key = item._1
+          val acceptBound = finalResult(key).acceptBound
+          // Must use the same invoke pattern on the rng as in getSeqOp for with replacement
+          // in order to generate the same sequence of random numbers when creating the sample
+          val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound)
+          val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound)
+          val copiesInSample = copiesAccepted +
+            (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key))
+          if (copiesInSample > 0) {
+            Iterator.fill(copiesInSample.toInt)(item)
+          } else {
+            Iterator.empty
+          }
+        }
+      }
+    } else {
+      (idx: Int, iter: Iterator[(K, V)]) => {
+        val rng = new RandomDataGenerator()
+        rng.reSeed(seed + idx)
+        iter.flatMap { item =>
+          val count = rng.nextPoisson(fractions(item._1))
+          if (count > 0) {
+            Iterator.fill(count)(item)
+          } else {
+            Iterator.empty
+          }
+        }
+      }
+    }
+  }
+
+  /** A random data generator that generates both uniform values and Poisson values. */
+  private class RandomDataGenerator {
+    val uniform = new XORShiftRandom()
+    var poisson = new Poisson(1.0, new DRand)
+
+    def reSeed(seed: Long) {
+      uniform.setSeed(seed)
+      poisson = new Poisson(1.0, new DRand(seed.toInt))
+    }
+
+    def nextPoisson(mean: Double): Int = {
+      poisson.nextInt(mean)
+    }
+
+    def nextUniform(): Double = {
+      uniform.nextDouble()
+    }
+  }
+}
+
+/**
+ * Object used by seqOp to keep track of the number of items accepted and items waitlisted per
+ * stratum, as well as the bounds for accepting and waitlisting items.
+ *
+ * `[random]` here is necessary since it's in the return type signature of seqOp defined above
+ */
+private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L)
+  extends Serializable {
+
+  val waitList = new ArrayBuffer[Double]
+  var acceptBound: Double = Double.NaN // upper bound for accepting item instantly
+  var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist
+
+  def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN
+
+  def merge(other: Option[AcceptanceResult]): Unit = {
+    if (other.isDefined) {
+      waitList ++= other.get.waitList
+      numAccepted += other.get.numAccepted
+      numItems += other.get.numItems
+    }
+  }
+}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index f882a8623fd841778d728e4c2681c1bd83d90a7b..e8bd65f8e45070d2ae31e1f317c8b47f315bc0e1 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -29,6 +29,7 @@ import scala.Tuple4;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 import com.google.common.base.Optional;
 import com.google.common.base.Charsets;
 import com.google.common.io.Files;
@@ -1208,4 +1209,40 @@ public class JavaAPISuite implements Serializable {
     pairRDD.collect();  // Works fine
     pairRDD.collectAsMap();  // Used to crash with ClassCastException
   }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void sampleByKey() {
+    JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
+    JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
+      new PairFunction<Integer, Integer, Integer>() {
+        @Override
+        public Tuple2<Integer, Integer> call(Integer i) {
+          return new Tuple2<Integer, Integer>(i % 2, 1);
+        }
+      });
+    Map<Integer, Object> fractions = Maps.newHashMap();
+    fractions.put(0, 0.5);
+    fractions.put(1, 1.0);
+    JavaPairRDD<Integer, Integer> wr = rdd2.sampleByKey(true, fractions, 1L);
+    Map<Integer, Long> wrCounts = (Map<Integer, Long>) (Object) wr.countByKey();
+    Assert.assertTrue(wrCounts.size() == 2);
+    Assert.assertTrue(wrCounts.get(0) > 0);
+    Assert.assertTrue(wrCounts.get(1) > 0);
+    JavaPairRDD<Integer, Integer> wor = rdd2.sampleByKey(false, fractions, 1L);
+    Map<Integer, Long> worCounts = (Map<Integer, Long>) (Object) wor.countByKey();
+    Assert.assertTrue(worCounts.size() == 2);
+    Assert.assertTrue(worCounts.get(0) > 0);
+    Assert.assertTrue(worCounts.get(1) > 0);
+    JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, true, 1L);
+    Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) wrExact.countByKey();
+    Assert.assertTrue(wrExactCounts.size() == 2);
+    Assert.assertTrue(wrExactCounts.get(0) == 2);
+    Assert.assertTrue(wrExactCounts.get(1) == 4);
+    JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, fractions, true, 1L);
+    Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) worExact.countByKey();
+    Assert.assertTrue(worExactCounts.size() == 2);
+    Assert.assertTrue(worExactCounts.get(0) == 2);
+    Assert.assertTrue(worExactCounts.get(1) == 4);
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 447e38ec9dbd006880a146ba0ebb2f64f4a7d9fe..4f49d4a1d4d34764738623256c1698d714afae3a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -83,6 +83,122 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
     assert(valuesFor2.toList.sorted === List(1))
   }
 
+  test("sampleByKey") {
+    def stratifier (fractionPositive: Double) = {
+      (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
+    }
+
+    def checkSize(exact: Boolean,
+        withReplacement: Boolean,
+        expected: Long,
+        actual: Long,
+        p: Double): Boolean = {
+      if (exact) {
+        return expected == actual
+      }
+      val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p))
+      // Very forgiving margin since we're dealing with very small sample sizes most of the time
+      math.abs(actual - expected) <= 6 * stdev
+    }
+
+    // Without replacement validation
+    def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)],
+        exact: Boolean,
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      val expectedSampleSize = stratifiedData.countByKey()
+        .mapValues(count => math.ceil(count * samplingRate).toInt)
+      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+      val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
+      val sampleCounts = sample.countByKey()
+      val takeSample = sample.collect()
+      sampleCounts.foreach { case(k, v) =>
+        assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) }
+      assert(takeSample.size === takeSample.toSet.size)
+      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+    }
+
+    // With replacement validation
+    def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)],
+        exact: Boolean,
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
+        math.ceil(count * samplingRate).toInt)
+      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+      val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
+      val sampleCounts = sample.countByKey()
+      val takeSample = sample.collect()
+      sampleCounts.foreach { case(k, v) =>
+        assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) }
+      val groupedByKey = takeSample.groupBy(_._1)
+      for ((key, v) <- groupedByKey) {
+        if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
+          // sample large enough for there to be repeats with high likelihood
+          assert(v.toSet.size < expectedSampleSize(key))
+        } else {
+          if (exact) {
+            assert(v.toSet.size <= expectedSampleSize(key))
+          } else {
+            assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate))
+          }
+        }
+      }
+      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
+    }
+
+    def checkAllCombos(stratifiedData: RDD[(String, Int)],
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n)
+      takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n)
+      takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n)
+      takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n)
+    }
+
+    val defaultSeed = 1L
+
+    // vary RDD size
+    for (n <- List(100, 1000, 1000000)) {
+      val data = sc.parallelize(1 to n, 2)
+      val fractionPositive = 0.3
+      val stratifiedData = data.keyBy(stratifier(fractionPositive))
+
+      val samplingRate = 0.1
+      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+    }
+
+    // vary fractionPositive
+    for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
+      val n = 100
+      val data = sc.parallelize(1 to n, 2)
+      val stratifiedData = data.keyBy(stratifier(fractionPositive))
+
+      val samplingRate = 0.1
+      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+    }
+
+    // Use the same data for the rest of the tests
+    val fractionPositive = 0.3
+    val n = 100
+    val data = sc.parallelize(1 to n, 2)
+    val stratifiedData = data.keyBy(stratifier(fractionPositive))
+
+    // vary seed
+    for (seed <- defaultSeed to defaultSeed + 5L) {
+      val samplingRate = 0.1
+      checkAllCombos(stratifiedData, samplingRate, seed, n)
+    }
+
+    // vary sampling rate
+    for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
+      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+    }
+  }
+
   test("reduceByKey") {
     val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
     val sums = pairs.reduceByKey(_+_).collect()
diff --git a/pom.xml b/pom.xml
index 8b1435cfe5d19ba209bdbbe132c00d73fa1db67e..39538f96606235c7d0a5253bb62c38eadc467e57 100644
--- a/pom.xml
+++ b/pom.xml
@@ -257,6 +257,12 @@
         <artifactId>commons-codec</artifactId>
         <version>1.5</version>
       </dependency>
+      <dependency>
+        <groupId>org.apache.commons</groupId>
+        <artifactId>commons-math3</artifactId>
+        <version>3.3</version>
+        <scope>test</scope>
+      </dependency>
       <dependency>
         <groupId>com.google.code.findbugs</groupId>
         <artifactId>jsr305</artifactId>