diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ccb8245cc2e7dc31acd4f86d7f058d752d534274..e41fd2db7485814ca8a85d2f4546437871a6e3c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -104,8 +104,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
 object NullPropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
-      case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
-      case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType)
+      case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
+      case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType)
       case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType)
       case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
       case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)