diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index f8929530c50363d81f72fbcd0d9a5575d0857614..55035f4bc5f2a28225c32073e9e407fbcddc72d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -145,11 +145,10 @@ case class Window( // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val (sortExprs, schema) = exprs.map { case e => - val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() - (SortOrder(ref, e.direction), ref) - }.unzip - val ordering = newOrdering(sortExprs, schema) + val sortExprs = exprs.zipWithIndex.map { case (e, i) => + SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) + } + val ordering = newOrdering(sortExprs, Nil) RangeBoundOrdering(ordering, current, bound) case RowFrame => RowBoundOrdering(offset) } @@ -205,14 +204,15 @@ case class Window( */ private[this] def createResultProjection( expressions: Seq[Expression]): MutableProjection = { - val unboundToAttr = expressions.map { - e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) + val references = expressions.zipWithIndex.map{ case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) } - val unboundToAttrMap = unboundToAttr.toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) newMutableProjection( projectList ++ patchedWindowExpression, - child.output ++ unboundToAttr.map(_._2))() + child.output)() } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 5f1660b62d4186221b2614d5a6e22189c9290c59..10e4ae2c50308504cd5f20a535809c0e93398984 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.DecimalType @@ -107,6 +108,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-11009 fix wrong result of Window function in cluster mode") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_11009.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest { } } } + +object SPARK_11009 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + try { + val df = sqlContext.range(1 << 20) + val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) + val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) + val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + if (df3.rdd.count() != 0) { + throw new Exception("df3 should have 0 output row.") + } + } finally { + sparkContext.stop() + } + } +}