diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e53723c1765693d1ed648327bb0e2fe84c7c0dd8..16ca4be5587c4ef46836041bc13b0fb8c0c767ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,7 +70,7 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: override def output = child.output // TODO: How to pick seed? - override def execute() = child.execute().sample(withReplacement, fraction, seed) + override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5d0fb7237011f6c907125e35a44e3022216614d9..c1c3683f84ab279e5c687506c722b22e373bcd2d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -202,4 +203,15 @@ class SQLQuerySuite extends QueryTest { checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), sql("SELECT distinct key FROM src order by key").collect().toSeq) } + + test("SPARK-4963 SchemaRDD sample on mutable row return wrong result") { + sql("SELECT * FROM src WHERE key % 2 = 0") + .sample(withReplacement = false, fraction = 0.3) + .registerTempTable("sampled") + (1 to 10).foreach { i => + checkAnswer( + sql("SELECT * FROM sampled WHERE key % 2 = 1"), + Seq.empty[Row]) + } + } }