Skip to content
Snippets Groups Projects
Commit 2bbecdec authored by Herman van Hovell's avatar Herman van Hovell
Browse files

[SPARK-17753][SQL] Allow a complex expression as the input a value based case statement

## What changes were proposed in this pull request?
We currently only allow relatively simple expressions as the input for a value based case statement. Expressions like `case (a > 1) or (b = 2) when true then 1 when false then 0 end` currently fail. This PR adds support for such expressions.

## How was this patch tested?
Added a test to the ExpressionParserSuite.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #15322 from hvanhovell/SPARK-17753.
parent d8399b60
No related branches found
No related tags found
No related merge requests found
......@@ -527,16 +527,16 @@ valueExpression
;
primaryExpression
: constant #constantDefault
| name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall
: name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall
| CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CAST '(' expression AS dataType ')' #cast
| constant #constantDefault
| ASTERISK #star
| qualifiedName '.' ASTERISK #star
| '(' expression (',' expression)+ ')' #rowConstructor
| qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
| '(' query ')' #subqueryExpression
| CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CAST '(' expression AS dataType ')' #cast
| qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
| value=primaryExpression '[' index=valueExpression ']' #subscript
| identifier #columnReference
| base=primaryExpression '.' fieldName=identifier #dereference
......
......@@ -1138,7 +1138,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* }}}
*/
override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
val e = expression(ctx.valueExpression)
val e = expression(ctx.value)
val branches = ctx.whenClause.asScala.map { wCtx =>
(EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
}
......
......@@ -292,6 +292,10 @@ class ExpressionParserSuite extends PlanTest {
test("case when") {
assertEqual("case a when 1 then b when 2 then c else d end",
CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
assertEqual("case (a or b) when true then c when false then d else e end",
CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e)))
assertEqual("case 'a'='a' when true then 1 end",
CaseKeyWhen("a" === "a", Seq(true, 1)))
assertEqual("case when a = 1 then b when a = 2 then c else d end",
CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment