diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 326933ec9e5ab646bb93140c017396eb340e92ed..a5ab390c76efeeb041cde229f5d4ed3197dac42d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
  * Null value propagation from bottom to top of the expression tree.
  */
 object NullPropagation extends Rule[LogicalPlan] {
-  def nonNullLiteral(e: Expression): Boolean = e match {
+  private def nonNullLiteral(e: Expression): Boolean = e match {
     case Literal(null, _) => false
     case _ => true
   }
@@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
  * Simplifies conditional expressions (if / case).
  */
 object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+  private def falseOrNullLiteral(e: Expression): Boolean = e match {
+    case FalseLiteral => true
+    case Literal(null, _) => true
+    case _ => false
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
       case If(TrueLiteral, trueValue, _) => trueValue
       case If(FalseLiteral, _, falseValue) => falseValue
+      case If(Literal(null, _), _, falseValue) => falseValue
 
-      case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+      case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
         // If there are branches that are always false, remove them.
         // If there are no more branches left, just use the else value.
         // Note that these two are handled together here in a single case statement because
         // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
-        val newBranches = branches.filter(_._1 != FalseLiteral)
+        val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
         if (newBranches.isEmpty) {
           elseValue.getOrElse(Literal.create(null, e.dataType))
         } else {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index d436b627f6bd2e8af6dbce051e62c894b9205e8e..33239c00845ff5638d810639d89f67fa658894d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, NullType}
 
 
 class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
   private val trueBranch = (TrueLiteral, Literal(5))
   private val normalBranch = (NonFoldableLiteral(true), Literal(10))
   private val unreachableBranch = (FalseLiteral, Literal(20))
+  private val nullBranch = (Literal(null, NullType), Literal(30))
 
   test("simplify if") {
     assertEquivalent(
@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
     assertEquivalent(
       If(FalseLiteral, Literal(10), Literal(20)),
       Literal(20))
+
+    assertEquivalent(
+      If(Literal(null, NullType), Literal(10), Literal(20)),
+      Literal(20))
   }
 
   test("remove unreachable branches") {
     // i.e. removing branches whose conditions are always false
     assertEquivalent(
-      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
       CaseWhen(normalBranch :: Nil, None))
   }
 
   test("remove entire CaseWhen if only the else branch is reachable") {
     assertEquivalent(
-      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
       Literal(30))
 
     assertEquivalent(
@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
 
   test("remove entire CaseWhen if the first branch is always true") {
     assertEquivalent(
-      CaseWhen(trueBranch :: normalBranch :: Nil, None),
+      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
       Literal(5))
 
     // Test branch elimination and simplification in combination
     assertEquivalent(
-      CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
+        :: Nil, None),
       Literal(5))
 
     // Make sure this doesn't trigger if there is a non-foldable branch before the true branch