diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index abba8668216f4613f6221fb2ff734e350389b830..0efe3c4d456ee4bda8784add214914d6177ec25b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -29,6 +29,7 @@ trait CatalystConf {
   def groupByOrdinal: Boolean
 
   def optimizerMaxIterations: Int
+  def maxCaseBranchesForCodegen: Int
 
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used to determine if two
@@ -45,6 +46,7 @@ case class SimpleCatalystConf(
     caseSensitiveAnalysis: Boolean,
     orderByOrdinal: Boolean = true,
     groupByOrdinal: Boolean = true,
-    optimizerMaxIterations: Int = 100)
+    optimizerMaxIterations: Int = 100,
+    maxCaseBranchesForCodegen: Int = 20)
   extends CatalystConf {
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 336649c0fd7632d068917424220fb76cae9f4df4..e97e08947a5007a2d3860729551fce2c491c54be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
 }
 
 /**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * When a = true, returns b; when c = true, returns d; else returns e.
+ * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
  *
  * @param branches seq of (branch condition, branch value)
  * @param elseValue optional value for the else branch
  */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
-// scalastyle:on line.size.limit
-case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
-  extends Expression with CodegenFallback {
+abstract class CaseWhenBase(
+    branches: Seq[(Expression, Expression)],
+    elseValue: Option[Expression])
+  extends Expression with Serializable {
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
 
@@ -142,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
     }
   }
 
-  def shouldCodegen: Boolean = {
-    branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
+  override def toString: String = {
+    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
+    "CASE" + cases + elseCase + " END"
+  }
+
+  override def sql: String = {
+    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
+    "CASE" + cases + elseCase + " END"
+  }
+}
+
+
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
+// scalastyle:on line.size.limit
+case class CaseWhen(
+    val branches: Seq[(Expression, Expression)],
+    val elseValue: Option[Expression] = None)
+  extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    super[CodegenFallback].doGenCode(ctx, ev)
+  }
+
+  def toCodegen(): CaseWhenCodegen = {
+    CaseWhenCodegen(branches, elseValue)
   }
+}
+
+/**
+ * CaseWhen expression used when code generation condition is satisfied.
+ * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
+ */
+case class CaseWhenCodegen(
+    val branches: Seq[(Expression, Expression)],
+    val elseValue: Option[Expression] = None)
+  extends CaseWhenBase(branches, elseValue) with Serializable {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    if (!shouldCodegen) {
-      // Fallback to interpreted mode if there are too many branches, as it may reach the
-      // 64K limit (limit on bytecode size for a single function).
-      return super[CodegenFallback].doGenCode(ctx, ev)
-    }
     // Generate code that looks like:
     //
     // condA = ...
@@ -202,26 +241,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
       ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
       $generatedCode""")
   }
-
-  override def toString: String = {
-    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
-    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
-    "CASE" + cases + elseCase + " END"
-  }
-
-  override def sql: String = {
-    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
-    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
-    "CASE" + cases + elseCase + " END"
-  }
 }
 
 /** Factory methods for CaseWhen. */
 object CaseWhen {
-
-  // The maximum number of switches supported with codegen.
-  val MAX_NUM_CASES_FOR_CODEGEN = 20
-
   def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
     CaseWhen(branches, Option(elseValue))
   }
@@ -242,7 +265,6 @@ object CaseWhen {
   }
 }
 
-
 /**
  * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
  * When a = b, returns c; when a = d, returns e; else returns f.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c46bdfb2b506a785aeb37164a29528faabf2f9f2..b806b725a8d0c8dee9ddddcdb77f4ae3eb5be040 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -104,7 +104,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
     Batch("LocalRelation", fixedPoint,
       ConvertToLocalRelation) ::
     Batch("Subquery", Once,
-      OptimizeSubqueries) :: Nil
+      OptimizeSubqueries) ::
+    Batch("OptimizeCodegen", Once,
+      OptimizeCodegen(conf)) :: Nil
   }
 
   /**
@@ -863,6 +865,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
   }
 }
 
+/**
+ * Optimizes expressions by replacing according to CodeGen configuration.
+ */
+case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
+      e.toCodegen()
+  }
+}
+
 /**
  * Combines all adjacent [[Union]] operators into a single [[Union]].
  */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4385b0e019f25a54571c51ca27828bdbf44ebc25
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class OptimizeCodegenSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
+  }
+
+  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
+    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
+    comparePlans(actual, correctAnswer)
+  }
+
+  test("Codegen only when the number of branches is small.") {
+    assertEquivalent(
+      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())
+
+    assertEquivalent(
+      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
+      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
+  }
+
+  test("Nested CaseWhen Codegen.") {
+    assertEquivalent(
+      CaseWhen(
+        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
+        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
+      CaseWhen(
+        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
+        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
+  }
+
+  test("Multiple CaseWhen in one operator.") {
+    val plan = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
+        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
+    val correctAnswer = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
+        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
+    val optimized = Optimize.execute(plan)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Multiple CaseWhen in different operators") {
+    val plan = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      .where(
+        LessThan(
+          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
+          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      ).analyze
+    val correctAnswer = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      .where(
+        LessThan(
+          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
+          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      ).analyze
+    val optimized = Optimize.execute(plan)
+    comparePlans(optimized, correctAnswer)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 29b66e3dee3e7a858e4202f770430079aff6631c..46eaede5e717b6112897e7f134b397dfebea022e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -429,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
 
   private def supportCodegen(e: Expression): Boolean = e match {
     case e: LeafExpression => true
-    case e: CaseWhen => e.shouldCodegen
     // CodegenFallback requires the input to be an InternalRow
     case e: CodegenFallback => false
     case _ => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 7f206bdb9b24076f9577cb748cd654b59af11a0a..4ae8278a9d767e0d00026dfea59d7ded9224b3e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -402,6 +402,12 @@ object SQLConf {
     .intConf
     .createWithDefault(200)
 
+  val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches")
+    .internal()
+    .doc("The maximum number of switches supported with codegen.")
+    .intConf
+    .createWithDefault(20)
+
   val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
     .doc("The maximum number of bytes to pack into a single partition when reading files.")
     .longConf
@@ -529,6 +535,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
 
   def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
 
+  def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
+
   def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
 
   def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)