diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 900def59d23a50d9d04157ef5455baa16dfa3880..320451c52c70650501c284c3b78e7df790500351 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -368,12 +368,12 @@ class Column(object):
 
         >>> from pyspark.sql import functions as F
         >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
-        +-----+--------------------------------------------------------+
-        | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
-        +-----+--------------------------------------------------------+
-        |Alice|                                                      -1|
-        |  Bob|                                                       1|
-        +-----+--------------------------------------------------------+
+        +-----+------------------------------------------------------------+
+        | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+        +-----+------------------------------------------------------------+
+        |Alice|                                                          -1|
+        |  Bob|                                                           1|
+        +-----+------------------------------------------------------------+
         """
         if not isinstance(condition, Column):
             raise TypeError("condition should be a Column")
@@ -393,12 +393,12 @@ class Column(object):
 
         >>> from pyspark.sql import functions as F
         >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
-        +-----+---------------------------------+
-        | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
-        +-----+---------------------------------+
-        |Alice|                                0|
-        |  Bob|                                1|
-        +-----+---------------------------------+
+        +-----+-------------------------------------+
+        | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
+        +-----+-------------------------------------+
+        |Alice|                                    0|
+        |  Bob|                                    1|
+        +-----+-------------------------------------+
         """
         v = value._jc if isinstance(value, Column) else value
         jc = self._jc.otherwise(v)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index c87b6c8e9543605149a4ca7e25b4c3d5a3199924..d0fbdacf6eafdc8a0320f58920c3141f657612d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -752,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
 
     /* Case statements */
     case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
-      CaseWhen(branches.map(nodeToExpr))
+      CaseWhen.createFromParser(branches.map(nodeToExpr))
     case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
       val keyExpr = nodeToExpr(branches.head)
       CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 6ec408a673c796883f2410ab5ff14add52acb65f..85ff4ea0c946b696b888d0bb90b4bd8a3257407b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -305,7 +305,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
           throw new AnalysisException(s"invalid function approximate($s) $udfName")
         }
       }
-    | CASE ~> whenThenElse ^^ CaseWhen
+    | CASE ~> whenThenElse ^^
+      { case branches => CaseWhen.createFromParser(branches) }
     | CASE ~> expression ~ whenThenElse ^^
       { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
     )
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 980b5d52fa8f77de53ba75dea2c83368a1ffc0d2..2737fe32cd086ad5ef458c2cc91d556e7bd6f5eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -621,14 +621,24 @@ object HiveTypeCoercion {
       case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
         val maybeCommonType = findWiderCommonType(c.valueTypes)
         maybeCommonType.map { commonType =>
-          val castedBranches = c.branches.grouped(2).map {
-            case Seq(when, value) if value.dataType != commonType =>
-              Seq(when, Cast(value, commonType))
-            case Seq(elseVal) if elseVal.dataType != commonType =>
-              Seq(Cast(elseVal, commonType))
-            case other => other
-          }.reduce(_ ++ _)
-          CaseWhen(castedBranches)
+          var changed = false
+          val newBranches = c.branches.map { case (condition, value) =>
+            if (value.dataType.sameType(commonType)) {
+              (condition, value)
+            } else {
+              changed = true
+              (condition, Cast(value, commonType))
+            }
+          }
+          val newElseValue = c.elseValue.map { value =>
+            if (value.dataType.sameType(commonType)) {
+              value
+            } else {
+              changed = true
+              Cast(value, commonType)
+            }
+          }
+          if (changed) CaseWhen(newBranches, newElseValue) else c
         }.getOrElse(c)
     }
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 5a1462433d583a46d71c17801f8d2f8d05ba3b51..8cc7bc1da2fc3735cd7b11f3279a72deda5876f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -81,44 +81,39 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
 /**
  * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
  * When a = true, returns b; when c = true, returns d; else returns e.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
  */
-case class CaseWhen(branches: Seq[Expression]) extends Expression {
-
-  // Use private[this] Array to speed up evaluation.
-  @transient private[this] lazy val branchesArr = branches.toArray
-
-  override def children: Seq[Expression] = branches
-
-  @transient lazy val whenList =
-    branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
-
-  @transient lazy val thenList =
-    branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
+  extends Expression {
 
-  val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
+  override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
 
   // both then and else expressions should be considered.
-  def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
+  def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
+
   def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
     case Seq(dt1, dt2) => dt1.sameType(dt2)
   }
 
-  override def dataType: DataType = thenList.head.dataType
+  override def dataType: DataType = branches.head._2.dataType
 
   override def nullable: Boolean = {
-    // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
-    thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
+    // Result is nullable if any of the branch is nullable, or if the else value is nullable
+    branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
+    // Make sure all branch conditions are boolean types.
     if (valueTypesEqual) {
-      if (whenList.forall(_.dataType == BooleanType)) {
+      if (branches.forall(_._1.dataType == BooleanType)) {
         TypeCheckResult.TypeCheckSuccess
       } else {
-        val index = whenList.indexWhere(_.dataType != BooleanType)
+        val index = branches.indexWhere(_._1.dataType != BooleanType)
         TypeCheckResult.TypeCheckFailure(
           s"WHEN expressions in CaseWhen should all be boolean type, " +
-            s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+            s"but the ${index + 1}th when expression's type is ${branches(index)._1}")
       }
     } else {
       TypeCheckResult.TypeCheckFailure(
@@ -127,31 +122,26 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
   }
 
   override def eval(input: InternalRow): Any = {
-    // Written in imperative fashion for performance considerations
-    val len = branchesArr.length
     var i = 0
-    // If all branches fail and an elseVal is not provided, the whole statement
-    // defaults to null, according to Hive's semantics.
-    while (i < len - 1) {
-      if (branchesArr(i).eval(input) == true) {
-        return branchesArr(i + 1).eval(input)
+    while (i < branches.size) {
+      if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) {
+        return branches(i)._2.eval(input)
       }
-      i += 2
+      i += 1
     }
-    var res: Any = null
-    if (i == len - 1) {
-      res = branchesArr(i).eval(input)
+    if (elseValue.isDefined) {
+      return elseValue.get.eval(input)
+    } else {
+      return null
     }
-    return res
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val len = branchesArr.length
     val got = ctx.freshName("got")
 
-    val cases = (0 until len/2).map { i =>
-      val cond = branchesArr(i * 2).gen(ctx)
-      val res = branchesArr(i * 2 + 1).gen(ctx)
+    val cases = branches.map { case (condition, value) =>
+      val cond = condition.gen(ctx)
+      val res = value.gen(ctx)
       s"""
         if (!$got) {
           ${cond.code}
@@ -165,17 +155,19 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
       """
     }.mkString("\n")
 
-    val other = if (len % 2 == 1) {
-      val res = branchesArr(len - 1).gen(ctx)
-      s"""
+    val elseCase = {
+      if (elseValue.isDefined) {
+        val res = elseValue.get.gen(ctx)
+        s"""
         if (!$got) {
           ${res.code}
           ${ev.isNull} = ${res.isNull};
           ${ev.value} = ${res.value};
         }
-      """
-    } else {
-      ""
+        """
+      } else {
+        ""
+      }
     }
 
     s"""
@@ -183,32 +175,42 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
       boolean ${ev.isNull} = true;
       ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
       $cases
-      $other
+      $elseCase
     """
   }
 
   override def toString: String = {
-    "CASE" + branches.sliding(2, 2).map {
-      case Seq(cond, value) => s" WHEN $cond THEN $value"
-      case Seq(elseValue) => s" ELSE $elseValue"
-    }.mkString
+    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
+    "CASE" + cases + elseCase + " END"
   }
 
   override def sql: String = {
-    val branchesSQL = branches.map(_.sql)
-    val (cases, maybeElse) = if (branches.length % 2 == 0) {
-      (branchesSQL, None)
-    } else {
-      (branchesSQL.init, Some(branchesSQL.last))
-    }
+    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
+    "CASE" + cases + elseCase + " END"
+  }
+}
 
-    val head = s"CASE "
-    val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
-    val body = cases.grouped(2).map {
-      case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
-    }.mkString(" ")
+/** Factory methods for CaseWhen. */
+object CaseWhen {
 
-    head + body + tail
+  def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
+    CaseWhen(branches, Option(elseValue))
+  }
+
+  /**
+   * A factory method to faciliate the creation of this expression when used in parsers.
+   * @param branches Expressions at even position are the branch conditions, and expressions at odd
+   *                 position are branch values.
+   */
+  def createFromParser(branches: Seq[Expression]): CaseWhen = {
+    val cases = branches.grouped(2).flatMap {
+      case cond :: value :: Nil => Some((cond, value))
+      case value :: Nil => None
+    }.toArray.toSeq  // force materialization to make the seq serializable
+    val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
+    CaseWhen(cases, elseValue)
   }
 }
 
@@ -218,17 +220,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
  */
 object CaseKeyWhen {
   def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
-    val newBranches = branches.zipWithIndex.map { case (expr, i) =>
-      if (i % 2 == 0 && i != branches.size - 1) {
-        // If this expression is at even position, then it is either a branch condition, or
-        // the very last value that is the "else value". The "i != branches.size - 1" makes
-        // sure we are not adding an EqualTo to the "else value".
-        EqualTo(key, expr)
-      } else {
-        expr
-      }
-    }
-    CaseWhen(newBranches)
+    val cases = branches.grouped(2).flatMap {
+      case cond :: value :: Nil => Some((EqualTo(key, cond), value))
+      case value :: Nil => None
+    }.toArray.toSeq  // force materialization to make the seq serializable
+    val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
+    CaseWhen(cases, elseValue)
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d4be545a35ab2e650519a99bfc6529f17f4e85f9..d0b29aa01f640f3477347cda80d256c76f35aff2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -315,6 +315,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
           } else {
             arg
           }
+        case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
+          val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
+          val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
+          if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+            changed = true
+            (newChild1, newChild2)
+          } else {
+            tuple
+          }
         case other => other
       }
       case nonChild: AnyRef => nonChild
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index cf84855885a37082f171cd5c0c7aeaf11638d03e..975cd87d090e4c8dac38b46bf67ef231ac6eaaf4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -239,7 +239,7 @@ class AnalysisSuite extends AnalysisTest {
 
   test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
     val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
-    val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
+    val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
     assertAnalysisSuccess(plan)
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 0521ed848c793fe87bc6c688a7943b19e3716aa2..59549e3998e7eb1327fdc3a0b94d3fe570bfc727 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -132,13 +132,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
 
     assertError(
-      CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
+      CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))),
       "THEN and ELSE expressions should all be same type or coercible to a common type")
     assertError(
       CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
       "THEN and ELSE expressions should all be same type or coercible to a common type")
     assertError(
-      CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
+      CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))),
       "WHEN expressions in CaseWhen should all be boolean type")
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 40378c6727667c2418e29fb677c2f6a6e791e565..b1f6c0b802d8e958993be75c371c9f6bcd7286f5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -308,15 +308,14 @@ class HiveTypeCoercionSuite extends PlanTest {
       CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
     )
     ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))),
-      CaseWhen(Seq(
-        Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)))
+      CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
+      CaseWhen(Seq((Literal(true), Literal(1.2))),
+        Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
     )
     ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))),
-      CaseWhen(Seq(
-        Literal(true), Cast(Literal(100L), DecimalType(22, 2)),
-        Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))))
+      CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
+      CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
+        Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
     )
   }
 
@@ -452,7 +451,7 @@ class HiveTypeCoercionSuite extends PlanTest {
     val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
       DecimalType(25, 5), DoubleType, DoubleType)
 
-    rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
+    rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
       val plan2 = LocalRelation(
         AttributeReference("r", rType)())
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 4029da5925580f5b3bf14801a530e69d829fdfab..3c581ecdaf068f559552fd8e543e076274acd8cf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -80,38 +80,39 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     val c5 = 'a.string.at(4)
     val c6 = 'a.string.at(5)
 
-    checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
-    checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row)
-
-    checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
-    checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
-    checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
-    checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
-
-    assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
+    checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((c3, c4)), c6), "a", row)
+    checkEvaluation(CaseWhen(Seq((Literal.create(null, BooleanType), c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((Literal.create(false, BooleanType), c4)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((Literal.create(true, BooleanType), c4)), c6), "a", row)
+
+    checkEvaluation(CaseWhen(Seq((c3, c4), (c2, c5)), c6), "a", row)
+    checkEvaluation(CaseWhen(Seq((c2, c4), (c3, c5)), c6), "b", row)
+    checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5)), c6), "c", row)
+    checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5))), null, row)
+
+    assert(CaseWhen(Seq((c2, c4)), c6).nullable === true)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable === true)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable === true)
 
     val c4_notNull = 'a.boolean.notNull.at(3)
     val c5_notNull = 'a.boolean.notNull.at(4)
     val c6_notNull = 'a.boolean.notNull.at(5)
 
-    assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
-    assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false)
+    assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull))).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull)), c6).nullable === true)
 
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
-    assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6_notNull).nullable === false)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull)), c6_notNull).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5)), c6_notNull).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6).nullable === true)
 
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
-    assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull))).nullable === true)
+    assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull))).nullable === true)
+    assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true)
   }
 
   test("case key when") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e8c61d6e01dc31e977bda09930de1b9a7e539144..6a020f9f2883e5c174ccad06338880e8fba01622 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @since 1.4.0
    */
   def when(condition: Column, value: Any): Column = this.expr match {
-    case CaseWhen(branches: Seq[Expression]) =>
-      withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) }
+    case CaseWhen(branches, None) =>
+      withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) }
+    case CaseWhen(branches, Some(_)) =>
+      throw new IllegalArgumentException(
+        "when() cannot be applied once otherwise() is applied")
     case _ =>
       throw new IllegalArgumentException(
         "when() can only be applied on a Column previously generated by when() function")
@@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @since 1.4.0
    */
   def otherwise(value: Any): Column = this.expr match {
-    case CaseWhen(branches: Seq[Expression]) =>
-      if (branches.size % 2 == 0) {
-        withExpr { CaseWhen(branches :+ lit(value).expr) }
-      } else {
-        throw new IllegalArgumentException(
-          "otherwise() can only be applied once on a Column previously generated by when()")
-      }
+    case CaseWhen(branches, None) =>
+      withExpr { CaseWhen(branches, Option(lit(value).expr)) }
+    case CaseWhen(branches, Some(_)) =>
+      throw new IllegalArgumentException(
+        "otherwise() can only be applied once on a Column previously generated by when()")
     case _ =>
       throw new IllegalArgumentException(
         "otherwise() can only be applied on a Column previously generated by when()")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 71fea2716bd9f9ccc2a1f5840cc6f79c473658ac..b8ea2261e94e263644bedd819be37fd1930eae76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions {
    * @since 1.4.0
    */
   def when(condition: Column, value: Any): Column = withExpr {
-    CaseWhen(Seq(condition.expr, lit(value).expr))
+    CaseWhen(Seq((condition.expr, lit(value).expr)))
   }
 
   /**