diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 3222a9cdc2c4ed42d6e4b967a38ca7d15612de3c..c95c1f5d67dbeb06868e185d3608ef48b8f28dcb 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -530,8 +530,8 @@ valueExpression
 
 primaryExpression
     : name=(CURRENT_DATE | CURRENT_TIMESTAMP)                                                  #timeFunctionCall
-    | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END                  #simpleCase
     | CASE whenClause+ (ELSE elseExpression=expression)? END                                   #searchedCase
+    | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END                  #simpleCase
     | CAST '(' expression AS dataType ')'                                                      #cast
     | constant                                                                                 #constantDefault
     | ASTERISK                                                                                 #star
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index 17cfc8158803b233987d3e6986e249c42234f613..2fecb8dc4a60ee729522e10448cb8b44e5fbfb5c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -298,6 +298,8 @@ class ExpressionParserSuite extends PlanTest {
       CaseKeyWhen("a" ===  "a", Seq(true, 1)))
     assertEqual("case when a = 1 then b when a = 2 then c else d end",
       CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
+    assertEqual("case when (1) + case when a > b then c else d end then f else g end",
+      CaseWhen(Seq((Literal(1) + CaseWhen(Seq(('a > 'b, 'c.expr)), 'd.expr), 'f.expr)), 'g))
   }
 
   test("dereference") {