diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index bfe43373d9534f3a992fdf5f5e11ee5822d09d16..47305571e579e04515ee23c4dd909a99eeb88750 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -375,9 +375,8 @@ private[hive] case class HiveUdafFunction(
 
   private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
 
-  // Cast required to avoid type inference selecting a deprecated Hive API.
   private val buffer =
-    function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
+    function.getNewAggregationBuffer
 
   override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector)
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index cb405f56bf53db0c63df2aa1c438d1ef828a5719..d7c5d1a25a82b1e727bf5aab923adbd6d54c1303 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -22,7 +22,7 @@ import java.util
 import java.util.Properties
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
+import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF}
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
 import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
@@ -93,6 +93,15 @@ class HiveUdfSuite extends QueryTest {
     sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
   }
 
+  test("SPARK-6409 UDAFAverage test") {
+    sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
+    checkAnswer(
+      sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),
+      Seq(Row(1.0, 260.182)))
+    sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg")
+    TestHive.reset()
+  }
+  
   test("SPARK-2693 udaf aggregates test") {
     checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
       sql("SELECT max(key) FROM src").collect().toSeq)