diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 946b3d446a40f7093bff4ce7c51e35f4fab7b074..d702c08cfd342fc0e3ffc79a05e5a47155e26161 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -43,7 +43,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara throw new AnalysisException("The second argument of First should be a boolean literal.") } - override def children: Seq[Expression] = child :: Nil + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -54,7 +54,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara override def dataType: DataType = child.dataType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) private lazy val first = AttributeReference("first", child.dataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 53b4b761ae514247d6b9f347ed27985cd2e7b971..af8840305805f89c5af9c442439763210ce8cc10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -40,7 +40,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat throw new AnalysisException("The second argument of First should be a boolean literal.") } - override def children: Seq[Expression] = child :: Nil + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -51,7 +51,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat override def dataType: DataType = child.dataType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) private lazy val last = AttributeReference("last", child.dataType)() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index c6b7eb63662c56aedc4282c46cfb253ae2c9d536..0ff3511c87a4f5ab7f4079a74bc6ff4e0b4d332a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -247,4 +247,16 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto |from part """.stripMargin)) } + + test("SPARK-16646: LAST_VALUE(FALSE) OVER ()") { + checkAnswer(sql("SELECT LAST_VALUE(FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT LAST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT LAST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) + } + + test("SPARK-16646: FIRST_VALUE(FALSE) OVER ()") { + checkAnswer(sql("SELECT FIRST_VALUE(FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT FIRST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT FIRST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) + } }