diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 786b97ad7b9ece4e5c2d1c646286db415ca26b8a..c156b03cdb7c4f2511cdab1ca84c4cac7d021ef4 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
  * A sampler for sampling with replacement, based on values drawn from Poisson distribution.
  *
  * @param fraction the sampling fraction (with replacement)
+ * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low.
  * @tparam T item type
  */
 @DeveloperApi
-class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
+class PoissonSampler[T: ClassTag](
+    fraction: Double,
+    useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
+
+  def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true)
 
   /** Epsilon slop to avoid failure from floating point jitter. */
   require(
@@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T]
   override def sample(items: Iterator[T]): Iterator[T] = {
     if (fraction <= 0.0) {
       Iterator.empty
-    } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
-        new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
+    } else if (useGapSamplingIfPossible &&
+               fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+      new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
     } else {
-      items.flatMap { item => {
+      items.flatMap { item =>
         val count = rng.sample()
         if (count == 0) Iterator.empty else Iterator.fill(count)(item)
-      }}
+      }
     }
   }
 
-  override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction)
+  override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible)
 }
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 0680f31d40f6db6509d520b8a8c3df1d8b65db32..c5d1ed0937b19de5ffc25e5b89be159bc9c08fca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
@@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.collection.ExternalSorter
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
+import org.apache.spark.util.random.PoissonSampler
 import org.apache.spark.util.{CompletionIterator, MutablePair}
 import org.apache.spark.{HashPartitioner, SparkEnv}
 
@@ -130,12 +131,21 @@ case class Sample(
 {
   override def output: Seq[Attribute] = child.output
 
-  // TODO: How to pick seed?
+  override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+  override def canProcessUnsafeRows: Boolean = true
+  override def canProcessSafeRows: Boolean = true
+
   protected override def doExecute(): RDD[InternalRow] = {
     if (withReplacement) {
-      child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+      // Disable gap sampling since the gap sampling method buffers two rows internally,
+      // requiring us to copy the row, which is more expensive than the random number generator.
+      new PartitionwiseSampledRDD[InternalRow, InternalRow](
+        child.execute(),
+        new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
+        preservesPartitioning = true,
+        seed)
     } else {
-      child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+      child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0e7659f443ecd764619c9ff54ba9e5a481f46c9a..8f5984e4a8ce27e2dcd472eee9610f339eaa2f4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest {
 
   private def toLetter(i: Int): String = (i + 97).toChar.toString
 
+  test("sample with replacement") {
+    val n = 100
+    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    checkAnswer(
+      data.sample(withReplacement = true, 0.05, seed = 13),
+      Seq(5, 10, 52, 73).map(Row(_))
+    )
+  }
+
+  test("sample without replacement") {
+    val n = 100
+    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    checkAnswer(
+      data.sample(withReplacement = false, 0.05, seed = 13),
+      Seq(16, 23, 88, 100).map(Row(_))
+    )
+  }
+
+  test("randomSplit") {
+    val n = 600
+    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    for (seed <- 1 to 5) {
+      val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+      assert(splits.length == 3, "wrong number of splits")
+
+      assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
+        data.collect().toList, "incomplete or wrong split")
+
+      val s = splits.map(_.count())
+      assert(math.abs(s(0) - 100) < 50) // std =  9.13
+      assert(math.abs(s(1) - 200) < 50) // std = 11.55
+      assert(math.abs(s(2) - 300) < 50) // std = 12.25
+    }
+  }
+
   test("pearson correlation") {
     val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
     val corr1 = df.stat.corr("a", "b", "pearson")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f9cc6d1f3c250156d3cdc0b0104f0ab4f0f00d6b..0212637a829e5653a0446e3f1b59fc17bfdd6df6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
-  test("randomSplit") {
-    val n = 600
-    val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id")
-    for (seed <- 1 to 5) {
-      val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
-      assert(splits.length == 3, "wrong number of splits")
-
-      assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
-        data.collect().toList, "incomplete or wrong split")
-
-      val s = splits.map(_.count())
-      assert(math.abs(s(0) - 100) < 50) // std =  9.13
-      assert(math.abs(s(1) - 200) < 50) // std = 11.55
-      assert(math.abs(s(2) - 300) < 50) // std = 12.25
-    }
-  }
-
   test("describe") {
     val describeTestData = Seq(
       ("Bob", 16, 176),