diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 854b5b461bdc8b0f3e688e366eeaf6f00a0dd0e3..4662f585cfe151e8522973e9bdab4ae55bd28bd0 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -77,10 +77,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
   protected val BETWEEN = Keyword("BETWEEN")
   protected val BY = Keyword("BY")
   protected val CACHE = Keyword("CACHE")
+  protected val CASE = Keyword("CASE")
   protected val CAST = Keyword("CAST")
   protected val COUNT = Keyword("COUNT")
   protected val DESC = Keyword("DESC")
   protected val DISTINCT = Keyword("DISTINCT")
+  protected val ELSE = Keyword("ELSE")
+  protected val END = Keyword("END")
   protected val EXCEPT = Keyword("EXCEPT")
   protected val FALSE = Keyword("FALSE")
   protected val FIRST = Keyword("FIRST")
@@ -122,11 +125,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
   protected val SUBSTRING = Keyword("SUBSTRING")
   protected val SUM = Keyword("SUM")
   protected val TABLE = Keyword("TABLE")
+  protected val THEN = Keyword("THEN")
   protected val TIMESTAMP = Keyword("TIMESTAMP")
   protected val TRUE = Keyword("TRUE")
   protected val UNCACHE = Keyword("UNCACHE")
   protected val UNION = Keyword("UNION")
   protected val UPPER = Keyword("UPPER")
+  protected val WHEN = Keyword("WHEN")
   protected val WHERE = Keyword("WHERE")
 
   // Use reflection to find the reserved words defined in this class.
@@ -333,6 +338,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
     IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
       case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
     } |
+    CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+      (ELSE ~> expression).? <~ END ^^ {
+       case casePart ~ altPart ~ elsePart =>
+         val altExprs = altPart.flatMap {
+           case we ~ te =>
+             Seq(casePart.fold(we)(EqualTo(_, we)), te)
+        }
+        CaseWhen(altExprs ++ elsePart.toList)
+    } |
     (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
       case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE))
     } |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b9b196ea5a46a84fe0c2a2daaf1b8a2b1ded4664..79de1bb855dbe9ed809885ad6e9b16bf65e49e6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
       sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
       ("true", "false") :: Nil)
   }
-  
+
   test("SPARK-3371 Renaming a function expression with group by gives error") {
     registerFunction("len", (s: String) => s.length)
     checkAnswer(
-      sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)}    
+      sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+  }
+
+  test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
+    checkAnswer(
+      sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+  }
+
+  test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
+    checkAnswer(
+      sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+  }
 }