Skip to content
Snippets Groups Projects
Commit 6987c067 authored by Davies Liu's avatar Davies Liu Committed by Yin Huai
Browse files

[SPARK-11009] [SQL] fix wrong result of Window function in cluster mode

Currently, All windows function could generate wrong result in cluster sometimes.

The root cause is that AttributeReference is called in executor, then id of it may not be unique than others created in driver.

Here is the script that could reproduce the problem (run in local cluster):
```
from pyspark import SparkContext, HiveContext
from pyspark.sql.window import Window
from pyspark.sql.functions import rowNumber

sqlContext = HiveContext(SparkContext())
sqlContext.setConf("spark.sql.shuffle.partitions", "3")
df =  sqlContext.range(1<<20)
df2 = df.select((df.id % 1000).alias("A"), (df.id / 1000).alias('B'))
ws = Window.partitionBy(df2.A).orderBy(df2.B)
df3 = df2.select("client", "date", rowNumber().over(ws).alias("rn")).filter("rn < 0")
assert df3.count() == 0
```

Author: Davies Liu <davies@databricks.com>
Author: Yin Huai <yhuai@databricks.com>

Closes #9050 from davies/wrong_window.
parent 626aab79
No related branches found
No related tags found
No related merge requests found
......@@ -145,11 +145,10 @@ case class Window(
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
val (sortExprs, schema) = exprs.map { case e =>
val ref = AttributeReference("ordExpr", e.dataType, e.nullable)()
(SortOrder(ref, e.direction), ref)
}.unzip
val ordering = newOrdering(sortExprs, schema)
val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
}
val ordering = newOrdering(sortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case RowFrame => RowBoundOrdering(offset)
}
......@@ -205,14 +204,15 @@ case class Window(
*/
private[this] def createResultProjection(
expressions: Seq[Expression]): MutableProjection = {
val unboundToAttr = expressions.map {
e => (e, AttributeReference("windowResult", e.dataType, e.nullable)())
val references = expressions.zipWithIndex.map{ case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToAttrMap = unboundToAttr.toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap))
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
newMutableProjection(
projectList ++ patchedWindowExpression,
child.output ++ unboundToAttr.map(_._2))()
child.output)()
}
protected override def doExecute(): RDD[InternalRow] = {
......
......@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.sql.{SQLContext, QueryTest}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.DecimalType
......@@ -107,6 +108,16 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}
test("SPARK-11009 fix wrong result of Window function in cluster mode") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SPARK_11009.getClass.getName.stripSuffix("$"),
"--name", "SparkSQLConfTest",
"--master", "local-cluster[2,1,1024]",
unusedJar.toString)
runSparkSubmit(args)
}
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
......@@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest {
}
}
}
object SPARK_11009 extends QueryTest {
import org.apache.spark.sql.functions._
protected var sqlContext: SQLContext = _
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")
val sparkContext = new SparkContext(
new SparkConf()
.set("spark.ui.enabled", "false")
.set("spark.sql.shuffle.partitions", "100"))
val hiveContext = new TestHiveContext(sparkContext)
sqlContext = hiveContext
try {
val df = sqlContext.range(1 << 20)
val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B"))
val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0")
if (df3.rdd.count() != 0) {
throw new Exception("df3 should have 0 output row.")
}
} finally {
sparkContext.stop()
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment