Skip to content
Snippets Groups Projects
Commit 6987c067 authored by Davies Liu's avatar Davies Liu Committed by Yin Huai
Browse files

[SPARK-11009] [SQL] fix wrong result of Window function in cluster mode

Currently, All windows function could generate wrong result in cluster sometimes.

The root cause is that AttributeReference is called in executor, then id of it may not be unique than others created in driver.

Here is the script that could reproduce the problem (run in local cluster):
```
from pyspark import SparkContext, HiveContext
from pyspark.sql.window import Window
from pyspark.sql.functions import rowNumber

sqlContext = HiveContext(SparkContext())
sqlContext.setConf("spark.sql.shuffle.partitions", "3")
df =  sqlContext.range(1<<20)
df2 = df.select((df.id % 1000).alias("A"), (df.id / 1000).alias('B'))
ws = Window.partitionBy(df2.A).orderBy(df2.B)
df3 = df2.select("client", "date", rowNumber().over(ws).alias("rn")).filter("rn < 0")
assert df3.count() == 0
```

Author: Davies Liu <davies@databricks.com>
Author: Yin Huai <yhuai@databricks.com>

Closes #9050 from davies/wrong_window.
parent 626aab79
No related branches found
No related tags found
No related merge requests found
...@@ -145,11 +145,10 @@ case class Window( ...@@ -145,11 +145,10 @@ case class Window(
// Construct the ordering. This is used to compare the result of current value projection // Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use // to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled). // Code Generation (if it is enabled).
val (sortExprs, schema) = exprs.map { case e => val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
(SortOrder(ref, e.direction), ref) }
}.unzip val ordering = newOrdering(sortExprs, Nil)
val ordering = newOrdering(sortExprs, schema)
RangeBoundOrdering(ordering, current, bound) RangeBoundOrdering(ordering, current, bound)
case RowFrame => RowBoundOrdering(offset) case RowFrame => RowBoundOrdering(offset)
} }
...@@ -205,14 +204,15 @@ case class Window( ...@@ -205,14 +204,15 @@ case class Window(
*/ */
private[this] def createResultProjection( private[this] def createResultProjection(
expressions: Seq[Expression]): MutableProjection = { expressions: Seq[Expression]): MutableProjection = {
val unboundToAttr = expressions.map { val references = expressions.zipWithIndex.map{ case (e, i) =>
e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) // Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
} }
val unboundToAttrMap = unboundToAttr.toMap val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
newMutableProjection( newMutableProjection(
projectList ++ patchedWindowExpression, projectList ++ patchedWindowExpression,
child.output ++ unboundToAttr.map(_._2))() child.output)()
} }
protected override def doExecute(): RDD[InternalRow] = { protected override def doExecute(): RDD[InternalRow] = {
......
...@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ ...@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.sql.{SQLContext, QueryTest} import org.apache.spark.sql.{SQLContext, QueryTest}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.types.DecimalType
...@@ -107,6 +108,16 @@ class HiveSparkSubmitSuite ...@@ -107,6 +108,16 @@ class HiveSparkSubmitSuite
runSparkSubmit(args) runSparkSubmit(args)
} }
test("SPARK-11009 fix wrong result of Window function in cluster mode") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SPARK_11009.getClass.getName.stripSuffix("$"),
"--name", "SparkSQLConfTest",
"--master", "local-cluster[2,1,1024]",
unusedJar.toString)
runSparkSubmit(args)
}
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite // This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = { private def runSparkSubmit(args: Seq[String]): Unit = {
...@@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest { ...@@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest {
} }
} }
} }
object SPARK_11009 extends QueryTest {
import org.apache.spark.sql.functions._
protected var sqlContext: SQLContext = _
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")
val sparkContext = new SparkContext(
new SparkConf()
.set("spark.ui.enabled", "false")
.set("spark.sql.shuffle.partitions", "100"))
val hiveContext = new TestHiveContext(sparkContext)
sqlContext = hiveContext
try {
val df = sqlContext.range(1 << 20)
val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B"))
val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0")
if (df3.rdd.count() != 0) {
throw new Exception("df3 should have 0 output row.")
}
} finally {
sparkContext.stop()
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment