diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 2921b939bc0cf49c8acbadf4c1438c51ed7fd7ec..d397cca4b444d9473db9bbc198f9d2b8e3ce7bf8 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -186,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
  * @tparam T item type
  */
 @DeveloperApi
-class PoissonSampler[T: ClassTag](
+class PoissonSampler[T](
     fraction: Double,
     useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
 
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index ff11775412d0860f8aed8ef28b9ce33ce7e7b14c..2be490b94264aba8751bf7a59506c95a8f090e80 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -597,6 +597,10 @@ object MimaExcludes {
         // for multilayer perceptron.
         // This class is marked as `private`.
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction")
+      ) ++ Seq(
+        // [SPARK-13674][SQL] Add wholestage codegen support to Sample
+        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
       )
     case v if v.startsWith("1.6") =>
       Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index dbea8521be2064893f68500434e16682145beff5..c2633a9f8cd48bb34055367ff725ee290492de5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -36,6 +36,8 @@ public abstract class BufferedRowIterator {
   protected UnsafeRow unsafeRow = new UnsafeRow(0);
   private long startTimeNs = System.nanoTime();
 
+  protected int partitionIndex = -1;
+
   public boolean hasNext() throws IOException {
     if (currentRows.isEmpty()) {
       processNext();
@@ -58,7 +60,7 @@ public abstract class BufferedRowIterator {
   /**
    * Initializes from array of iterators of InternalRow.
    */
-  public abstract void init(Iterator<InternalRow> iters[]);
+  public abstract void init(int index, Iterator<InternalRow> iters[]);
 
   /**
    * Append a row to currentRows.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6a779abd40a3c7c17131ffba25b662cf7f64e000..9bdf611f6e536db602883b9984d3faf14e673df0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -323,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
           this.references = references;
         }
 
-        public void init(scala.collection.Iterator inputs[]) {
+        public void init(int index, scala.collection.Iterator inputs[]) {
+          partitionIndex = index;
           ${ctx.initMutableStates()}
         }
 
@@ -351,10 +352,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
     val rdds = child.asInstanceOf[CodegenSupport].upstreams()
     assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
     if (rdds.length == 1) {
-      rdds.head.mapPartitions { iter =>
+      rdds.head.mapPartitionsWithIndex { (index, iter) =>
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(Array(iter))
+        buffer.init(index, Array(iter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext
@@ -367,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
     } else {
       // Right now, we support up to two upstreams.
       rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+        val partitionIndex = TaskContext.getPartitionId()
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(Array(leftIter, rightIter))
+        buffer.init(partitionIndex, Array(leftIter, rightIter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index fca662760dcc42fd484bb310311cff353b1ba4f0..a6a14df6a33ea0801c75d54b2512e86f92402174 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.random.PoissonSampler
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
 
 case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
   extends UnaryNode with CodegenSupport {
@@ -223,9 +223,12 @@ case class Sample(
     upperBound: Double,
     withReplacement: Boolean,
     seed: Long,
-    child: SparkPlan) extends UnaryNode {
+    child: SparkPlan) extends UnaryNode with CodegenSupport {
   override def output: Seq[Attribute] = child.output
 
+  private[sql] override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
   protected override def doExecute(): RDD[InternalRow] = {
     if (withReplacement) {
       // Disable gap sampling since the gap sampling method buffers two rows internally,
@@ -239,6 +242,63 @@ case class Sample(
       child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
     }
   }
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    val sampler = ctx.freshName("sampler")
+
+    if (withReplacement) {
+      val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
+      val initSampler = ctx.freshName("initSampler")
+      ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+        s"$initSampler();")
+
+      ctx.addNewFunction(initSampler,
+        s"""
+          | private void $initSampler() {
+          |   $sampler = new $samplerClass<UnsafeRow>($upperBound - $lowerBound, false);
+          |   java.util.Random random = new java.util.Random(${seed}L);
+          |   long randomSeed = random.nextLong();
+          |   int loopCount = 0;
+          |   while (loopCount < partitionIndex) {
+          |     randomSeed = random.nextLong();
+          |     loopCount += 1;
+          |   }
+          |   $sampler.setSeed(randomSeed);
+          | }
+         """.stripMargin.trim)
+
+      val samplingCount = ctx.freshName("samplingCount")
+      s"""
+         | int $samplingCount = $sampler.sample();
+         | while ($samplingCount-- > 0) {
+         |   $numOutput.add(1);
+         |   ${consume(ctx, input)}
+         | }
+       """.stripMargin.trim
+    } else {
+      val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName
+      ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+        s"""
+          | $sampler = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, false);
+          | $sampler.setSeed(${seed}L + partitionIndex);
+         """.stripMargin.trim)
+
+      s"""
+         | if ($sampler.sample() == 0) continue;
+         | $numOutput.add(1);
+         | ${consume(ctx, input)}
+       """.stripMargin.trim
+    }
+  }
 }
 
 case class Range(
@@ -320,11 +380,7 @@ case class Range(
       | // initialize Range
       | if (!$initTerm) {
       |   $initTerm = true;
-      |   if ($input.hasNext()) {
-      |     initRange(((InternalRow) $input.next()).getInt(0));
-      |   } else {
-      |     return;
-      |   }
+      |   initRange(partitionIndex);
       | }
       |
       | while (!$overflow && $checkEnd) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 003d3e062e0459e9f8151ed11fa597b156f715ce..55906793c0b81785c0d42a0d4f280d68e57a81d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -85,6 +85,31 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     */
   }
 
+  ignore("range/sample/sum") {
+    val N = 500 << 20
+    runBenchmark("range/sample/sum", N) {
+      sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect()
+    }
+    /*
+    Westmere E56xx/L56xx/X56xx (Nehalem-C)
+    range/sample/sum:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    range/sample/sum codegen=false         53888 / 56592          9.7         102.8       1.0X
+    range/sample/sum codegen=true          41614 / 42607         12.6          79.4       1.3X
+    */
+
+    runBenchmark("range/sample/sum", N) {
+      sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect()
+    }
+    /*
+    Westmere E56xx/L56xx/X56xx (Nehalem-C)
+    range/sample/sum:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    range/sample/sum codegen=false         12982 / 13384         40.4          24.8       1.0X
+    range/sample/sum codegen=true            7074 / 7383         74.1          13.5       1.8X
+    */
+  }
+
   ignore("stat functions") {
     val N = 100L << 20