diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index b49a4614eacab0f488e75c72f0da7aeeb5ae278e..c9024336889437004c375a813107c3a6c026f022 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -281,14 +281,17 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
   private val sum = MutableLiteral(zero.eval(EmptyRow))
   private val sumAsDouble = Cast(sum, DoubleType)
 
-  private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
+  private def addFunction(value: Any) = Add(sum, Literal(value))
 
   override def eval(input: Row): Any =
     sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
 
   override def update(input: Row): Unit = {
-    count += 1
-    sum.update(addFunction, input)
+    val evaluatedExpr = expr.eval(input)
+    if (evaluatedExpr != null) {
+      count += 1
+      sum.update(addFunction(evaluatedExpr), input)
+    }
   }
 }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 8197e8a18d447c4b1959cd0b0ce827898254403e..fb599e1e01e73cb5533aa3ef923a4cbc1d8095e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -115,6 +115,16 @@ class DslQuerySuite extends QueryTest {
       2.0)
   }
 
+  test("null average") {
+    checkAnswer(
+      testData3.groupBy()(Average('b)),
+      2.0)
+
+    checkAnswer(
+      testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
+      (2.0, 1) :: Nil)
+  }
+
   test("count") {
     assert(testData2.count() === testData2.map(_ => 1).count())
   }