From cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9 Mon Sep 17 00:00:00 2001
From: Takeshi YAMAMURO <linguin.m.s@gmail.com>
Date: Sat, 27 Aug 2016 08:42:41 +0100
Subject: [PATCH] [SPARK-15382][SQL] Fix a bug in sampling with replacement

## What changes were proposed in this pull request?
This pr to fix a bug below in sampling with replacement
```
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
df.sample(true, 2.0).withColumn("c", monotonically_increasing_id).select($"c").show
+---+
|  c|
+---+
|  0|
|  1|
|  1|
|  1|
|  2|
+---+
```

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #14800 from maropu/FixSampleBug.
---
 .../spark/sql/execution/basicPhysicalOperators.scala       | 1 +
 .../test/scala/org/apache/spark/sql/DataFrameSuite.scala   | 7 +++++++
 2 files changed, 8 insertions(+)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 3562083b06..dd78a78491 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -266,6 +266,7 @@ case class SampleExec(
     if (withReplacement) {
       val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
       val initSampler = ctx.freshName("initSampler")
+      ctx.copyResult = true
       ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
         s"$initSampler();")
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index cd485770d2..ce0b92a461 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1579,4 +1579,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val df = spark.createDataFrame(rdd, StructType(schemas), false)
     assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
   }
+
+  test("copy results for sampling with replacement") {
+    val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
+    val sampleDf = df.sample(true, 2.00)
+    val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
+    assert(d.size == d.distinct.size)
+  }
 }
-- 
GitLab