diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 854b5b461bdc8b0f3e688e366eeaf6f00a0dd0e3..4662f585cfe151e8522973e9bdab4ae55bd28bd0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -77,10 +77,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CACHE = Keyword("CACHE") + protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val ELSE = Keyword("ELSE") + protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") @@ -122,11 +125,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") + protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") + protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") // Use reflection to find the reserved words defined in this class. @@ -333,6 +338,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers { IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | + CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { + case we ~ te => + Seq(casePart.fold(we)(EqualTo(_, we)), te) + } + CaseWhen(altExprs ++ elsePart.toList) + } | (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) } | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b9b196ea5a46a84fe0c2a2daaf1b8a2b1ded4664..79de1bb855dbe9ed809885ad6e9b16bf65e49e6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + } + + test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { + checkAnswer( + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + } + + test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { + checkAnswer( + sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + } }