diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index d4ef04c2294a2c04acad94227de39cd3597fd399..c04bd6cd851877d49cf0f9ffab60f1543c404792 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
       }
     }
     | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
-      { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+      { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
     | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
       lexical.normalizeKeyword(udfName) match {
         case "sum" => SumDistinct(exprs.head)
         case "count" => CountDistinct(exprs)
+        case name => UnresolvedFunction(name, exprs, isDistinct = true)
         case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
       }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e58f3f64947f37980e3ed96cf76b4f6564c13071..8cadbc57e87e1b61480aa17061123a87757ede4f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
 import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -277,7 +278,7 @@ class Analyzer(
         Project(
           projectList.flatMap {
             case s: Star => s.expand(child.output, resolver)
-            case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
+            case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
               val expandedArgs = args.flatMap {
                 case s: Star => s.expand(child.output, resolver)
                 case o => o :: Nil
@@ -517,9 +518,26 @@ class Analyzer(
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
       case q: LogicalPlan =>
         q transformExpressions {
-          case u @ UnresolvedFunction(name, children) =>
+          case u @ UnresolvedFunction(name, children, isDistinct) =>
             withPosition(u) {
-              registry.lookupFunction(name, children)
+              registry.lookupFunction(name, children) match {
+                // We get an aggregate function built based on AggregateFunction2 interface.
+                // So, we wrap it in AggregateExpression2.
+                case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
+                // Currently, our old aggregate function interface supports SUM(DISTINCT ...)
+                // and COUTN(DISTINCT ...).
+                case sumDistinct: SumDistinct => sumDistinct
+                case countDistinct: CountDistinct => countDistinct
+                // DISTINCT is not meaningful with Max and Min.
+                case max: Max if isDistinct => max
+                case min: Min if isDistinct => min
+                // For other aggregate functions, DISTINCT keyword is not supported for now.
+                // Once we converted to the new code path, we will allow using DISTINCT keyword.
+                case other if isDistinct =>
+                  failAnalysis(s"$name does not support DISTINCT keyword.")
+                // If it does not have DISTINCT keyword, we will return it as is.
+                case other => other
+              }
             }
         }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c7f9713344c5054f5d7ad34b868703b18bb0e1a2..c203fcecf20fbd56205888a3a43d330b802646b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 0daee1990a6e0a04781b9ed769b94cf30f5fcb19..03da45b09f928360333677847936988d5a0f0713 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -73,7 +73,10 @@ object UnresolvedAttribute {
   def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
 }
 
-case class UnresolvedFunction(name: String, children: Seq[Expression])
+case class UnresolvedFunction(
+    name: String,
+    children: Seq[Expression],
+    isDistinct: Boolean)
   extends Expression with Unevaluable {
 
   override def dataType: DataType = throw new UnresolvedException(this, "dataType")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index b09aea03318dabd6852caedb082f4d64b9f29a83..b10a3c877434b38c6233ac6f484730da646b4c24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
 case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
   extends LeafExpression with NamedExpression {
 
-  override def toString: String = s"input[$ordinal]"
+  override def toString: String = s"input[$ordinal, $dataType]"
 
   override def eval(input: InternalRow): Any = input(ordinal)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index aada25276adb720a974e45fb0b4fa1f1f68c9a85..29ae47e842ddb46117686d499b377144ccf5bca3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] {
     val primitive = ctx.freshName("primitive")
     val ve = GeneratedExpressionCode("", isNull, primitive)
     ve.code = genCode(ctx, ve)
-    ve
+    // Add `this` in the comment.
+    ve.copy(s"/* $this */\n" + ve.code)
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b924af4cc84d827348a159192a0da13434822518
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+case class Average(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = resultType
+
+  // Expected input data type.
+  // TODO: Once we remove the old code path, we can use our analyzer to cast NullType
+  // to the default data type of the NumericType.
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
+
+  private val resultType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 4, scale + 4)
+    case DecimalType.Unlimited => DecimalType.Unlimited
+    case _ => DoubleType
+  }
+
+  private val sumDataType = child.dataType match {
+    case _ @ DecimalType() => DecimalType.Unlimited
+    case _ => DoubleType
+  }
+
+  private val currentSum = AttributeReference("currentSum", sumDataType)()
+  private val currentCount = AttributeReference("currentCount", LongType)()
+
+  override val bufferAttributes = currentSum :: currentCount :: Nil
+
+  override val initialValues = Seq(
+    /* currentSum = */ Cast(Literal(0), sumDataType),
+    /* currentCount = */ Literal(0L)
+  )
+
+  override val updateExpressions = Seq(
+    /* currentSum = */
+    Add(
+      currentSum,
+      Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+    /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+  )
+
+  override val mergeExpressions = Seq(
+    /* currentSum = */ currentSum.left + currentSum.right,
+    /* currentCount = */ currentCount.left + currentCount.right
+  )
+
+  // If all input are nulls, currentCount will be 0 and we will get null after the division.
+  override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
+}
+
+case class Count(child: Expression) extends AlgebraicAggregate {
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = LongType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val currentCount = AttributeReference("currentCount", LongType)()
+
+  override val bufferAttributes = currentCount :: Nil
+
+  override val initialValues = Seq(
+    /* currentCount = */ Literal(0L)
+  )
+
+  override val updateExpressions = Seq(
+    /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+  )
+
+  override val mergeExpressions = Seq(
+    /* currentCount = */ currentCount.left + currentCount.right
+  )
+
+  override val evaluateExpression = Cast(currentCount, LongType)
+}
+
+case class First(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // First is not a deterministic function.
+  override def deterministic: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val first = AttributeReference("first", child.dataType)()
+
+  override val bufferAttributes = first :: Nil
+
+  override val initialValues = Seq(
+    /* first = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* first = */ If(IsNull(first), child, first)
+  )
+
+  override val mergeExpressions = Seq(
+    /* first = */ If(IsNull(first.left), first.right, first.left)
+  )
+
+  override val evaluateExpression = first
+}
+
+case class Last(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Last is not a deterministic function.
+  override def deterministic: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val last = AttributeReference("last", child.dataType)()
+
+  override val bufferAttributes = last :: Nil
+
+  override val initialValues = Seq(
+    /* last = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* last = */ If(IsNull(child), last, child)
+  )
+
+  override val mergeExpressions = Seq(
+    /* last = */ If(IsNull(last.right), last.left, last.right)
+  )
+
+  override val evaluateExpression = last
+}
+
+case class Max(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val max = AttributeReference("max", child.dataType)()
+
+  override val bufferAttributes = max :: Nil
+
+  override val initialValues = Seq(
+    /* max = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
+  )
+
+  override val mergeExpressions = {
+    val greatest = Greatest(Seq(max.left, max.right))
+    Seq(
+      /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
+    )
+  }
+
+  override val evaluateExpression = max
+}
+
+case class Min(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = child.dataType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  private val min = AttributeReference("min", child.dataType)()
+
+  override val bufferAttributes = min :: Nil
+
+  override val initialValues = Seq(
+    /* min = */ Literal.create(null, child.dataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
+  )
+
+  override val mergeExpressions = {
+    val least = Least(Seq(min.left, min.right))
+    Seq(
+      /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
+    )
+  }
+
+  override val evaluateExpression = min
+}
+
+case class Sum(child: Expression) extends AlgebraicAggregate {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = resultType
+
+  // Expected input data type.
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
+
+  private val resultType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType(precision + 4, scale + 4)
+    case DecimalType.Unlimited => DecimalType.Unlimited
+    case _ => child.dataType
+  }
+
+  private val sumDataType = child.dataType match {
+    case _ @ DecimalType() => DecimalType.Unlimited
+    case _ => child.dataType
+  }
+
+  private val currentSum = AttributeReference("currentSum", sumDataType)()
+
+  private val zero = Cast(Literal(0), sumDataType)
+
+  override val bufferAttributes = currentSum :: Nil
+
+  override val initialValues = Seq(
+    /* currentSum = */ Literal.create(null, sumDataType)
+  )
+
+  override val updateExpressions = Seq(
+    /* currentSum = */
+    Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
+  )
+
+  override val mergeExpressions = {
+    val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
+    Seq(
+      /* currentSum = */
+      Coalesce(Seq(add, currentSum.left))
+    )
+  }
+
+  override val evaluateExpression = Cast(currentSum, resultType)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
new file mode 100644
index 0000000000000000000000000000000000000000..577ede73cb01ff68c34d690f04c7bae8c24bd373
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/** The mode of an [[AggregateFunction1]]. */
+private[sql] sealed trait AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
+ * This function updates the given aggregation buffer with the original input of this
+ * function. When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object Partial extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function.
+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
+ * When it has processed all input rows, the aggregation buffer is returned.
+ */
+private[sql] case object PartialMerge extends AggregateMode
+
+/**
+ * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function and the generate final result.
+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
+ * When it has processed all input rows, the final result of this function is returned.
+ */
+private[sql] case object Final extends AggregateMode
+
+/**
+ * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * from original input rows without any partial aggregation.
+ * This function updates the given aggregation buffer with the original input of this
+ * function. When it has processed all input rows, the final result of this function is returned.
+ */
+private[sql] case object Complete extends AggregateMode
+
+/**
+ * A place holder expressions used in code-gen, it does not change the corresponding value
+ * in the row.
+ */
+private[sql] case object NoOp extends Expression with Unevaluable {
+  override def nullable: Boolean = true
+  override def eval(input: InternalRow): Any = {
+    throw new TreeNodeException(
+      this, s"No function to evaluate expression. type: ${this.nodeName}")
+  }
+  override def dataType: DataType = NullType
+  override def children: Seq[Expression] = Nil
+}
+
+/**
+ * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
+ * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
+ * @param aggregateFunction
+ * @param mode
+ * @param isDistinct
+ */
+private[sql] case class AggregateExpression2(
+    aggregateFunction: AggregateFunction2,
+    mode: AggregateMode,
+    isDistinct: Boolean) extends AggregateExpression {
+
+  override def children: Seq[Expression] = aggregateFunction :: Nil
+  override def dataType: DataType = aggregateFunction.dataType
+  override def foldable: Boolean = false
+  override def nullable: Boolean = aggregateFunction.nullable
+
+  override def references: AttributeSet = {
+    val childReferemces = mode match {
+      case Partial | Complete => aggregateFunction.references.toSeq
+      case PartialMerge | Final => aggregateFunction.bufferAttributes
+    }
+
+    AttributeSet(childReferemces)
+  }
+
+  override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
+}
+
+abstract class AggregateFunction2
+  extends Expression with ImplicitCastInputTypes {
+
+  self: Product =>
+
+  /** An aggregate function is not foldable. */
+  override def foldable: Boolean = false
+
+  /**
+   * The offset of this function's buffer in the underlying buffer shared with other functions.
+   */
+  var bufferOffset: Int = 0
+
+  /** The schema of the aggregation buffer. */
+  def bufferSchema: StructType
+
+  /** Attributes of fields in bufferSchema. */
+  def bufferAttributes: Seq[AttributeReference]
+
+  /** Clones bufferAttributes. */
+  def cloneBufferAttributes: Seq[Attribute]
+
+  /**
+   * Initializes its aggregation buffer located in `buffer`.
+   * It will use bufferOffset to find the starting point of
+   * its buffer in the given `buffer` shared with other functions.
+   */
+  def initialize(buffer: MutableRow): Unit
+
+  /**
+   * Updates its aggregation buffer located in `buffer` based on the given `input`.
+   * It will use bufferOffset to find the starting point of its buffer in the given `buffer`
+   * shared with other functions.
+   */
+  def update(buffer: MutableRow, input: InternalRow): Unit
+
+  /**
+   * Updates its aggregation buffer located in `buffer1` by combining intermediate results
+   * in the current buffer and intermediate results from another buffer `buffer2`.
+   * It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
+   * and `buffer2`.
+   */
+  def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+    throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
+
+/**
+ * A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
+ */
+abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
+  self: Product =>
+
+  val initialValues: Seq[Expression]
+  val updateExpressions: Seq[Expression]
+  val mergeExpressions: Seq[Expression]
+  val evaluateExpression: Expression
+
+  override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+  /**
+   * A helper class for representing an attribute used in merging two
+   * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`,
+   * we merge buffer values and then update bufferLeft. A [[RichAttribute]]
+   * of an [[AttributeReference]] `a` has two functions `left` and `right`,
+   * which represent `a` in `bufferLeft` and `bufferRight`, respectively.
+   * @param a
+   */
+  implicit class RichAttribute(a: AttributeReference) {
+    /** Represents this attribute at the mutable buffer side. */
+    def left: AttributeReference = a
+
+    /** Represents this attribute at the input buffer side (the data value is read-only). */
+    def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a))
+  }
+
+  /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
+  override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
+
+  override def initialize(buffer: MutableRow): Unit = {
+    var i = 0
+    while (i < bufferAttributes.size) {
+      buffer(i + bufferOffset) = initialValues(i).eval()
+      i += 1
+    }
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's update should not be called directly")
+  }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's merge should not be called directly")
+  }
+
+  override def eval(buffer: InternalRow): Any = {
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's eval should not be called directly")
+  }
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index d705a1286065cd0262395431b1712339dee3e8b8..e07c920a41d0ace78a286e257fc7a1454cbc2be7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -27,7 +27,9 @@ import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
 
 
-trait AggregateExpression extends Expression with Unevaluable {
+trait AggregateExpression extends Expression with Unevaluable
+
+trait AggregateExpression1 extends AggregateExpression {
 
   /**
    * Aggregate expressions should not be foldable.
@@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable {
    * Creates a new instance that can be used to compute this aggregate expression for a group
    * of input rows/
    */
-  def newInstance(): AggregateFunction
+  def newInstance(): AggregateFunction1
 }
 
 /**
@@ -54,10 +56,10 @@ case class SplitEvaluation(
     partialEvaluations: Seq[NamedExpression])
 
 /**
- * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
+ * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
  * These partial evaluations can then be combined to compute the actual answer.
  */
-trait PartialAggregate extends AggregateExpression {
+trait PartialAggregate1 extends AggregateExpression1 {
 
   /**
    * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
@@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression {
 
 /**
  * A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
+ * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
  */
-abstract class AggregateFunction
-  extends LeafExpression with AggregateExpression with Serializable {
+abstract class AggregateFunction1
+  extends LeafExpression with AggregateExpression1 with Serializable {
 
   /** Base should return the generic aggregate expression that this function is computing */
-  val base: AggregateExpression
+  val base: AggregateExpression1
 
   override def nullable: Boolean = base.nullable
   override def dataType: DataType = base.dataType
@@ -81,12 +83,12 @@ abstract class AggregateFunction
   def update(input: InternalRow): Unit
 
   // Do we really need this?
-  override def newInstance(): AggregateFunction = {
+  override def newInstance(): AggregateFunction1 = {
     makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
   }
 }
 
-case class Min(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
 
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
@@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate
     TypeUtils.checkForOrderingExpr(child.dataType, "function min")
 }
 
-case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
   override def eval(input: InternalRow): Any = currentMin.value
 }
 
-case class Max(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
 
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
@@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate
     TypeUtils.checkForOrderingExpr(child.dataType, "function max")
 }
 
-case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
@@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
   override def eval(input: InternalRow): Any = currentMax.value
 }
 
-case class Count(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
 
   override def nullable: Boolean = false
   override def dataType: LongType.type = LongType
@@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat
   override def newInstance(): CountFunction = new CountFunction(child, this)
 }
 
-case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   var count: Long = _
@@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
   override def eval(input: InternalRow): Any = count
 }
 
-case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
+case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = expressions
@@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
 
 case class CountDistinctFunction(
     @transient expr: Seq[Expression],
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -220,7 +222,7 @@ case class CountDistinctFunction(
   override def eval(input: InternalRow): Any = seen.size.toLong
 }
 
-case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
+case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = expressions
@@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
 
 case class CollectHashSetFunction(
     @transient expr: Seq[Expression],
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -255,7 +257,7 @@ case class CollectHashSetFunction(
   }
 }
 
-case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
+case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = inputSet :: Nil
@@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
 
 case class CombineSetsAndCountFunction(
     @transient inputSet: Expression,
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
 }
 
 case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
-  extends UnaryExpression with AggregateExpression {
+  extends UnaryExpression with AggregateExpression1 {
 
   override def nullable: Boolean = false
   override def dataType: DataType = HyperLogLogUDT
@@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
 
 case class ApproxCountDistinctPartitionFunction(
     expr: Expression,
-    base: AggregateExpression,
+    base: AggregateExpression1,
     relativeSD: Double)
-  extends AggregateFunction {
+  extends AggregateFunction1 {
   def this() = this(null, null, 0) // Required for serialization.
 
   private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction(
 }
 
 case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
-  extends UnaryExpression with AggregateExpression {
+  extends UnaryExpression with AggregateExpression1 {
 
   override def nullable: Boolean = false
   override def dataType: LongType.type = LongType
@@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
 
 case class ApproxCountDistinctMergeFunction(
     expr: Expression,
-    base: AggregateExpression,
+    base: AggregateExpression1,
     relativeSD: Double)
-  extends AggregateFunction {
+  extends AggregateFunction1 {
   def this() = this(null, null, 0) // Required for serialization.
 
   private val hyperLogLog = new HyperLogLog(relativeSD)
@@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction(
 }
 
 case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
-  extends UnaryExpression with PartialAggregate {
+  extends UnaryExpression with PartialAggregate1 {
 
   override def nullable: Boolean = false
   override def dataType: LongType.type = LongType
@@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
   override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
 }
 
-case class Average(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
 
   override def prettyName: String = "avg"
 
@@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
     TypeUtils.checkForNumericExpr(child.dataType, "function average")
 }
 
-case class AverageFunction(expr: Expression, base: AggregateExpression)
-  extends AggregateFunction {
+case class AverageFunction(expr: Expression, base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
   }
 }
 
-case class Sum(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
 
   override def nullable: Boolean = true
 
@@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate
     TypeUtils.checkForNumericExpr(child.dataType, "function sum")
 }
 
-case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   private val calcType =
@@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
  *          <-- null         <-- no data
  * null     <-- null         <-- no data
  */
-case class CombineSum(child: Expression) extends AggregateExpression {
+case class CombineSum(child: Expression) extends AggregateExpression1 {
   def this() = this(null)
 
   override def children: Seq[Expression] = child :: Nil
@@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression {
   override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
 }
 
-case class CombineSumFunction(expr: Expression, base: AggregateExpression)
-  extends AggregateFunction {
+case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
   }
 }
 
-case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate {
+case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
 
   def this() = this(null)
   override def nullable: Boolean = true
@@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
     TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
 }
 
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
-  extends AggregateFunction {
+case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
   }
 }
 
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
   def this() = this(null, null)
 
   override def children: Seq[Expression] = inputSet :: Nil
@@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg
 
 case class CombineSetsAndSumFunction(
     @transient inputSet: Expression,
-    @transient base: AggregateExpression)
-  extends AggregateFunction {
+    @transient base: AggregateExpression1)
+  extends AggregateFunction1 {
 
   def this() = this(null, null) // Required for serialization.
 
@@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction(
   }
 }
 
-case class First(child: Expression) extends UnaryExpression with PartialAggregate {
+case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
   override def toString: String = s"FIRST($child)"
@@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat
   override def newInstance(): FirstFunction = new FirstFunction(child, this)
 }
 
-case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   var result: Any = null
@@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
   override def eval(input: InternalRow): Any = result
 }
 
-case class Last(child: Expression) extends UnaryExpression with PartialAggregate {
+case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
   override def references: AttributeSet = child.references
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
@@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate
   override def newInstance(): LastFunction = new LastFunction(child, this)
 }
 
-case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
   def this() = this(null, null) // Required for serialization.
 
   var result: Any = null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 03b4b3c216f496492b52722c5f2200be124a4527..d838268f46956b176775774fe973adb462d26195 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
 
   protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
     val ctx = newCodeGenContext()
-    val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
-      val evaluationCode = e.gen(ctx)
-      evaluationCode.code +
-        s"""
-          if(${evaluationCode.isNull})
-            mutableRow.setNullAt($i);
-          else
-            ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
-        """
+    val projectionCode = expressions.zipWithIndex.map {
+      case (NoOp, _) => ""
+      case (e, i) =>
+        val evaluationCode = e.gen(ctx)
+        evaluationCode.code +
+          s"""
+            if(${evaluationCode.isNull})
+              mutableRow.setNullAt($i);
+            else
+              ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
+          """
     }
     // collect projections into blocks as function has 64kb codesize limit in JVM
     val projectionBlocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 179a348d5baac53df9fe97b04f25bcf8f15fb07b..b8e3b0d53a50599f343616b56b75fe7a65fcff4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -129,10 +129,10 @@ object PartialAggregation {
     case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
       // Collect all aggregate expressions.
       val allAggregates =
-        aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+        aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
       // Collect all aggregate expressions that can be computed partially.
       val partialAggregates =
-        aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+        aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
 
       // Only do partial aggregation if supported by all aggregate expressions.
       if (allAggregates.size == partialAggregates.size) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 986c315b3173a2068b7cf4fec258fe15fe46f777..6aefa9f67556a4ed5091e2d3af1ff2c673327b40 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 78c780bdc5797742553e4744bc066311550f43e5..1474b170ba896718cd1ede862638d1342f42ab44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -402,6 +402,9 @@ private[spark] object SQLConf {
     defaultValue = Some(true),
     isPublic = false)
 
+  val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
+    defaultValue = Some(true), doc = "<TODO>")
+
   val USE_SQL_SERIALIZER2 = booleanConf(
     "spark.sql.useSerializer2",
     defaultValue = Some(true), isPublic = false)
@@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
 
   private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
 
+  private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
+
   private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
 
   private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8b4528b5d52fed4c093ac55cf098af628008ba0f..49bfe74b680af1ffa9a26112813b7a6c6d8127a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
   @transient
   val udf: UDFRegistration = new UDFRegistration(this)
 
+  @transient
+  val udaf: UDAFRegistration = new UDAFRegistration(this)
+
   /**
    * Returns true if the table is currently cached in-memory.
    * @group cachemgmt
@@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
       DDLStrategy ::
       TakeOrderedAndProject ::
       HashAggregation ::
+      Aggregation ::
       LeftSemiJoin ::
       HashJoin ::
       InMemoryScans ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5b872f5e3eecdcaa585257ae8d52a2f6de1d7d9d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Expression}
+import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction}
+
+class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
+
+  private val functionRegistry = sqlContext.functionRegistry
+
+  def register(
+      name: String,
+      func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+    def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
+    functionRegistry.registerFunction(name, builder)
+    func
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 3cd60a2aa55edd4a7db6b594969e906944d21f78..c2c945321db954efe40e1ab4745be74fc433761a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -68,14 +68,14 @@ case class Aggregate(
    *                        output.
    */
   case class ComputedAggregate(
-      unbound: AggregateExpression,
-      aggregate: AggregateExpression,
+      unbound: AggregateExpression1,
+      aggregate: AggregateExpression1,
       resultAttribute: AttributeReference)
 
   /** A list of aggregates that need to be computed for each group. */
   private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
     agg.collect {
-      case a: AggregateExpression =>
+      case a: AggregateExpression1 =>
         ComputedAggregate(
           a,
           BindReferences.bindReference(a, child.output),
@@ -87,8 +87,8 @@ case class Aggregate(
   private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
 
   /** Creates a new aggregate buffer for a group. */
-  private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
-    val buffer = new Array[AggregateFunction](computedAggregates.length)
+  private[this] def newAggregateBuffer(): Array[AggregateFunction1] = {
+    val buffer = new Array[AggregateFunction1](computedAggregates.length)
     var i = 0
     while (i < computedAggregates.length) {
       buffer(i) = computedAggregates(i).aggregate.newInstance()
@@ -146,7 +146,7 @@ case class Aggregate(
       }
     } else {
       child.execute().mapPartitions { iter =>
-        val hashTable = new HashMap[InternalRow, Array[AggregateFunction]]
+        val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
         val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
 
         var currentRow: InternalRow = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 2750053594f990fd9aeefeb9b9b2f0caa5676df3..d31e265a293e9ad5d1507ee269da3805b53c2945 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
         }
 
         def addSortIfNecessary(child: SparkPlan): SparkPlan = {
-          if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) {
-            sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+
+          if (rowOrdering.nonEmpty) {
+            // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
+            val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
+            if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
+              sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+            } else {
+              child
+            }
           } else {
             child
           }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index ecde9c57139a68670b437760fa21f9dbb83cd79c..0e63f2fe29cb3e90baac8aeea6e0c6e829536f84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -69,7 +69,7 @@ case class GeneratedAggregate(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val aggregatesToCompute = aggregateExpressions.flatMap { a =>
-      a.collect { case agg: AggregateExpression => agg}
+      a.collect { case agg: AggregateExpression1 => agg}
     }
 
     // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8cef7f200d2dcc48e48b2e891411ceb6c7d4443d..f54aa2027f6a6804c6781ccdf178806af7371189 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.{SQLContext, Strategy, execution}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2}
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
              if canBeCodeGened(
                   allAggregates(partialComputation) ++
                   allAggregates(rewrittenAggregateExpressions)) &&
-               codegenEnabled =>
+               codegenEnabled &&
+               !canBeConvertedToNewAggregation(plan) =>
           execution.GeneratedAggregate(
             partial = false,
             namedGroupingAttributes,
@@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
              rewrittenAggregateExpressions,
              groupingExpressions,
              partialComputation,
-             child) =>
+             child) if !canBeConvertedToNewAggregation(plan) =>
         execution.Aggregate(
           partial = false,
           namedGroupingAttributes,
@@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case _ => Nil
     }
 
-    def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
+    def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
+      aggregate.Utils.tryConvert(
+        plan,
+        sqlContext.conf.useSqlAggregate2,
+        sqlContext.conf.codegenEnabled).isDefined
+    }
+
+    def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
       case _: CombineSum | _: Sum | _: Count | _: Max | _: Min |  _: CombineSetsAndCount => false
       // The generated set implementation is pretty limited ATM.
       case CollectHashSet(exprs) if exprs.size == 1  &&
@@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case _ => true
     }
 
-    def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] =
-      exprs.flatMap(_.collect { case a: AggregateExpression => a })
+    def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
+      exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
   }
 
+  /**
+   * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
+   */
+  object Aggregation extends Strategy {
+    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case p: logical.Aggregate =>
+        val converted =
+          aggregate.Utils.tryConvert(
+            p,
+            sqlContext.conf.useSqlAggregate2,
+            sqlContext.conf.codegenEnabled)
+        converted match {
+          case None => Nil // Cannot convert to new aggregation code path.
+          case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
+            // Extracts all distinct aggregate expressions from the resultExpressions.
+            val aggregateExpressions = resultExpressions.flatMap { expr =>
+              expr.collect {
+                case agg: AggregateExpression2 => agg
+              }
+            }.toSet.toSeq
+            // For those distinct aggregate expressions, we create a map from the
+            // aggregate function to the corresponding attribute of the function.
+            val aggregateFunctionMap = aggregateExpressions.map { agg =>
+              val aggregateFunction = agg.aggregateFunction
+              (aggregateFunction, agg.isDistinct) ->
+                Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+            }.toMap
+
+            val (functionsWithDistinct, functionsWithoutDistinct) =
+              aggregateExpressions.partition(_.isDistinct)
+            if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+              // This is a sanity check. We should not reach here when we have multiple distinct
+              // column sets (aggregate.NewAggregation will not match).
+              sys.error(
+                "Multiple distinct column sets are not supported by the new aggregation" +
+                  "code path.")
+            }
+
+            val aggregateOperator =
+              if (functionsWithDistinct.isEmpty) {
+                aggregate.Utils.planAggregateWithoutDistinct(
+                  groupingExpressions,
+                  aggregateExpressions,
+                  aggregateFunctionMap,
+                  resultExpressions,
+                  planLater(child))
+              } else {
+                aggregate.Utils.planAggregateWithOneDistinct(
+                  groupingExpressions,
+                  functionsWithDistinct,
+                  functionsWithoutDistinct,
+                  aggregateFunctionMap,
+                  resultExpressions,
+                  planLater(child))
+              }
+
+            aggregateOperator
+        }
+
+      case _ => Nil
+    }
+  }
+
+
   object BroadcastNestedLoopJoin extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Join(left, right, joinType, condition) =>
@@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.Filter(condition, planLater(child)) :: Nil
       case e @ logical.Expand(_, _, _, child) =>
         execution.Expand(e.projections, e.output, planLater(child)) :: Nil
-      case logical.Aggregate(group, agg, child) =>
-        execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
+      case a @ logical.Aggregate(group, agg, child) => {
+        val useNewAggregation =
+          aggregate.Utils.tryConvert(
+            a,
+            sqlContext.conf.useSqlAggregate2,
+            sqlContext.conf.codegenEnabled).isDefined
+        if (useNewAggregation) {
+          // If this logical.Aggregate can be planned to use new aggregation code path
+          // (i.e. it can be planned by the Strategy Aggregation), we will not use the old
+          // aggregation code path.
+          Nil
+        } else {
+          execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
+        }
+      }
       case logical.Window(projectList, windowExpressions, spec, child) =>
         execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil
       case logical.Sample(lb, ub, withReplacement, seed, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0c9082897f390e7519989215a010428afff0d56e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
+
+case class Aggregate2Sort(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    aggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  override def canProcessUnsafeRows: Boolean = true
+
+  override def references: AttributeSet = {
+    val referencesInResults =
+      AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
+
+    AttributeSet(
+      groupingExpressions.flatMap(_.references) ++
+      aggregateExpressions.flatMap(_.references) ++
+      referencesInResults)
+  }
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    // TODO: We should not sort the input rows if they are just in reversed order.
+    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    // It is possible that the child.outputOrdering starts with the required
+    // ordering expressions (e.g. we require [a] as the sort expression and the
+    // child's outputOrdering is [a, b]). We can only guarantee the output rows
+    // are sorted by values of groupingExpressions.
+    groupingExpressions.map(SortOrder(_, Ascending))
+  }
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+    child.execute().mapPartitions { iter =>
+      if (aggregateExpressions.length == 0) {
+        new GroupingIterator(
+          groupingExpressions,
+          resultExpressions,
+          newMutableProjection,
+          child.output,
+          iter)
+      } else {
+        val aggregationIterator: SortAggregationIterator = {
+          aggregateExpressions.map(_.mode).distinct.toList match {
+            case Partial :: Nil =>
+              new PartialSortAggregationIterator(
+                groupingExpressions,
+                aggregateExpressions,
+                newMutableProjection,
+                child.output,
+                iter)
+            case PartialMerge :: Nil =>
+              new PartialMergeSortAggregationIterator(
+                groupingExpressions,
+                aggregateExpressions,
+                newMutableProjection,
+                child.output,
+                iter)
+            case Final :: Nil =>
+              new FinalSortAggregationIterator(
+                groupingExpressions,
+                aggregateExpressions,
+                aggregateAttributes,
+                resultExpressions,
+                newMutableProjection,
+                child.output,
+                iter)
+            case other =>
+              sys.error(
+                s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
+                  s"modes $other in this operator.")
+          }
+        }
+
+        aggregationIterator
+      }
+    }
+  }
+}
+
+case class FinalAndCompleteAggregate2Sort(
+    previousGroupingExpressions: Seq[NamedExpression],
+    groupingExpressions: Seq[NamedExpression],
+    finalAggregateExpressions: Seq[AggregateExpression2],
+    finalAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+  override def references: AttributeSet = {
+    val referencesInResults =
+      AttributeSet(resultExpressions.flatMap(_.references)) --
+        AttributeSet(finalAggregateExpressions) --
+        AttributeSet(completeAggregateExpressions)
+
+    AttributeSet(
+      groupingExpressions.flatMap(_.references) ++
+        finalAggregateExpressions.flatMap(_.references) ++
+        completeAggregateExpressions.flatMap(_.references) ++
+        referencesInResults)
+  }
+
+  override def requiredChildDistribution: List[Distribution] = {
+    if (groupingExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(groupingExpressions) :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+    child.execute().mapPartitions { iter =>
+
+      new FinalAndCompleteSortAggregationIterator(
+        previousGroupingExpressions.length,
+        groupingExpressions,
+        finalAggregateExpressions,
+        finalAggregateAttributes,
+        completeAggregateExpressions,
+        completeAggregateAttributes,
+        resultExpressions,
+        newMutableProjection,
+        child.output,
+        iter)
+    }
+  }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ce1cbdc9cb090333e3d733daa4738334cbd11e86
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -0,0 +1,749 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.NullType
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An iterator used to evaluate aggregate functions. It assumes that input rows
+ * are already grouped by values of `groupingExpressions`.
+ */
+private[sql] abstract class SortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends Iterator[InternalRow] {
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Static fields for this iterator
+  ///////////////////////////////////////////////////////////////////////////
+
+  protected val aggregateFunctions: Array[AggregateFunction2] = {
+    var bufferOffset = initialBufferOffset
+    val functions = new Array[AggregateFunction2](aggregateExpressions.length)
+    var i = 0
+    while (i < aggregateExpressions.length) {
+      val func = aggregateExpressions(i).aggregateFunction
+      val funcWithBoundReferences = aggregateExpressions(i).mode match {
+        case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+          // We need to create BoundReferences if the function is not an
+          // AlgebraicAggregate (it does not support code-gen) and the mode of
+          // this function is Partial or Complete because we will call eval of this
+          // function's children in the update method of this aggregate function.
+          // Those eval calls require BoundReferences to work.
+          BindReferences.bindReference(func, inputAttributes)
+        case _ => func
+      }
+      // Set bufferOffset for this function. It is important that setting bufferOffset
+      // happens after all potential bindReference operations because bindReference
+      // will create a new instance of the function.
+      funcWithBoundReferences.bufferOffset = bufferOffset
+      bufferOffset += funcWithBoundReferences.bufferSchema.length
+      functions(i) = funcWithBoundReferences
+      i += 1
+    }
+    functions
+  }
+
+  // All non-algebraic aggregate functions.
+  protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+    aggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }.toArray
+  }
+
+  // Positions of those non-algebraic aggregate functions in aggregateFunctions.
+  // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+  // func2 and func3 are non-algebraic aggregate functions.
+  // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+  protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
+    val positions = new ArrayBuffer[Int]()
+    var i = 0
+    while (i < aggregateFunctions.length) {
+      aggregateFunctions(i) match {
+        case agg: AlgebraicAggregate =>
+        case _ => positions += i
+      }
+      i += 1
+    }
+    positions.toArray
+  }
+
+  // This is used to project expressions for the grouping expressions.
+  protected val groupGenerator =
+    newMutableProjection(groupingExpressions, inputAttributes)()
+
+  // The underlying buffer shared by all aggregate functions.
+  protected val buffer: MutableRow = {
+    // The number of elements of the underlying buffer of this operator.
+    // All aggregate functions are sharing this underlying buffer and they find their
+    // buffer values through bufferOffset.
+    var size = initialBufferOffset
+    var i = 0
+    while (i < aggregateFunctions.length) {
+      size += aggregateFunctions(i).bufferSchema.length
+      i += 1
+    }
+    new GenericMutableRow(size)
+  }
+
+  protected val joinedRow = new JoinedRow4
+
+  protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
+
+  // This projection is used to initialize buffer values for all AlgebraicAggregates.
+  protected val algebraicInitialProjection = {
+    val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.initialValues
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+    newMutableProjection(initExpressions, Nil)().target(buffer)
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Mutable states
+  ///////////////////////////////////////////////////////////////////////////
+
+  // The partition key of the current partition.
+  protected var currentGroupingKey: InternalRow = _
+  // The partition key of next partition.
+  protected var nextGroupingKey: InternalRow = _
+  // The first row of next partition.
+  protected var firstRowInNextGroup: InternalRow = _
+  // Indicates if we has new group of rows to process.
+  protected var hasNewGroup: Boolean = true
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Private methods
+  ///////////////////////////////////////////////////////////////////////////
+
+  /** Initializes buffer values for all aggregate functions. */
+  protected def initializeBuffer(): Unit = {
+    algebraicInitialProjection(EmptyRow)
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).initialize(buffer)
+      i += 1
+    }
+  }
+
+  protected def initialize(): Unit = {
+    if (inputIter.hasNext) {
+      initializeBuffer()
+      val currentRow = inputIter.next().copy()
+      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+      // we are making a copy at here.
+      nextGroupingKey = groupGenerator(currentRow).copy()
+      firstRowInNextGroup = currentRow
+    } else {
+      // This iter is an empty one.
+      hasNewGroup = false
+    }
+  }
+
+  /** Processes rows in the current group. It will stop when it find a new group. */
+  private def processCurrentGroup(): Unit = {
+    currentGroupingKey = nextGroupingKey
+    // Now, we will start to find all rows belonging to this group.
+    // We create a variable to track if we see the next group.
+    var findNextPartition = false
+    // firstRowInNextGroup is the first row of this group. We first process it.
+    processRow(firstRowInNextGroup)
+    // The search will stop when we see the next group or there is no
+    // input row left in the iter.
+    while (inputIter.hasNext && !findNextPartition) {
+      val currentRow = inputIter.next()
+      // Get the grouping key based on the grouping expressions.
+      // For the below compare method, we do not need to make a copy of groupingKey.
+      val groupingKey = groupGenerator(currentRow)
+      // Check if the current row belongs the current input row.
+      currentGroupingKey.equals(groupingKey)
+
+      if (currentGroupingKey == groupingKey) {
+        processRow(currentRow)
+      } else {
+        // We find a new group.
+        findNextPartition = true
+        nextGroupingKey = groupingKey.copy()
+        firstRowInNextGroup = currentRow.copy()
+      }
+    }
+    // We have not seen a new group. It means that there is no new row in the input
+    // iter. The current group is the last group of the iter.
+    if (!findNextPartition) {
+      hasNewGroup = false
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Public methods
+  ///////////////////////////////////////////////////////////////////////////
+
+  override final def hasNext: Boolean = hasNewGroup
+
+  override final def next(): InternalRow = {
+    if (hasNext) {
+      // Process the current group.
+      processCurrentGroup()
+      // Generate output row for the current group.
+      val outputRow = generateOutput()
+      // Initilize buffer values for the next group.
+      initializeBuffer()
+
+      outputRow
+    } else {
+      // no more result
+      throw new NoSuchElementException
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Methods that need to be implemented
+  ///////////////////////////////////////////////////////////////////////////
+
+  protected def initialBufferOffset: Int
+
+  protected def processRow(row: InternalRow): Unit
+
+  protected def generateOutput(): InternalRow
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Initialize this iterator
+  ///////////////////////////////////////////////////////////////////////////
+
+  initialize()
+}
+
+/**
+ * An iterator only used to group input rows according to values of `groupingExpressions`.
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ */
+class GroupingIterator(
+    groupingExpressions: Seq[NamedExpression],
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    Nil,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  private val resultProjection =
+    newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
+
+  override protected def initialBufferOffset: Int = 0
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Since we only do grouping, there is nothing to do at here.
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    resultProjection(currentGroupingKey)
+  }
+}
+
+/**
+ * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialSortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    aggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  // This projection is used to update buffer values for all AlgebraicAggregates.
+  private val algebraicUpdateProjection = {
+    val bufferSchema = aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.bufferAttributes
+      case agg: AggregateFunction2 => agg.bufferAttributes
+    }
+    val updateExpressions = aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.updateExpressions
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+    newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+  }
+
+  override protected def initialBufferOffset: Int = 0
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Process all algebraic aggregate functions.
+    algebraicUpdateProjection(joinedRow(buffer, row))
+    // Process all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).update(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // We just output the grouping expressions and the underlying buffer.
+    joinedRow(currentGroupingKey, buffer).copy()
+  }
+}
+
+/**
+ * An iterator used to do partial merge aggregations (for those aggregate functions with mode
+ * PartialMerge). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialMergeSortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    aggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  private val placeholderAttribtues =
+    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+  // This projection is used to merge buffer values for all AlgebraicAggregates.
+  private val algebraicMergeProjection = {
+    val bufferSchemata =
+      placeholderAttribtues ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ placeholderAttribtues ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.mergeExpressions
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+
+    newMutableProjection(mergeExpressions, bufferSchemata)()
+  }
+
+  // This projection is used to extract aggregation buffers from the underlying buffer.
+  // We need it because the underlying buffer has placeholders at its beginning.
+  private val extractsBufferValues = {
+    val expressions = aggregateFunctions.flatMap {
+      case agg => agg.bufferAttributes
+    }
+
+    newMutableProjection(expressions, inputAttributes)()
+  }
+
+  override protected def initialBufferOffset: Int = groupingExpressions.length
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Process all algebraic aggregate functions.
+    algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+    // Process all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // We output grouping expressions and aggregation buffers.
+    joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+  }
+}
+
+/**
+ * An iterator used to do final aggregations (for those aggregate functions with mode
+ * Final). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalSortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    aggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    aggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  // The result of aggregate functions.
+  private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
+
+  // The projection used to generate the output rows of this operator.
+  // This is only used when we are generating final results of aggregate functions.
+  private val resultProjection =
+    newMutableProjection(
+      resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
+
+  private val offsetAttributes =
+    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+  // This projection is used to merge buffer values for all AlgebraicAggregates.
+  private val algebraicMergeProjection = {
+    val bufferSchemata =
+      offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.mergeExpressions
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+
+    newMutableProjection(mergeExpressions, bufferSchemata)()
+  }
+
+  // This projection is used to evaluate all AlgebraicAggregates.
+  private val algebraicEvalProjection = {
+    val bufferSchemata =
+      offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val evalExpressions = aggregateFunctions.map {
+      case ae: AlgebraicAggregate => ae.evaluateExpression
+      case agg: AggregateFunction2 => NoOp
+    }
+
+    newMutableProjection(evalExpressions, bufferSchemata)()
+  }
+
+  override protected def initialBufferOffset: Int = groupingExpressions.length
+
+  override def initialize(): Unit = {
+    if (inputIter.hasNext) {
+      initializeBuffer()
+      val currentRow = inputIter.next().copy()
+      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+      // we are making a copy at here.
+      nextGroupingKey = groupGenerator(currentRow).copy()
+      firstRowInNextGroup = currentRow
+    } else {
+      if (groupingExpressions.isEmpty) {
+        // If there is no grouping expression, we need to generate a single row as the output.
+        initializeBuffer()
+        // Right now, the buffer only contains initial buffer values. Because
+        // merging two buffers with initial values will generate a row that
+        // still store initial values. We set the currentRow as the copy of the current buffer.
+        val currentRow = buffer.copy()
+        nextGroupingKey = groupGenerator(currentRow).copy()
+        firstRowInNextGroup = currentRow
+      } else {
+        // This iter is an empty one.
+        hasNewGroup = false
+      }
+    }
+  }
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Process all algebraic aggregate functions.
+    algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+    // Process all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // Generate results for all algebraic aggregate functions.
+    algebraicEvalProjection.target(aggregateResult)(buffer)
+    // Generate results for all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      aggregateResult.update(
+        nonAlgebraicAggregateFunctionPositions(i),
+        nonAlgebraicAggregateFunctions(i).eval(buffer))
+      i += 1
+    }
+    resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+  }
+}
+
+/**
+ * An iterator used to do both final aggregations (for those aggregate functions with mode
+ * Final) and complete aggregations (for those aggregate functions with mode Complete).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
+ * col1 to colM are columns used by aggregate functions with Complete mode.
+ * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
+ * Final mode.
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
+ * The first N placeholders represent slots of grouping expressions.
+ * Then, next M placeholders represent slots of col1 to colM.
+ * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
+ * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
+ * Complete. The reason that we have placeholders at here is to make our underlying buffer
+ * have the same length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalAndCompleteSortAggregationIterator(
+    override protected val initialBufferOffset: Int,
+    groupingExpressions: Seq[NamedExpression],
+    finalAggregateExpressions: Seq[AggregateExpression2],
+    finalAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    // TODO: document the ordering
+    finalAggregateExpressions ++ completeAggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  // The result of aggregate functions.
+  private val aggregateResult: MutableRow =
+    new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
+
+  // The projection used to generate the output rows of this operator.
+  // This is only used when we are generating final results of aggregate functions.
+  private val resultProjection = {
+    val inputSchema =
+      groupingExpressions.map(_.toAttribute) ++
+        finalAggregateAttributes ++
+        completeAggregateAttributes
+    newMutableProjection(resultExpressions, inputSchema)()
+  }
+
+  private val offsetAttributes =
+    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+  // All aggregate functions with mode Final.
+  private val finalAggregateFunctions: Array[AggregateFunction2] = {
+    val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
+    var i = 0
+    while (i < finalAggregateExpressions.length) {
+      functions(i) = aggregateFunctions(i)
+      i += 1
+    }
+    functions
+  }
+
+  // All non-algebraic aggregate functions with mode Final.
+  private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+    finalAggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }.toArray
+  }
+
+  // All aggregate functions with mode Complete.
+  private val completeAggregateFunctions: Array[AggregateFunction2] = {
+    val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
+    var i = 0
+    while (i < completeAggregateExpressions.length) {
+      functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
+      i += 1
+    }
+    functions
+  }
+
+  // All non-algebraic aggregate functions with mode Complete.
+  private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+    completeAggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }.toArray
+  }
+
+  // This projection is used to merge buffer values for all AlgebraicAggregates with mode
+  // Final.
+  private val finalAlgebraicMergeProjection = {
+    val numCompleteOffsetAttributes =
+      completeAggregateFunctions.map(_.bufferAttributes.length).sum
+    val completeOffsetAttributes =
+      Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)())
+    val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
+
+    val bufferSchemata =
+      offsetAttributes ++ finalAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      } ++ completeOffsetAttributes
+    val mergeExpressions =
+      placeholderExpressions ++ finalAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.mergeExpressions
+        case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+      } ++ completeOffsetExpressions
+
+    newMutableProjection(mergeExpressions, bufferSchemata)()
+  }
+
+  // This projection is used to update buffer values for all AlgebraicAggregates with mode
+  // Complete.
+  private val completeAlgebraicUpdateProjection = {
+    val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum
+    val finalOffsetAttributes =
+      Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)())
+    val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+
+    val bufferSchema =
+      offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      }
+    val updateExpressions =
+      placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.updateExpressions
+        case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+      }
+    newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+  }
+
+  // This projection is used to evaluate all AlgebraicAggregates.
+  private val algebraicEvalProjection = {
+    val bufferSchemata =
+      offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val evalExpressions = aggregateFunctions.map {
+      case ae: AlgebraicAggregate => ae.evaluateExpression
+      case agg: AggregateFunction2 => NoOp
+    }
+
+    newMutableProjection(evalExpressions, bufferSchemata)()
+  }
+
+  override def initialize(): Unit = {
+    if (inputIter.hasNext) {
+      initializeBuffer()
+      val currentRow = inputIter.next().copy()
+      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+      // we are making a copy at here.
+      nextGroupingKey = groupGenerator(currentRow).copy()
+      firstRowInNextGroup = currentRow
+    } else {
+      if (groupingExpressions.isEmpty) {
+        // If there is no grouping expression, we need to generate a single row as the output.
+        initializeBuffer()
+        // Right now, the buffer only contains initial buffer values. Because
+        // merging two buffers with initial values will generate a row that
+        // still store initial values. We set the currentRow as the copy of the current buffer.
+        val currentRow = buffer.copy()
+        nextGroupingKey = groupGenerator(currentRow).copy()
+        firstRowInNextGroup = currentRow
+      } else {
+        // This iter is an empty one.
+        hasNewGroup = false
+      }
+    }
+  }
+
+  override protected def processRow(row: InternalRow): Unit = {
+    val input = joinedRow(buffer, row)
+    // For all aggregate functions with mode Complete, update buffers.
+    completeAlgebraicUpdateProjection(input)
+    var i = 0
+    while (i < completeNonAlgebraicAggregateFunctions.length) {
+      completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
+      i += 1
+    }
+
+    // For all aggregate functions with mode Final, merge buffers.
+    finalAlgebraicMergeProjection.target(buffer)(input)
+    i = 0
+    while (i < finalNonAlgebraicAggregateFunctions.length) {
+      finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // Generate results for all algebraic aggregate functions.
+    algebraicEvalProjection.target(aggregateResult)(buffer)
+    // Generate results for all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      aggregateResult.update(
+        nonAlgebraicAggregateFunctionPositions(i),
+        nonAlgebraicAggregateFunctions(i).eval(buffer))
+      i += 1
+    }
+
+    resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1cb27710e048092a7c3920dc1b81786aa346e4b3
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+  // Right now, we do not support complex types in the grouping key schema.
+  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+    val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+      case array: ArrayType => true
+      case map: MapType => true
+      case struct: StructType => true
+      case _ => false
+    }
+
+    !hasComplexTypes
+  }
+
+  private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+    case p: Aggregate if supportsGroupingKeySchema(p) =>
+      val converted = p.transformExpressionsDown {
+        case expressions.Average(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Average(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Count(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Count(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        // We do not support multiple COUNT DISTINCT columns for now.
+        case expressions.CountDistinct(children) if children.length == 1 =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Count(children.head),
+            mode = aggregate.Complete,
+            isDistinct = true)
+
+        case expressions.First(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.First(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Last(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Last(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Max(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Max(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Min(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Min(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Sum(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Sum(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.SumDistinct(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Sum(child),
+            mode = aggregate.Complete,
+            isDistinct = true)
+      }
+      // Check if there is any expressions.AggregateExpression1 left.
+      // If so, we cannot convert this plan.
+      val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+        // For every expressions, check if it contains AggregateExpression1.
+        expr.find {
+          case agg: expressions.AggregateExpression1 => true
+          case other => false
+        }.isDefined
+      }
+
+      // Check if there are multiple distinct columns.
+      val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+        expr.collect {
+          case agg: AggregateExpression2 => agg
+        }
+      }.toSet.toSeq
+      val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+      val hasMultipleDistinctColumnSets =
+        if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+          true
+        } else {
+          false
+        }
+
+      if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+    case other => None
+  }
+
+  private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+    // If the plan cannot be converted, we will do a final round check to if the original
+    // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+    // we need to throw an exception.
+    val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+      expr.collect {
+        case agg: AggregateExpression2 => agg.aggregateFunction
+      }
+    }.distinct
+    if (aggregateFunction2s.nonEmpty) {
+      // For functions implemented based on the new interface, prepare a list of function names.
+      val invalidFunctions = {
+        if (aggregateFunction2s.length > 1) {
+          s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+            s"and ${aggregateFunction2s.head.nodeName} are"
+        } else {
+          s"${aggregateFunction2s.head.nodeName} is"
+        }
+      }
+      val errorMessage =
+        s"${invalidFunctions} implemented based on the new Aggregate Function " +
+          s"interface and it cannot be used with functions implemented based on " +
+          s"the old Aggregate Function interface."
+      throw new AnalysisException(errorMessage)
+    }
+  }
+
+  def tryConvert(
+      plan: LogicalPlan,
+      useNewAggregation: Boolean,
+      codeGenEnabled: Boolean): Option[Aggregate] = plan match {
+    case p: Aggregate if useNewAggregation && codeGenEnabled =>
+      val converted = tryConvert(p)
+      if (converted.isDefined) {
+        converted
+      } else {
+        checkInvalidAggregateFunction2(p)
+        None
+      }
+    case p: Aggregate =>
+      checkInvalidAggregateFunction2(p)
+      None
+    case other => None
+  }
+
+  def planAggregateWithoutDistinct(
+      groupingExpressions: Seq[Expression],
+      aggregateExpressions: Seq[AggregateExpression2],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+    // 1. Create an Aggregate Operator for partial aggregations.
+    val namedGroupingExpressions = groupingExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      // If the expression is not a NamedExpressions, we add an alias.
+      // So, when we generate the result of the operator, the Aggregate Operator
+      // can directly get the Seq of attributes representing the grouping expressions.
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val groupExpressionMap = namedGroupingExpressions.toMap
+    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+    val partialAggregateExpressions = aggregateExpressions.map {
+      case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+        AggregateExpression2(aggregateFunction, Partial, isDistinct)
+    }
+    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+      agg.aggregateFunction.bufferAttributes
+    }
+    val partialAggregate =
+      Aggregate2Sort(
+        None: Option[Seq[Expression]],
+        namedGroupingExpressions.map(_._2),
+        partialAggregateExpressions,
+        partialAggregateAttributes,
+        namedGroupingAttributes ++ partialAggregateAttributes,
+        child)
+
+    // 2. Create an Aggregate Operator for final aggregations.
+    val finalAggregateExpressions = aggregateExpressions.map {
+      case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+        AggregateExpression2(aggregateFunction, Final, isDistinct)
+    }
+    val finalAggregateAttributes =
+      finalAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      }
+    val rewrittenResultExpressions = resultExpressions.map { expr =>
+      expr.transformDown {
+        case agg: AggregateExpression2 =>
+          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+        case expression =>
+          // We do not rely on the equality check at here since attributes may
+          // different cosmetically. Instead, we use semanticEquals.
+          groupExpressionMap.collectFirst {
+            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+          }.getOrElse(expression)
+      }.asInstanceOf[NamedExpression]
+    }
+    val finalAggregate = Aggregate2Sort(
+      Some(namedGroupingAttributes),
+      namedGroupingAttributes,
+      finalAggregateExpressions,
+      finalAggregateAttributes,
+      rewrittenResultExpressions,
+      partialAggregate)
+
+    finalAggregate :: Nil
+  }
+
+  def planAggregateWithOneDistinct(
+      groupingExpressions: Seq[Expression],
+      functionsWithDistinct: Seq[AggregateExpression2],
+      functionsWithoutDistinct: Seq[AggregateExpression2],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+
+    // 1. Create an Aggregate Operator for partial aggregations.
+    // The grouping expressions are original groupingExpressions and
+    // distinct columns. For example, for avg(distinct value) ... group by key
+    // the grouping expressions of this Aggregate Operator will be [key, value].
+    val namedGroupingExpressions = groupingExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      // If the expression is not a NamedExpressions, we add an alias.
+      // So, when we generate the result of the operator, the Aggregate Operator
+      // can directly get the Seq of attributes representing the grouping expressions.
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val groupExpressionMap = namedGroupingExpressions.toMap
+    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+    // It is safe to call head at here since functionsWithDistinct has at least one
+    // AggregateExpression2.
+    val distinctColumnExpressions =
+      functionsWithDistinct.head.aggregateFunction.children
+    val namedDistinctColumnExpressions = distinctColumnExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
+    val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+
+    val partialAggregateExpressions = functionsWithoutDistinct.map {
+      case AggregateExpression2(aggregateFunction, mode, _) =>
+        AggregateExpression2(aggregateFunction, Partial, false)
+    }
+    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+      agg.aggregateFunction.bufferAttributes
+    }
+    val partialAggregate =
+      Aggregate2Sort(
+        None: Option[Seq[Expression]],
+        (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
+        partialAggregateExpressions,
+        partialAggregateAttributes,
+        namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
+        child)
+
+    // 2. Create an Aggregate Operator for partial merge aggregations.
+    val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
+      case AggregateExpression2(aggregateFunction, mode, _) =>
+        AggregateExpression2(aggregateFunction, PartialMerge, false)
+    }
+    val partialMergeAggregateAttributes =
+      partialMergeAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      }
+    val partialMergeAggregate =
+      Aggregate2Sort(
+        Some(namedGroupingAttributes),
+        namedGroupingAttributes ++ distinctColumnAttributes,
+        partialMergeAggregateExpressions,
+        partialMergeAggregateAttributes,
+        namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
+        partialAggregate)
+
+    // 3. Create an Aggregate Operator for partial merge aggregations.
+    val finalAggregateExpressions = functionsWithoutDistinct.map {
+      case AggregateExpression2(aggregateFunction, mode, _) =>
+        AggregateExpression2(aggregateFunction, Final, false)
+    }
+    val finalAggregateAttributes =
+      finalAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      }
+    val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+      // Children of an AggregateFunction with DISTINCT keyword has already
+      // been evaluated. At here, we need to replace original children
+      // to AttributeReferences.
+      case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+        val rewrittenAggregateFunction = aggregateFunction.transformDown {
+          case expr if distinctColumnExpressionMap.contains(expr) =>
+            distinctColumnExpressionMap(expr).toAttribute
+        }.asInstanceOf[AggregateFunction2]
+        // We rewrite the aggregate function to a non-distinct aggregation because
+        // its input will have distinct arguments.
+        val rewrittenAggregateExpression =
+          AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+
+        val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+        (rewrittenAggregateExpression -> aggregateFunctionAttribute)
+    }.unzip
+
+    val rewrittenResultExpressions = resultExpressions.map { expr =>
+      expr.transform {
+        case agg: AggregateExpression2 =>
+          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+        case expression =>
+          // We do not rely on the equality check at here since attributes may
+          // different cosmetically. Instead, we use semanticEquals.
+          groupExpressionMap.collectFirst {
+            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+          }.getOrElse(expression)
+      }.asInstanceOf[NamedExpression]
+    }
+    val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
+      namedGroupingAttributes ++ distinctColumnAttributes,
+      namedGroupingAttributes,
+      finalAggregateExpressions,
+      finalAggregateAttributes,
+      completeAggregateExpressions,
+      completeAggregateAttributes,
+      rewrittenResultExpressions,
+      partialMergeAggregate)
+
+    finalAndCompleteAggregate :: Nil
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6c49a906c848a1e79713d305363c9ea771385b88
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+
+/**
+ * The abstract class for implementing user-defined aggregate function.
+ */
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+  /**
+   * A [[StructType]] represents data types of input arguments of this aggregate function.
+   * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
+   * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
+   *
+   * ```
+   *   StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+   * ```
+   *
+   * The name of a field of this [[StructType]] is only used to identify the corresponding
+   * input argument. Users can choose names to identify the input arguments.
+   */
+  def inputSchema: StructType
+
+  /**
+   * A [[StructType]] represents data types of values in the aggregation buffer.
+   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+   * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
+   * the returned [[StructType]] will look like
+   *
+   * ```
+   *   StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+   * ```
+   *
+   * The name of a field of this [[StructType]] is only used to identify the corresponding
+   * buffer value. Users can choose names to identify the input arguments.
+   */
+  def bufferSchema: StructType
+
+  /**
+   * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
+   */
+  def returnDataType: DataType
+
+  /** Indicates if this function is deterministic. */
+  def deterministic: Boolean
+
+  /**
+   *  Initializes the given aggregation buffer. Initial values set by this method should satisfy
+   *  the condition that when merging two buffers with initial values, the new buffer should
+   *  still store initial values.
+   */
+  def initialize(buffer: MutableAggregationBuffer): Unit
+
+  /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+  /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+  /**
+   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
+   * aggregation buffer.
+   */
+  def evaluate(buffer: Row): Any
+}
+
+private[sql] abstract class AggregationBuffer(
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int)
+  extends Row {
+
+  override def length: Int = toCatalystConverters.length
+
+  protected val offsets: Array[Int] = {
+    val newOffsets = new Array[Int](length)
+    var i = 0
+    while (i < newOffsets.length) {
+      newOffsets(i) = bufferOffset + i
+      i += 1
+    }
+    newOffsets
+  }
+}
+
+/**
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
+ */
+class MutableAggregationBuffer private[sql] (
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int,
+    var underlyingBuffer: MutableRow)
+  extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+  override def get(i: Int): Any = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not access ${i}th value in this buffer because it only has $length values.")
+    }
+    toScalaConverters(i)(underlyingBuffer(offsets(i)))
+  }
+
+  def update(i: Int, value: Any): Unit = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not update ${i}th value in this buffer because it only has $length values.")
+    }
+    underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+  }
+
+  override def copy(): MutableAggregationBuffer = {
+    new MutableAggregationBuffer(
+      toCatalystConverters,
+      toScalaConverters,
+      bufferOffset,
+      underlyingBuffer)
+  }
+}
+
+/**
+ * A [[Row]] representing an immutable aggregation buffer.
+ */
+class InputAggregationBuffer private[sql] (
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int,
+    var underlyingInputBuffer: Row)
+  extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+  override def get(i: Int): Any = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not access ${i}th value in this buffer because it only has $length values.")
+    }
+    toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
+  }
+
+  override def copy(): InputAggregationBuffer = {
+    new InputAggregationBuffer(
+      toCatalystConverters,
+      toScalaConverters,
+      bufferOffset,
+      underlyingInputBuffer)
+  }
+}
+
+/**
+ * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
+ * internal aggregation code path.
+ * @param children
+ * @param udaf
+ */
+case class ScalaUDAF(
+    children: Seq[Expression],
+    udaf: UserDefinedAggregateFunction)
+  extends AggregateFunction2 with Logging {
+
+  require(
+    children.length == udaf.inputSchema.length,
+    s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
+      s"but ${children.length} are provided.")
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = udaf.returnDataType
+
+  override def deterministic: Boolean = udaf.deterministic
+
+  override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
+
+  override val bufferSchema: StructType = udaf.bufferSchema
+
+  override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
+
+  override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+  val childrenSchema: StructType = {
+    val inputFields = children.zipWithIndex.map {
+      case (child, index) =>
+        StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
+    }
+    StructType(inputFields)
+  }
+
+  lazy val inputProjection = {
+    val inputAttributes = childrenSchema.toAttributes
+    log.debug(
+      s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
+    try {
+      GenerateMutableProjection.generate(children, inputAttributes)()
+    } catch {
+      case e: Exception =>
+        log.error("Failed to generate mutable projection, fallback to interpreted", e)
+        new InterpretedMutableProjection(children, inputAttributes)
+    }
+  }
+
+  val inputToScalaConverters: Any => Any =
+    CatalystTypeConverters.createToScalaConverter(childrenSchema)
+
+  val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+    CatalystTypeConverters.createToCatalystConverter(field.dataType)
+  }
+
+  val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+    CatalystTypeConverters.createToScalaConverter(field.dataType)
+  }
+
+  lazy val inputAggregateBuffer: InputAggregationBuffer =
+    new InputAggregationBuffer(
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      bufferOffset,
+      null)
+
+  lazy val mutableAggregateBuffer: MutableAggregationBuffer =
+    new MutableAggregationBuffer(
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      bufferOffset,
+      null)
+
+
+  override def initialize(buffer: MutableRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer
+
+    udaf.initialize(mutableAggregateBuffer)
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer
+
+    udaf.update(
+      mutableAggregateBuffer,
+      inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
+  }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer1
+    inputAggregateBuffer.underlyingInputBuffer = buffer2
+
+    udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
+  }
+
+  override def eval(buffer: InternalRow = null): Any = {
+    inputAggregateBuffer.underlyingInputBuffer = buffer
+
+    udaf.evaluate(inputAggregateBuffer)
+  }
+
+  override def toString: String = {
+    s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
+  }
+
+  override def nodeName: String = udaf.getClass.getSimpleName
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 28159cbd5ab96d16e7712582b319dd9627f36071..bfeecbe8b2ab5d41fdaa4d890f141595826a453f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2420,7 +2420,7 @@ object functions {
    * @since 1.5.0
    */
   def callUDF(udfName: String, cols: Column*): Column = {
-    UnresolvedFunction(udfName, cols.map(_.expr))
+    UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
   }
 
   /**
@@ -2449,7 +2449,7 @@ object functions {
       exprs(i) = cols(i).expr
       i += 1
     }
-    UnresolvedFunction(udfName, exprs)
+    UnresolvedFunction(udfName, exprs, isDistinct = false)
   }
 
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index beee10173fbc4e910d310003beabec16276d20f6..ab8dce603c117b96cb7f9e2c645a5069571f9d51 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
 
 import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
 import org.apache.spark.sql.execution.GeneratedAggregate
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.TestData._
@@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       var hasGeneratedAgg = false
       df.queryExecution.executedPlan.foreach {
         case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
+        case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
         case _ =>
       }
       if (!hasGeneratedAgg) {
@@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       // Aggregate with Code generation handling all null values
       testCodeGen(
         "SELECT  sum('a'), avg('a'), count(null) FROM testData",
-        Row(0, null, 0) :: Nil)
+        Row(null, null, 0) :: Nil)
     } finally {
       sqlContext.dropTempTable("testData3x")
       sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3dd24130af81a6331f92e7e25ba6db0351deb4f1..3d71deb13e88480f96ac84f728f7a175c07e0d70 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext._
@@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution}
 
 
 class PlannerSuite extends SparkFunSuite {
+  private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+    val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
+    val planned =
+      plannedOption.getOrElse(
+        fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+    val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
+
+    // For the new aggregation code path, there will be three aggregate operator for
+    // distinct aggregations.
+    assert(
+      aggregations.size == 2 || aggregations.size == 3,
+      s"The plan of query $query does not have partial aggregations.")
+  }
+
   test("unions are collapsed") {
     val query = testData.unionAll(testData).unionAll(testData).logicalPlan
     val planned = BasicOperators(query).head
@@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite {
 
   test("count is partially aggregated") {
     val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
-    val planned = HashAggregation(query).head
-    val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
-
-    assert(aggregations.size === 2)
+    testPartialAggregationPlan(query)
   }
 
   test("count distinct is partially aggregated") {
     val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
-    val planned = HashAggregation(query)
-    assert(planned.nonEmpty)
+    testPartialAggregationPlan(query)
   }
 
   test("mixed aggregates are partially aggregated") {
     val query =
       testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
-    val planned = HashAggregation(query)
-    assert(planned.nonEmpty)
+    testPartialAggregationPlan(query)
   }
 
   test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 31a49a36833389f0c2ebee015b320d4db1f26357..24a758f53170a513fc6337b58cf97c69e2d6cf25 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite
     "windowing_adjust_rowcontainer_sz"
   )
 
+  // Only run those query tests in the realWhileList (do not try other ignored query files).
   override def testCases: Seq[(String, File)] = super.testCases.filter {
     case (name, _) => realWhiteList.contains(name)
   }
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
index f458567e5d7eae4cddeee6f55a3042ad17a5f53f..1fe4fe9629c02b1b784ccb8afbe5ba5db1eb34c5 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.hive.execution
 
+import java.io.File
+
 import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.hive.test.TestHive
 
@@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
     "join_reorder4",
     "join_star"
   )
+
+  // Only run those query tests in the realWhileList (do not try other ignored query files).
+  override def testCases: Seq[(String, File)] = super.testCases.filter {
+    case (name, _) => realWhiteList.contains(name)
+  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cec7685bb685974db2cf0219af70e2445172d9dc..4cdb83c5116f9dd4acce2abd1755f8ba83f61366 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
       DataSinks,
       Scripts,
       HashAggregation,
+      Aggregation,
       LeftSemiJoin,
       HashJoin,
       BasicOperators,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f5574509b0b380632b003212f3419b9e457c3e5a..8518e333e805816675512bd486da276fb95c5961 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
 
     /* UDFs - Must be last otherwise will preempt built in functions */
     case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
-      UnresolvedFunction(name, args.map(nodeToExpr))
+      UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
+    // Aggregate function with DISTINCT keyword.
+    case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
+      UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
     case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
-      UnresolvedFunction(name, UnresolvedStar(None) :: Nil)
+      UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
 
     /* Literals */
     case Token("TOK_NULL", Nil) => Literal.create(null, NullType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4d23c7035c03d3e4233c40af39644257053ada56..3259b50acc765a0daea7e8352de1d5396839208d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction(
 
 private[hive] case class HiveGenericUDAF(
     funcWrapper: HiveFunctionWrapper,
-    children: Seq[Expression]) extends AggregateExpression
+    children: Seq[Expression]) extends AggregateExpression1
   with HiveInspectors {
 
   type UDFType = AbstractGenericUDAFResolver
@@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF(
 /** It is used as a wrapper for the hive functions which uses UDAF interface */
 private[hive] case class HiveUDAF(
     funcWrapper: HiveFunctionWrapper,
-    children: Seq[Expression]) extends AggregateExpression
+    children: Seq[Expression]) extends AggregateExpression1
   with HiveInspectors {
 
   type UDFType = UDAF
@@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF(
 private[hive] case class HiveUDAFFunction(
     funcWrapper: HiveFunctionWrapper,
     exprs: Seq[Expression],
-    base: AggregateExpression,
+    base: AggregateExpression1,
     isUDAFBridgeRequired: Boolean = false)
-  extends AggregateFunction
+  extends AggregateFunction1
   with HiveInspectors {
 
   def this() = this(null, null, null)
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
new file mode 100644
index 0000000000000000000000000000000000000000..5c9d0e97a99c6cb8f5a497e490bfcf457070bbb0
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class MyDoubleAvg extends UserDefinedAggregateFunction {
+
+  private StructType _inputDataType;
+
+  private StructType _bufferSchema;
+
+  private DataType _returnDataType;
+
+  public MyDoubleAvg() {
+    List<StructField> inputfields = new ArrayList<StructField>();
+    inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+    _inputDataType = DataTypes.createStructType(inputfields);
+
+    List<StructField> bufferFields = new ArrayList<StructField>();
+    bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
+    bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
+    _bufferSchema = DataTypes.createStructType(bufferFields);
+
+    _returnDataType = DataTypes.DoubleType;
+  }
+
+  @Override public StructType inputSchema() {
+    return _inputDataType;
+  }
+
+  @Override public StructType bufferSchema() {
+    return _bufferSchema;
+  }
+
+  @Override public DataType returnDataType() {
+    return _returnDataType;
+  }
+
+  @Override public boolean deterministic() {
+    return true;
+  }
+
+  @Override public void initialize(MutableAggregationBuffer buffer) {
+    buffer.update(0, null);
+    buffer.update(1, 0L);
+  }
+
+  @Override public void update(MutableAggregationBuffer buffer, Row input) {
+    if (!input.isNullAt(0)) {
+      if (buffer.isNullAt(0)) {
+        buffer.update(0, input.getDouble(0));
+        buffer.update(1, 1L);
+      } else {
+        Double newValue = input.getDouble(0) + buffer.getDouble(0);
+        buffer.update(0, newValue);
+        buffer.update(1, buffer.getLong(1) + 1L);
+      }
+    }
+  }
+
+  @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+    if (!buffer2.isNullAt(0)) {
+      if (buffer1.isNullAt(0)) {
+        buffer1.update(0, buffer2.getDouble(0));
+        buffer1.update(1, buffer2.getLong(1));
+      } else {
+        Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+        buffer1.update(0, newValue);
+        buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
+      }
+    }
+  }
+
+  @Override public Object evaluate(Row buffer) {
+    if (buffer.isNullAt(0)) {
+      return null;
+    } else {
+      return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
+    }
+  }
+}
+
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
new file mode 100644
index 0000000000000000000000000000000000000000..1d4587a27c78737f0e96db7741d18a2458097078
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.Row;
+
+public class MyDoubleSum extends UserDefinedAggregateFunction {
+
+  private StructType _inputDataType;
+
+  private StructType _bufferSchema;
+
+  private DataType _returnDataType;
+
+  public MyDoubleSum() {
+    List<StructField> inputfields = new ArrayList<StructField>();
+    inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+    _inputDataType = DataTypes.createStructType(inputfields);
+
+    List<StructField> bufferFields = new ArrayList<StructField>();
+    bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
+    _bufferSchema = DataTypes.createStructType(bufferFields);
+
+    _returnDataType = DataTypes.DoubleType;
+  }
+
+  @Override public StructType inputSchema() {
+    return _inputDataType;
+  }
+
+  @Override public StructType bufferSchema() {
+    return _bufferSchema;
+  }
+
+  @Override public DataType returnDataType() {
+    return _returnDataType;
+  }
+
+  @Override public boolean deterministic() {
+    return true;
+  }
+
+  @Override public void initialize(MutableAggregationBuffer buffer) {
+    buffer.update(0, null);
+  }
+
+  @Override public void update(MutableAggregationBuffer buffer, Row input) {
+    if (!input.isNullAt(0)) {
+      if (buffer.isNullAt(0)) {
+        buffer.update(0, input.getDouble(0));
+      } else {
+        Double newValue = input.getDouble(0) + buffer.getDouble(0);
+        buffer.update(0, newValue);
+      }
+    }
+  }
+
+  @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+    if (!buffer2.isNullAt(0)) {
+      if (buffer1.isNullAt(0)) {
+        buffer1.update(0, buffer2.getDouble(0));
+      } else {
+        Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+        buffer1.update(0, newValue);
+      }
+    }
+  }
+
+  @Override public Object evaluate(Row buffer) {
+    if (buffer.isNullAt(0)) {
+      return null;
+    } else {
+      return buffer.getDouble(0);
+    }
+  }
+}
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
new file mode 100644
index 0000000000000000000000000000000000000000..573541ac9702dd3969c9bc859d2b91ec1f7e6e56
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
new file mode 100644
index 0000000000000000000000000000000000000000..44b2a42cc26c53977d88671c93c4f5fdc7ab042c
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
@@ -0,0 +1 @@
+unhex(str) - Converts hexadecimal argument to binary
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
new file mode 100644
index 0000000000000000000000000000000000000000..97af3b812a4295d6f54a89a3f80478a2b8f417eb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
@@ -0,0 +1,14 @@
+unhex(str) - Converts hexadecimal argument to binary
+Performs the inverse operation of HEX(str). That is, it interprets
+each pair of hexadecimal digits in the argument as a number and
+converts it to the byte representation of the number. The
+resulting characters are returned as a binary string.
+
+Example:
+> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1;
+'MySQL'
+
+The characters in the argument string must be legal hexadecimal
+digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters
+any nonhexadecimal digits in the argument, it returns NULL. Also,
+if there are an odd number of characters a leading 0 is appended.
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
new file mode 100644
index 0000000000000000000000000000000000000000..b4a6f2b692227cd18a3b47e6d599b29e96cc45a0
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
@@ -0,0 +1 @@
+MySQL	1267	a	-4	
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
new file mode 100644
index 0000000000000000000000000000000000000000..3a67adaf0a9a8c4c5d2daee3e579a369009dcc8c
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
@@ -0,0 +1 @@
+NULL	NULL	NULL
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0375eb79add95d637ff0e2f99a60c02dcaaed358
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -0,0 +1,507 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.scalatest.BeforeAndAfterAll
+import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+
+class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
+
+  override val sqlContext = TestHive
+  import sqlContext.implicits._
+
+  var originalUseAggregate2: Boolean = _
+
+  override def beforeAll(): Unit = {
+    originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
+    sqlContext.sql("set spark.sql.useAggregate2=true")
+    val data1 = Seq[(Integer, Integer)](
+      (1, 10),
+      (null, -60),
+      (1, 20),
+      (1, 30),
+      (2, 0),
+      (null, -10),
+      (2, -1),
+      (2, null),
+      (2, null),
+      (null, 100),
+      (3, null),
+      (null, null),
+      (3, null)).toDF("key", "value")
+    data1.write.saveAsTable("agg1")
+
+    val data2 = Seq[(Integer, Integer, Integer)](
+      (1, 10, -10),
+      (null, -60, 60),
+      (1, 30, -30),
+      (1, 30, 30),
+      (2, 1, 1),
+      (null, -10, 10),
+      (2, -1, null),
+      (2, 1, 1),
+      (2, null, 1),
+      (null, 100, -10),
+      (3, null, 3),
+      (null, null, null),
+      (3, null, null)).toDF("key", "value1", "value2")
+    data2.write.saveAsTable("agg2")
+
+    val emptyDF = sqlContext.createDataFrame(
+      sqlContext.sparkContext.emptyRDD[Row],
+      StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
+    emptyDF.registerTempTable("emptyTable")
+
+    // Register UDAFs
+    sqlContext.udaf.register("mydoublesum", new MyDoubleSum)
+    sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg)
+  }
+
+  override def afterAll(): Unit = {
+    sqlContext.sql("DROP TABLE IF EXISTS agg1")
+    sqlContext.sql("DROP TABLE IF EXISTS agg2")
+    sqlContext.dropTempTable("emptyTable")
+    sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
+  }
+
+  test("empty table") {
+    // If there is no GROUP BY clause and the table is empty, we will generate a single row.
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  AVG(value),
+          |  COUNT(*),
+          |  COUNT(key),
+          |  COUNT(value),
+          |  FIRST(key),
+          |  LAST(value),
+          |  MAX(key),
+          |  MIN(value),
+          |  SUM(key)
+          |FROM emptyTable
+        """.stripMargin),
+      Row(null, 0, 0, 0, null, null, null, null, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  AVG(value),
+          |  COUNT(*),
+          |  COUNT(key),
+          |  COUNT(value),
+          |  FIRST(key),
+          |  LAST(value),
+          |  MAX(key),
+          |  MIN(value),
+          |  SUM(key),
+          |  COUNT(DISTINCT value)
+          |FROM emptyTable
+        """.stripMargin),
+      Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil)
+
+    // If there is a GROUP BY clause and the table is empty, there is no output.
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  AVG(value),
+          |  COUNT(*),
+          |  COUNT(value),
+          |  FIRST(value),
+          |  LAST(value),
+          |  MAX(value),
+          |  MIN(value),
+          |  SUM(value),
+          |  COUNT(DISTINCT value)
+          |FROM emptyTable
+          |GROUP BY key
+        """.stripMargin),
+      Nil)
+  }
+
+  test("only do grouping") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT key
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT DISTINCT value1, key
+          |FROM agg2
+        """.stripMargin),
+      Row(10, 1) ::
+        Row(-60, null) ::
+        Row(30, 1) ::
+        Row(1, 2) ::
+        Row(-10, null) ::
+        Row(-1, 2) ::
+        Row(null, 2) ::
+        Row(100, null) ::
+        Row(null, 3) ::
+        Row(null, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT value1, key
+          |FROM agg2
+          |GROUP BY key, value1
+        """.stripMargin),
+      Row(10, 1) ::
+        Row(-60, null) ::
+        Row(30, 1) ::
+        Row(1, 2) ::
+        Row(-10, null) ::
+        Row(-1, 2) ::
+        Row(null, 2) ::
+        Row(100, null) ::
+        Row(null, 3) ::
+        Row(null, null) :: Nil)
+  }
+
+  test("case in-sensitive resolution") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT avg(value), kEY - 100
+          |FROM agg1
+          |GROUP BY Key - 100
+        """.stripMargin),
+      Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT sum(distinct value1), kEY - 100, count(distinct value1)
+          |FROM agg2
+          |GROUP BY Key - 100
+        """.stripMargin),
+      Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT valUe * key - 100
+          |FROM agg1
+          |GROUP BY vAlue * keY - 100
+        """.stripMargin),
+      Row(-90) ::
+        Row(-80) ::
+        Row(-70) ::
+        Row(-100) ::
+        Row(-102) ::
+        Row(null) :: Nil)
+  }
+
+  test("test average no key in output") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil)
+  }
+
+  test("test average") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT key, avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT avg(value), key
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT avg(value) + 1.5, key + 10
+          |FROM agg1
+          |GROUP BY key + 10
+        """.stripMargin),
+      Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT avg(value) FROM agg1
+        """.stripMargin),
+      Row(11.125) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT avg(null)
+        """.stripMargin),
+      Row(null) :: Nil)
+  }
+
+  test("udaf") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoublesum(value + 1.5 * key),
+          |  mydoubleavg(value),
+          |  avg(value - key),
+          |  mydoublesum(value - 1.5 * key),
+          |  avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) ::
+        Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) ::
+        Row(3, null, null, null, null, null) ::
+        Row(null, null, 110.0, null, null, 10.0) :: Nil)
+  }
+
+  test("non-AlgebraicAggregate aggreguate function") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT mydoublesum(value), key
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT mydoublesum(value) FROM agg1
+        """.stripMargin),
+      Row(89.0) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT mydoublesum(null)
+        """.stripMargin),
+      Row(null) :: Nil)
+  }
+
+  test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT mydoublesum(value), key, avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(60.0, 1, 20.0) ::
+        Row(-1.0, 2, -0.5) ::
+        Row(null, 3, null) ::
+        Row(30.0, null, 10.0) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  mydoublesum(value + 1.5 * key),
+          |  avg(value - key),
+          |  key,
+          |  mydoublesum(value - 1.5 * key),
+          |  avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(64.5, 19.0, 1, 55.5, 20.0) ::
+        Row(5.0, -2.5, 2, -7.0, -0.5) ::
+        Row(null, null, 3, null, null) ::
+        Row(null, null, null, null, 10.0) :: Nil)
+  }
+
+  test("single distinct column set") {
+    // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  min(distinct value1),
+          |  sum(distinct value1),
+          |  avg(value1),
+          |  avg(value2),
+          |  max(distinct value1)
+          |FROM agg2
+        """.stripMargin),
+      Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  mydoubleavg(distinct value1),
+          |  avg(value1),
+          |  avg(value2),
+          |  key,
+          |  mydoubleavg(value1 - 1),
+          |  mydoubleavg(distinct value1) * 0.1,
+          |  avg(value1 + value2)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+        Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+        Row(null, null, 3.0, 3, null, null, null) ::
+        Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoubleavg(distinct value1),
+          |  mydoublesum(value2),
+          |  mydoublesum(distinct value1),
+          |  mydoubleavg(distinct value1),
+          |  mydoubleavg(value1)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+        Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+        Row(3, null, 3.0, null, null, null) ::
+        Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+  }
+
+  test("test count") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  count(value2),
+          |  value1,
+          |  count(*),
+          |  count(1),
+          |  key
+          |FROM agg2
+          |GROUP BY key, value1
+        """.stripMargin),
+      Row(1, 10, 1, 1, 1) ::
+        Row(1, -60, 1, 1, null) ::
+        Row(2, 30, 2, 2, 1) ::
+        Row(2, 1, 2, 2, 2) ::
+        Row(1, -10, 1, 1, null) ::
+        Row(0, -1, 1, 1, 2) ::
+        Row(1, null, 1, 1, 2) ::
+        Row(1, 100, 1, 1, null) ::
+        Row(1, null, 2, 2, 3) ::
+        Row(0, null, 1, 1, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  count(value2),
+          |  value1,
+          |  count(*),
+          |  count(1),
+          |  key,
+          |  count(DISTINCT abs(value2))
+          |FROM agg2
+          |GROUP BY key, value1
+        """.stripMargin),
+      Row(1, 10, 1, 1, 1, 1) ::
+        Row(1, -60, 1, 1, null, 1) ::
+        Row(2, 30, 2, 2, 1, 1) ::
+        Row(2, 1, 2, 2, 2, 1) ::
+        Row(1, -10, 1, 1, null, 1) ::
+        Row(0, -1, 1, 1, 2, 0) ::
+        Row(1, null, 1, 1, 2, 1) ::
+        Row(1, 100, 1, 1, null, 1) ::
+        Row(1, null, 2, 2, 3, 1) ::
+        Row(0, null, 1, 1, null, 0) :: Nil)
+  }
+
+  test("error handling") {
+    sqlContext.sql(s"set spark.sql.useAggregate2=false")
+    var errorMessage = intercept[AnalysisException] {
+      sqlContext.sql(
+        """
+          |SELECT
+          |  key,
+          |  sum(value + 1.5 * key),
+          |  mydoublesum(value),
+          |  mydoubleavg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin).collect()
+    }.getMessage
+    assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+    // TODO: once we support Hive UDAF in the new interface,
+    // we can remove the following two tests.
+    sqlContext.sql(s"set spark.sql.useAggregate2=true")
+    errorMessage = intercept[AnalysisException] {
+      sqlContext.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoublesum(value + 1.5 * key),
+          |  stddev_samp(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin).collect()
+    }.getMessage
+    assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+    // This will fall back to the old aggregate
+    val newAggregateOperators = sqlContext.sql(
+      """
+        |SELECT
+        |  key,
+        |  sum(value + 1.5 * key),
+        |  stddev_samp(value)
+        |FROM agg1
+        |GROUP BY key
+      """.stripMargin).queryExecution.executedPlan.collect {
+      case agg: Aggregate2Sort => agg
+    }
+    val message =
+      "We should fallback to the old aggregation code path if there is any aggregate function " +
+        "that cannot be converted to the new interface."
+    assert(newAggregateOperators.isEmpty, message)
+
+    sqlContext.sql(s"set spark.sql.useAggregate2=true")
+  }
+}