diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 397eff05686b6e4a7e23ede2be5c7b9813c3f0bc..c0c960471a61add3d27185f2d74e57af7d989224 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
       }
 
       // Setup unique distinct aggregate children.
-      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
-      val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
-      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
+      val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
+      val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
 
       // Setup expand & aggregate operators for distinct aggregate expressions.
+      val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
       val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
         case ((group, expressions), i) =>
           val id = Literal(i + 1)
@@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
           val operators = expressions.map { e =>
             val af = e.aggregateFunction
             val naf = patchAggregateFunctionChildren(af) { x =>
-              evalWithinGroup(id, distinctAggChildAttrMap(x))
+              evalWithinGroup(id, distinctAggChildAttrLookup(x))
             }
             (e, e.copy(aggregateFunction = naf, isDistinct = false))
           }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 6bf2c53440baf5739cee4c3449ddbdd18a4ec640..8253921563b3acac704cdab4cc3942d3473a176b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
   }
 }
 
+class LongProductSum extends UserDefinedAggregateFunction {
+  def inputSchema: StructType = new StructType()
+    .add("a", LongType)
+    .add("b", LongType)
+
+  def bufferSchema: StructType = new StructType()
+    .add("product", LongType)
+
+  def dataType: DataType = LongType
+
+  def deterministic: Boolean = true
+
+  def initialize(buffer: MutableAggregationBuffer): Unit = {
+    buffer(0) = 0L
+  }
+
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+    if (!(input.isNullAt(0) || input.isNullAt(1))) {
+      buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1)
+    }
+  }
+
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
+  }
+
+  def evaluate(buffer: Row): Any =
+    buffer.getLong(0)
+}
+
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
   import testImplicits._
 
@@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
     // Register UDAFs
     sqlContext.udf.register("mydoublesum", new MyDoubleSum)
     sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
+    sqlContext.udf.register("longProductSum", new LongProductSum)
   }
 
   override def afterAll(): Unit = {
@@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
           |  count(distinct value2),
           |  sum(distinct value2),
           |  count(distinct value1, value2),
+          |  longProductSum(distinct value1, value2),
           |  count(value1),
           |  sum(value1),
           |  count(value2),
           |  sum(value2),
+          |  longProductSum(value1, value2),
           |  count(*),
           |  count(1)
           |FROM agg2
           |GROUP BY key
         """.stripMargin),
-      Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
-        Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
-        Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
-        Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
+      Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
+        Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
+        Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
+        Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
   }
 
   test("test count") {