diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index d08f553cefe8cc6c3b2a8175a213c90a6342ab17..4abfdfe87d5e9733a6ef3c76dc5ea407894fb4e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -110,7 +110,11 @@ abstract class AggregateFunction2
    * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
    * will be 2.
    */
-  var mutableBufferOffset: Int = 0
+  protected var mutableBufferOffset: Int = 0
+
+  def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
+    mutableBufferOffset = newMutableBufferOffset
+  }
 
   /**
    * The offset of this function's start buffer value in the
@@ -126,7 +130,11 @@ abstract class AggregateFunction2
    * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
    * will be 3 (position 0 is used for the value of key`).
    */
-  var inputBufferOffset: Int = 0
+  protected var inputBufferOffset: Int = 0
+
+  def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
+    inputBufferOffset = newInputBufferOffset
+  }
 
   /** The schema of the aggregation buffer. */
   def bufferSchema: StructType
@@ -195,11 +203,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
   override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
 
   override def initialize(buffer: MutableRow): Unit = {
-    var i = 0
-    while (i < bufferAttributes.size) {
-      buffer(i + mutableBufferOffset) = initialValues(i).eval()
-      i += 1
-    }
+    throw new UnsupportedOperationException(
+      "AlgebraicAggregate's initialize should not be called directly")
   }
 
   override final def update(buffer: MutableRow, input: InternalRow): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
new file mode 100644
index 0000000000000000000000000000000000000000..cf568dc048674f2c01ecf9c849df46c97fcb17c2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types
+ * of the grouping expressions and aggregate functions, it determines if it uses
+ * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to
+ * process input rows.
+ */
+case class Aggregate(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  private[this] val allAggregateExpressions =
+    nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+  private[this] val hasNonAlgebricAggregateFunctions =
+    !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
+
+  // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
+  // grouping key and aggregation buffer is supported; and (3) all
+  // aggregate functions are algebraic.
+  private[this] val supportsHybridIterator: Boolean = {
+    val aggregationBufferSchema: StructType =
+      StructType.fromAttributes(
+        allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+    val groupKeySchema: StructType =
+      StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
+
+    val schemaSupportsUnsafe: Boolean =
+      UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+        UnsafeProjection.canSupport(groupKeySchema)
+
+    // TODO: Use the hybrid iterator for non-algebric aggregate functions.
+    sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions
+  }
+
+  // We need to use sorted input if we have grouping expressions, and
+  // we cannot use the hybrid iterator or the hybrid is disabled.
+  private[this] val requiresSortedInput: Boolean = {
+    groupingExpressions.nonEmpty && !supportsHybridIterator
+  }
+
+  override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions
+
+  // If result expressions' data types are all fixed length, we generate unsafe rows
+  // (We have this requirement instead of check the result of UnsafeProjection.canSupport
+  // is because we use a mutable projection to generate the result).
+  override def outputsUnsafeRows: Boolean = {
+    // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
+    // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix
+    // any issue we get.
+    false
+  }
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    if (requiresSortedInput) {
+      // TODO: We should not sort the input rows if they are just in reversed order.
+      groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+    } else {
+      Seq.fill(children.size)(Nil)
+    }
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    if (requiresSortedInput) {
+      // It is possible that the child.outputOrdering starts with the required
+      // ordering expressions (e.g. we require [a] as the sort expression and the
+      // child's outputOrdering is [a, b]). We can only guarantee the output rows
+      // are sorted by values of groupingExpressions.
+      groupingExpressions.map(SortOrder(_, Ascending))
+    } else {
+      Nil
+    }
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+    child.execute().mapPartitions { iter =>
+      // Because the constructor of an aggregation iterator will read at least the first row,
+      // we need to get the value of iter.hasNext first.
+      val hasInput = iter.hasNext
+      val useHybridIterator =
+        hasInput &&
+          supportsHybridIterator &&
+          groupingExpressions.nonEmpty
+      if (useHybridIterator) {
+        UnsafeHybridAggregationIterator.createFromInputIterator(
+          groupingExpressions,
+          nonCompleteAggregateExpressions,
+          nonCompleteAggregateAttributes,
+          completeAggregateExpressions,
+          completeAggregateAttributes,
+          initialInputBufferOffset,
+          resultExpressions,
+          newMutableProjection _,
+          child.output,
+          iter,
+          outputsUnsafeRows)
+      } else {
+        if (!hasInput && groupingExpressions.nonEmpty) {
+          // This is a grouped aggregate and the input iterator is empty,
+          // so return an empty iterator.
+          Iterator[InternalRow]()
+        } else {
+          val outputIter = SortBasedAggregationIterator.createFromInputIterator(
+            groupingExpressions,
+            nonCompleteAggregateExpressions,
+            nonCompleteAggregateAttributes,
+            completeAggregateExpressions,
+            completeAggregateAttributes,
+            initialInputBufferOffset,
+            resultExpressions,
+            newMutableProjection _ ,
+            newProjection _,
+            child.output,
+            iter,
+            outputsUnsafeRows)
+          if (!hasInput && groupingExpressions.isEmpty) {
+            // There is no input and there is no grouping expressions.
+            // We need to output a single row as the output.
+            Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+          } else {
+            outputIter
+          }
+        }
+      }
+    }
+  }
+
+  override def simpleString: String = {
+    val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) {
+      classOf[UnsafeHybridAggregationIterator].getSimpleName
+    } else {
+      classOf[SortBasedAggregationIterator].getSimpleName
+    }
+
+    s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}"""
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..abca373b0c4f9dc869446b8705462c3129a7fb83
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.unsafe.KVIterator
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]].
+ * It mainly contains two parts:
+ * 1. It initializes aggregate functions.
+ * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
+ *    its aggregate functions. `processRow` is the function to handle an input. `generateOutput`
+ *    is used to generate result.
+ */
+abstract class AggregationIterator(
+    groupingKeyAttributes: Seq[Attribute],
+    valueAttributes: Seq[Attribute],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    outputsUnsafeRows: Boolean)
+  extends Iterator[InternalRow] with Logging {
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Initializing functions.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // An Seq of all AggregateExpressions.
+  // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
+  // are at the beginning of the allAggregateExpressions.
+  protected val allAggregateExpressions =
+    nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+  require(
+    allAggregateExpressions.map(_.mode).distinct.length <= 2,
+    s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.")
+
+  /**
+   * The distinct modes of AggregateExpressions. Right now, we can handle the following mode:
+   *  - Partial-only: all AggregateExpressions have the mode of Partial;
+   *  - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge);
+   *  - Final-only: all AggregateExpressions have the mode of Final;
+   *  - Final-Complete: some AggregateExpressions have the mode of Final and
+   *    others have the mode of Complete;
+   *  - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions
+   *    with mode Complete in completeAggregateExpressions; and
+   *  - Grouping-only: there is no AggregateExpression.
+   */
+  protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) =
+    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+      completeAggregateExpressions.map(_.mode).distinct.headOption
+
+  // Initialize all AggregateFunctions by binding references if necessary,
+  // and set inputBufferOffset and mutableBufferOffset.
+  protected val allAggregateFunctions: Array[AggregateFunction2] = {
+    var mutableBufferOffset = 0
+    var inputBufferOffset: Int = initialInputBufferOffset
+    val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+    var i = 0
+    while (i < allAggregateExpressions.length) {
+      val func = allAggregateExpressions(i).aggregateFunction
+      val funcWithBoundReferences = allAggregateExpressions(i).mode match {
+        case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+          // We need to create BoundReferences if the function is not an
+          // AlgebraicAggregate (it does not support code-gen) and the mode of
+          // this function is Partial or Complete because we will call eval of this
+          // function's children in the update method of this aggregate function.
+          // Those eval calls require BoundReferences to work.
+          BindReferences.bindReference(func, valueAttributes)
+        case _ =>
+          // We only need to set inputBufferOffset for aggregate functions with mode
+          // PartialMerge and Final.
+          func.withNewInputBufferOffset(inputBufferOffset)
+          inputBufferOffset += func.bufferSchema.length
+          func
+      }
+      // Set mutableBufferOffset for this function. It is important that setting
+      // mutableBufferOffset happens after all potential bindReference operations
+      // because bindReference will create a new instance of the function.
+      funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset)
+      mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
+      functions(i) = funcWithBoundReferences
+      i += 1
+    }
+    functions
+  }
+
+  // Positions of those non-algebraic aggregate functions in allAggregateFunctions.
+  // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+  // func2 and func3 are non-algebraic aggregate functions.
+  // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+  private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = {
+    val positions = new ArrayBuffer[Int]()
+    var i = 0
+    while (i < allAggregateFunctions.length) {
+      allAggregateFunctions(i) match {
+        case agg: AlgebraicAggregate =>
+        case _ => positions += i
+      }
+      i += 1
+    }
+    positions.toArray
+  }
+
+  // All AggregateFunctions functions with mode Partial, PartialMerge, or Final.
+  private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+    allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+
+  // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final.
+  private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+    nonCompleteAggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }
+
+  // The projection used to initialize buffer values for all AlgebraicAggregates.
+  private[this] val algebraicInitialProjection = {
+    val initExpressions = allAggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.initialValues
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+    newMutableProjection(initExpressions, Nil)()
+  }
+
+  // All non-Algebraic AggregateFunctions.
+  private[this] val allNonAlgebraicAggregateFunctions =
+    allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions)
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Methods and fields used by sub-classes.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // Initializing functions used to process a row.
+  protected val processRow: (MutableRow, InternalRow) => Unit = {
+    val rowToBeProcessed = new JoinedRow
+    val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+    aggregationMode match {
+      // Partial-only
+      case (Some(Partial), None) =>
+        val updateExpressions = nonCompleteAggregateFunctions.flatMap {
+          case ae: AlgebraicAggregate => ae.updateExpressions
+          case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+        }
+        val algebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+        (currentBuffer: MutableRow, row: InternalRow) => {
+          algebraicUpdateProjection.target(currentBuffer)
+          // Process all algebraic aggregate functions.
+          algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row))
+          // Process all non-algebraic aggregate functions.
+          var i = 0
+          while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+            nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+            i += 1
+          }
+        }
+
+      // PartialMerge-only or Final-only
+      case (Some(PartialMerge), None) | (Some(Final), None) =>
+        val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) {
+          // If initialInputBufferOffset, the input value does not contain
+          // grouping keys.
+          // This part is pretty hacky.
+          allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq
+        } else {
+          groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+        }
+        // val inputAggregationBufferSchema =
+        //  groupingKeyAttributes ++
+        //    allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+        val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
+          case ae: AlgebraicAggregate => ae.mergeExpressions
+          case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+        }
+        // This projection is used to merge buffer values for all AlgebraicAggregates.
+        val algebraicMergeProjection =
+          newMutableProjection(
+            mergeExpressions,
+            aggregationBufferSchema ++ inputAggregationBufferSchema)()
+
+        (currentBuffer: MutableRow, row: InternalRow) => {
+          // Process all algebraic aggregate functions.
+          algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
+          // Process all non-algebraic aggregate functions.
+          var i = 0
+          while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+            nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+            i += 1
+          }
+        }
+
+      // Final-Complete
+      case (Some(Final), Some(Complete)) =>
+        val completeAggregateFunctions: Array[AggregateFunction2] =
+          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+        // All non-algebraic aggregate functions with mode Complete.
+        val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+          completeAggregateFunctions.collect {
+            case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+          }
+
+        // The first initialInputBufferOffset values of the input aggregation buffer is
+        // for grouping expressions and distinct columns.
+        val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset)
+
+        val completeOffsetExpressions =
+          Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+        // We do not touch buffer values of aggregate functions with the Final mode.
+        val finalOffsetExpressions =
+          Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+
+        val mergeInputSchema =
+          aggregationBufferSchema ++
+            groupingAttributesAndDistinctColumns ++
+            nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes)
+        val mergeExpressions =
+          nonCompleteAggregateFunctions.flatMap {
+            case ae: AlgebraicAggregate => ae.mergeExpressions
+            case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+          } ++ completeOffsetExpressions
+        val finalAlgebraicMergeProjection =
+          newMutableProjection(mergeExpressions, mergeInputSchema)()
+
+        val updateExpressions =
+          finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+            case ae: AlgebraicAggregate => ae.updateExpressions
+            case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+          }
+        val completeAlgebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+        (currentBuffer: MutableRow, row: InternalRow) => {
+          val input = rowToBeProcessed(currentBuffer, row)
+          // For all aggregate functions with mode Complete, update buffers.
+          completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+          var i = 0
+          while (i < completeNonAlgebraicAggregateFunctions.length) {
+            completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+            i += 1
+          }
+
+          // For all aggregate functions with mode Final, merge buffers.
+          finalAlgebraicMergeProjection.target(currentBuffer)(input)
+          i = 0
+          while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+            nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+            i += 1
+          }
+        }
+
+      // Complete-only
+      case (None, Some(Complete)) =>
+        val completeAggregateFunctions: Array[AggregateFunction2] =
+          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+        // All non-algebraic aggregate functions with mode Complete.
+        val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+          completeAggregateFunctions.collect {
+            case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+          }
+
+        val updateExpressions =
+          completeAggregateFunctions.flatMap {
+            case ae: AlgebraicAggregate => ae.updateExpressions
+            case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+          }
+        val completeAlgebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+        (currentBuffer: MutableRow, row: InternalRow) => {
+          val input = rowToBeProcessed(currentBuffer, row)
+          // For all aggregate functions with mode Complete, update buffers.
+          completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+          var i = 0
+          while (i < completeNonAlgebraicAggregateFunctions.length) {
+            completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+            i += 1
+          }
+        }
+
+      // Grouping only.
+      case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {}
+
+      case other =>
+        sys.error(
+          s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
+            s"support evaluate modes $other in this iterator.")
+    }
+  }
+
+  // Initializing the function used to generate the output row.
+  protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
+    val rowToBeEvaluated = new JoinedRow
+    val safeOutoutRow = new GenericMutableRow(resultExpressions.length)
+    val mutableOutput = if (outputsUnsafeRows) {
+      UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow)
+    } else {
+      safeOutoutRow
+    }
+
+    aggregationMode match {
+      // Partial-only or PartialMerge-only: every output row is basically the values of
+      // the grouping expressions and the corresponding aggregation buffer.
+      case (Some(Partial), None) | (Some(PartialMerge), None) =>
+        // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not
+        // support generic getter), we create a mutable projection to output the
+        // JoinedRow(currentGroupingKey, currentBuffer)
+        val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes)
+        val resultProjection =
+          newMutableProjection(
+            groupingKeyAttributes ++ bufferSchema,
+            groupingKeyAttributes ++ bufferSchema)()
+        resultProjection.target(mutableOutput)
+
+        (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+          resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer))
+          // rowToBeEvaluated(currentGroupingKey, currentBuffer)
+        }
+
+      // Final-only, Complete-only and Final-Complete: every output row contains values representing
+      // resultExpressions.
+      case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+        val bufferSchemata =
+          allAggregateFunctions.flatMap(_.bufferAttributes)
+        val evalExpressions = allAggregateFunctions.map {
+          case ae: AlgebraicAggregate => ae.evaluateExpression
+          case agg: AggregateFunction2 => NoOp
+        }
+        val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
+        val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
+        // TODO: Use unsafe row.
+        val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
+        val resultProjection =
+          newMutableProjection(
+            resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
+        resultProjection.target(mutableOutput)
+
+        (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+          // Generate results for all algebraic aggregate functions.
+          algebraicEvalProjection.target(aggregateResult)(currentBuffer)
+          // Generate results for all non-algebraic aggregate functions.
+          var i = 0
+          while (i < allNonAlgebraicAggregateFunctions.length) {
+            aggregateResult.update(
+              allNonAlgebraicAggregateFunctionPositions(i),
+              allNonAlgebraicAggregateFunctions(i).eval(currentBuffer))
+            i += 1
+          }
+          resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult))
+        }
+
+      // Grouping-only: we only output values of grouping expressions.
+      case (None, None) =>
+        val resultProjection =
+          newMutableProjection(resultExpressions, groupingKeyAttributes)()
+        resultProjection.target(mutableOutput)
+
+        (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+          resultProjection(currentGroupingKey)
+        }
+
+      case other =>
+        sys.error(
+          s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
+            s"support evaluate modes $other in this iterator.")
+    }
+  }
+
+  /** Initializes buffer values for all aggregate functions. */
+  protected def initializeBuffer(buffer: MutableRow): Unit = {
+    algebraicInitialProjection.target(buffer)(EmptyRow)
+    var i = 0
+    while (i < allNonAlgebraicAggregateFunctions.length) {
+      allNonAlgebraicAggregateFunctions(i).initialize(buffer)
+      i += 1
+    }
+  }
+
+  /**
+   * Creates a new aggregation buffer and initializes buffer values
+   * for all aggregate functions.
+   */
+  protected def newBuffer: MutableRow
+}
+
+object AggregationIterator {
+  def kvIterator(
+    groupingExpressions: Seq[NamedExpression],
+    newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
+    new KVIterator[InternalRow, InternalRow] {
+      private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
+
+      private[this] var groupingKey: InternalRow = _
+
+      private[this] var value: InternalRow = _
+
+      override def next(): Boolean = {
+        if (inputIter.hasNext) {
+          // Read the next input row.
+          val inputRow = inputIter.next()
+          // Get groupingKey based on groupingExpressions.
+          groupingKey = groupingKeyGenerator(inputRow)
+          // The value is the inputRow.
+          value = inputRow
+          true
+        } else {
+          false
+        }
+      }
+
+      override def getKey(): InternalRow = {
+        groupingKey
+      }
+
+      override def getValue(): InternalRow = {
+        value
+      }
+
+      override def close(): Unit = {
+        // Do nothing
+      }
+    }
+  }
+
+  def unsafeKVIterator(
+      groupingExpressions: Seq[NamedExpression],
+      inputAttributes: Seq[Attribute],
+      inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
+    new KVIterator[UnsafeRow, InternalRow] {
+      private[this] val groupingKeyGenerator =
+        UnsafeProjection.create(groupingExpressions, inputAttributes)
+
+      private[this] var groupingKey: UnsafeRow = _
+
+      private[this] var value: InternalRow = _
+
+      override def next(): Boolean = {
+        if (inputIter.hasNext) {
+          // Read the next input row.
+          val inputRow = inputIter.next()
+          // Get groupingKey based on groupingExpressions.
+          groupingKey = groupingKeyGenerator.apply(inputRow)
+          // The value is the inputRow.
+          value = inputRow
+          true
+        } else {
+          false
+        }
+      }
+
+      override def getKey(): UnsafeRow = {
+        groupingKey
+      }
+
+      override def getValue(): InternalRow = {
+        value
+      }
+
+      override def close(): Unit = {
+        // Do nothing
+      }
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..78bcee16c9d005e6ba5795e24b1c88eff110be95
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
+import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
+
+/**
+ * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
+ * sorted by values of [[groupingKeyAttributes]].
+ */
+class SortBasedAggregationIterator(
+    groupingKeyAttributes: Seq[Attribute],
+    valueAttributes: Seq[Attribute],
+    inputKVIterator: KVIterator[InternalRow, InternalRow],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    outputsUnsafeRows: Boolean)
+  extends AggregationIterator(
+    groupingKeyAttributes,
+    valueAttributes,
+    nonCompleteAggregateExpressions,
+    nonCompleteAggregateAttributes,
+    completeAggregateExpressions,
+    completeAggregateAttributes,
+    initialInputBufferOffset,
+    resultExpressions,
+    newMutableProjection,
+    outputsUnsafeRows) {
+
+  override protected def newBuffer: MutableRow = {
+    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val bufferRowSize: Int = bufferSchema.length
+
+    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+    val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength)
+
+    val buffer = if (useUnsafeBuffer) {
+      val unsafeProjection =
+        UnsafeProjection.create(bufferSchema.map(_.dataType))
+      unsafeProjection.apply(genericMutableBuffer)
+    } else {
+      genericMutableBuffer
+    }
+    initializeBuffer(buffer)
+    buffer
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Mutable states for sort based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // The partition key of the current partition.
+  private[this] var currentGroupingKey: InternalRow = _
+
+  // The partition key of next partition.
+  private[this] var nextGroupingKey: InternalRow = _
+
+  // The first row of next partition.
+  private[this] var firstRowInNextGroup: InternalRow = _
+
+  // Indicates if we has new group of rows from the sorted input iterator
+  private[this] var sortedInputHasNewGroup: Boolean = false
+
+  // The aggregation buffer used by the sort-based aggregation.
+  private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
+
+  /** Processes rows in the current group. It will stop when it find a new group. */
+  protected def processCurrentSortedGroup(): Unit = {
+    currentGroupingKey = nextGroupingKey
+    // Now, we will start to find all rows belonging to this group.
+    // We create a variable to track if we see the next group.
+    var findNextPartition = false
+    // firstRowInNextGroup is the first row of this group. We first process it.
+    processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+    // The search will stop when we see the next group or there is no
+    // input row left in the iter.
+    var hasNext = inputKVIterator.next()
+    while (!findNextPartition && hasNext) {
+      // Get the grouping key.
+      val groupingKey = inputKVIterator.getKey
+      val currentRow = inputKVIterator.getValue
+
+      // Check if the current row belongs the current input row.
+      if (currentGroupingKey == groupingKey) {
+        processRow(sortBasedAggregationBuffer, currentRow)
+
+        hasNext = inputKVIterator.next()
+      } else {
+        // We find a new group.
+        findNextPartition = true
+        nextGroupingKey = groupingKey.copy()
+        firstRowInNextGroup = currentRow.copy()
+      }
+    }
+    // We have not seen a new group. It means that there is no new row in the input
+    // iter. The current group is the last group of the iter.
+    if (!findNextPartition) {
+      sortedInputHasNewGroup = false
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Iterator's public methods
+  ///////////////////////////////////////////////////////////////////////////
+
+  override final def hasNext: Boolean = sortedInputHasNewGroup
+
+  override final def next(): InternalRow = {
+    if (hasNext) {
+      // Process the current group.
+      processCurrentSortedGroup()
+      // Generate output row for the current group.
+      val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
+      // Initialize buffer values for the next group.
+      initializeBuffer(sortBasedAggregationBuffer)
+
+      outputRow
+    } else {
+      // no more result
+      throw new NoSuchElementException
+    }
+  }
+
+  protected def initialize(): Unit = {
+    if (inputKVIterator.next()) {
+      initializeBuffer(sortBasedAggregationBuffer)
+
+      nextGroupingKey = inputKVIterator.getKey().copy()
+      firstRowInNextGroup = inputKVIterator.getValue().copy()
+
+      sortedInputHasNewGroup = true
+    } else {
+      // This inputIter is empty.
+      sortedInputHasNewGroup = false
+    }
+  }
+
+  initialize()
+
+  def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
+    initializeBuffer(sortBasedAggregationBuffer)
+    generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
+  }
+}
+
+object SortBasedAggregationIterator {
+  // scalastyle:off
+  def createFromInputIterator(
+      groupingExprs: Seq[NamedExpression],
+      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+      nonCompleteAggregateAttributes: Seq[Attribute],
+      completeAggregateExpressions: Seq[AggregateExpression2],
+      completeAggregateAttributes: Seq[Attribute],
+      initialInputBufferOffset: Int,
+      resultExpressions: Seq[NamedExpression],
+      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+      newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
+      inputAttributes: Seq[Attribute],
+      inputIter: Iterator[InternalRow],
+      outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
+    val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
+      AggregationIterator.unsafeKVIterator(
+        groupingExprs,
+        inputAttributes,
+        inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
+    } else {
+      AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
+    }
+
+    new SortBasedAggregationIterator(
+      groupingExprs.map(_.toAttribute),
+      inputAttributes,
+      kvIterator,
+      nonCompleteAggregateExpressions,
+      nonCompleteAggregateAttributes,
+      completeAggregateExpressions,
+      completeAggregateAttributes,
+      initialInputBufferOffset,
+      resultExpressions,
+      newMutableProjection,
+      outputsUnsafeRows)
+  }
+
+  def createFromKVIterator(
+      groupingKeyAttributes: Seq[Attribute],
+      valueAttributes: Seq[Attribute],
+      inputKVIterator: KVIterator[InternalRow, InternalRow],
+      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+      nonCompleteAggregateAttributes: Seq[Attribute],
+      completeAggregateExpressions: Seq[AggregateExpression2],
+      completeAggregateAttributes: Seq[Attribute],
+      initialInputBufferOffset: Int,
+      resultExpressions: Seq[NamedExpression],
+      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+      outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
+    new SortBasedAggregationIterator(
+      groupingKeyAttributes,
+      valueAttributes,
+      inputKVIterator,
+      nonCompleteAggregateExpressions,
+      nonCompleteAggregateAttributes,
+      completeAggregateExpressions,
+      completeAggregateAttributes,
+      initialInputBufferOffset,
+      resultExpressions,
+      newMutableProjection,
+      outputsUnsafeRows)
+  }
+  // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..37d34eb7ccf09def949af717448416a6d8591122
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate [[AggregateFunction2]].
+ * It first tries to use in-memory hash-based aggregation. If we cannot allocate more
+ * space for the hash map, we spill the sorted map entries, free the map, and then
+ * switch to sort-based aggregation.
+ */
+class UnsafeHybridAggregationIterator(
+    groupingKeyAttributes: Seq[Attribute],
+    valueAttributes: Seq[Attribute],
+    inputKVIterator: KVIterator[UnsafeRow, InternalRow],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    outputsUnsafeRows: Boolean)
+  extends AggregationIterator(
+    groupingKeyAttributes,
+    valueAttributes,
+    nonCompleteAggregateExpressions,
+    nonCompleteAggregateAttributes,
+    completeAggregateExpressions,
+    completeAggregateAttributes,
+    initialInputBufferOffset,
+    resultExpressions,
+    newMutableProjection,
+    outputsUnsafeRows) {
+
+  require(groupingKeyAttributes.nonEmpty)
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Unsafe Aggregation buffers
+  ///////////////////////////////////////////////////////////////////////////
+
+  // This is the Unsafe Aggregation Map used to store all buffers.
+  private[this] val buffers = new UnsafeFixedWidthAggregationMap(
+    newBuffer,
+    StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+    StructType.fromAttributes(groupingKeyAttributes),
+    TaskContext.get.taskMemoryManager(),
+    SparkEnv.get.shuffleMemoryManager,
+    1024 * 16, // initial capacity
+    SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
+    false // disable tracking of performance metrics
+  )
+
+  override protected def newBuffer: UnsafeRow = {
+    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val bufferRowSize: Int = bufferSchema.length
+
+    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+    val unsafeProjection =
+      UnsafeProjection.create(bufferSchema.map(_.dataType))
+    val buffer = unsafeProjection.apply(genericMutableBuffer)
+    initializeBuffer(buffer)
+    buffer
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Methods and variables related to switching to sort-based aggregation
+  ///////////////////////////////////////////////////////////////////////////
+  private[this] var sortBased = false
+
+  private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _
+
+  // The value part of the input KV iterator is used to store original input values of
+  // aggregate functions, we need to convert them to aggregation buffers.
+  private def processOriginalInput(
+      firstKey: UnsafeRow,
+      firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
+    new KVIterator[UnsafeRow, UnsafeRow] {
+      private[this] var isFirstRow = true
+
+      private[this] var groupingKey: UnsafeRow = _
+
+      private[this] val buffer: UnsafeRow = newBuffer
+
+      override def next(): Boolean = {
+        initializeBuffer(buffer)
+        if (isFirstRow) {
+          isFirstRow = false
+          groupingKey = firstKey
+          processRow(buffer, firstValue)
+
+          true
+        } else if (inputKVIterator.next()) {
+          groupingKey = inputKVIterator.getKey()
+          val value = inputKVIterator.getValue()
+          processRow(buffer, value)
+
+          true
+        } else {
+          false
+        }
+      }
+
+      override def getKey(): UnsafeRow = {
+        groupingKey
+      }
+
+      override def getValue(): UnsafeRow = {
+        buffer
+      }
+
+      override def close(): Unit = {
+        // Do nothing.
+      }
+    }
+  }
+
+  // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer.
+  // We need to project the aggregation buffer out.
+  private def projectInputBufferToUnsafe(
+      firstKey: UnsafeRow,
+      firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
+    new KVIterator[UnsafeRow, UnsafeRow] {
+      private[this] var isFirstRow = true
+
+      private[this] var groupingKey: UnsafeRow = _
+
+      private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+
+      private[this] val value: UnsafeRow = {
+        val genericMutableRow = new GenericMutableRow(bufferSchema.length)
+        UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow)
+      }
+
+      private[this] val projectInputBuffer = {
+        newMutableProjection(bufferSchema, valueAttributes)().target(value)
+      }
+
+      override def next(): Boolean = {
+        if (isFirstRow) {
+          isFirstRow = false
+          groupingKey = firstKey
+          projectInputBuffer(firstValue)
+
+          true
+        } else if (inputKVIterator.next()) {
+          groupingKey = inputKVIterator.getKey()
+          projectInputBuffer(inputKVIterator.getValue())
+
+          true
+        } else {
+          false
+        }
+      }
+
+      override def getKey(): UnsafeRow = {
+        groupingKey
+      }
+
+      override def getValue(): UnsafeRow = {
+        value
+      }
+
+      override def close(): Unit = {
+        // Do nothing.
+      }
+    }
+  }
+
+  /**
+   * We need to fall back to sort based aggregation because we do not have enough memory
+   * for our in-memory hash map (i.e. `buffers`).
+   */
+  private def switchToSortBasedAggregation(
+      currentGroupingKey: UnsafeRow,
+      currentRow: InternalRow): Unit = {
+    logInfo("falling back to sort based aggregation.")
+
+    // Step 1: Get the ExternalSorter containing entries of the map.
+    val externalSorter = buffers.destructAndCreateExternalSorter()
+
+    // Step 2: Free the memory used by the map.
+    buffers.free()
+
+    // Step 3: If we have aggregate function with mode Partial or Complete,
+    // we need to process them to get aggregation buffer.
+    // So, later in the sort-based aggregation iterator, we can do merge.
+    // If aggregate functions are with mode Final and PartialMerge,
+    // we just need to project the aggregation buffer from the input.
+    val needsProcess = aggregationMode match {
+      case (Some(Partial), None) => true
+      case (None, Some(Complete)) => true
+      case (Some(Final), Some(Complete)) => true
+      case _ => false
+    }
+
+    val processedIterator = if (needsProcess) {
+      processOriginalInput(currentGroupingKey, currentRow)
+    } else {
+      // The input value's format is groupingExprs + buffer.
+      // We need to project the buffer part out.
+      projectInputBufferToUnsafe(currentGroupingKey, currentRow)
+    }
+
+    // Step 4: Redirect processedIterator to externalSorter.
+    while (processedIterator.next()) {
+      externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue())
+    }
+
+    // Step 5: Get the sorted iterator from the externalSorter.
+    val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
+
+    // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
+    // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
+    // will be PartialMerge. For a aggregate function with mode Complete,
+    // its mode in the SortBasedAggregationIterator will be Final.
+    val newNonCompleteAggregateExpressions = allAggregateExpressions.map {
+        case AggregateExpression2(func, Partial, isDistinct) =>
+          AggregateExpression2(func, PartialMerge, isDistinct)
+        case AggregateExpression2(func, Complete, isDistinct) =>
+          AggregateExpression2(func, Final, isDistinct)
+        case other => other
+      }
+    val newNonCompleteAggregateAttributes =
+      nonCompleteAggregateAttributes ++ completeAggregateAttributes
+
+    val newValueAttributes =
+      allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+
+    sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator(
+      groupingKeyAttributes = groupingKeyAttributes,
+      valueAttributes = newValueAttributes,
+      inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]],
+      nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions,
+      nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes,
+      completeAggregateExpressions = Nil,
+      completeAggregateAttributes = Nil,
+      initialInputBufferOffset = 0,
+      resultExpressions = resultExpressions,
+      newMutableProjection = newMutableProjection,
+      outputsUnsafeRows = outputsUnsafeRows)
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Methods used to initialize this iterator.
+  ///////////////////////////////////////////////////////////////////////////
+
+  /** Starts to read input rows and falls back to sort-based aggregation if necessary. */
+  protected def initialize(): Unit = {
+    var hasNext = inputKVIterator.next()
+    while (!sortBased && hasNext) {
+      val groupingKey = inputKVIterator.getKey()
+      val currentRow = inputKVIterator.getValue()
+      val buffer = buffers.getAggregationBuffer(groupingKey)
+      if (buffer == null) {
+        // buffer == null means that we could not allocate more memory.
+        // Now, we need to spill the map and switch to sort-based aggregation.
+        switchToSortBasedAggregation(groupingKey, currentRow)
+        sortBased = true
+      } else {
+        processRow(buffer, currentRow)
+        hasNext = inputKVIterator.next()
+      }
+    }
+  }
+
+  // This is the starting point of this iterator.
+  initialize()
+
+  // Creates the iterator for the Hash Aggregation Map after we have populated
+  // contents of that map.
+  private[this] val aggregationBufferMapIterator = buffers.iterator()
+
+  private[this] var _mapIteratorHasNext = false
+
+  // Pre-load the first key-value pair from the map to make hasNext idempotent.
+  if (!sortBased) {
+    _mapIteratorHasNext = aggregationBufferMapIterator.next()
+    // If the map is empty, we just free it.
+    if (!_mapIteratorHasNext) {
+      buffers.free()
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Iterator's public methods
+  ///////////////////////////////////////////////////////////////////////////
+
+  override final def hasNext: Boolean = {
+    (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext)
+  }
+
+
+  override final def next(): InternalRow = {
+    if (hasNext) {
+      if (sortBased) {
+        sortBasedAggregationIterator.next()
+      } else {
+        // We did not fall back to the sort-based aggregation.
+        val result =
+          generateOutput(
+            aggregationBufferMapIterator.getKey,
+            aggregationBufferMapIterator.getValue)
+        // Pre-load next key-value pair form aggregationBufferMapIterator.
+        _mapIteratorHasNext = aggregationBufferMapIterator.next()
+
+        if (!_mapIteratorHasNext) {
+          val resultCopy = result.copy()
+          buffers.free()
+          resultCopy
+        } else {
+          result
+        }
+      }
+    } else {
+      // no more result
+      throw new NoSuchElementException
+    }
+  }
+}
+
+object UnsafeHybridAggregationIterator {
+  // scalastyle:off
+  def createFromInputIterator(
+      groupingExprs: Seq[NamedExpression],
+      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+      nonCompleteAggregateAttributes: Seq[Attribute],
+      completeAggregateExpressions: Seq[AggregateExpression2],
+      completeAggregateAttributes: Seq[Attribute],
+      initialInputBufferOffset: Int,
+      resultExpressions: Seq[NamedExpression],
+      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+      inputAttributes: Seq[Attribute],
+      inputIter: Iterator[InternalRow],
+      outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
+    new UnsafeHybridAggregationIterator(
+      groupingExprs.map(_.toAttribute),
+      inputAttributes,
+      AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter),
+      nonCompleteAggregateExpressions,
+      nonCompleteAggregateAttributes,
+      completeAggregateExpressions,
+      completeAggregateAttributes,
+      initialInputBufferOffset,
+      resultExpressions,
+      newMutableProjection,
+      outputsUnsafeRows)
+  }
+
+  def createFromKVIterator(
+      groupingKeyAttributes: Seq[Attribute],
+      valueAttributes: Seq[Attribute],
+      inputKVIterator: KVIterator[UnsafeRow, InternalRow],
+      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+      nonCompleteAggregateAttributes: Seq[Attribute],
+      completeAggregateExpressions: Seq[AggregateExpression2],
+      completeAggregateAttributes: Seq[Attribute],
+      initialInputBufferOffset: Int,
+      resultExpressions: Seq[NamedExpression],
+      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+      outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
+    new UnsafeHybridAggregationIterator(
+      groupingKeyAttributes,
+      valueAttributes,
+      inputKVIterator,
+      nonCompleteAggregateExpressions,
+      nonCompleteAggregateAttributes,
+      completeAggregateExpressions,
+      completeAggregateAttributes,
+      initialInputBufferOffset,
+      resultExpressions,
+      newMutableProjection,
+      outputsUnsafeRows)
+  }
+  // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
deleted file mode 100644
index 98538c462bc89477b343556c6ba682b1efaa569c..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
-
-case class Aggregate2Sort(
-    requiredChildDistributionExpressions: Option[Seq[Expression]],
-    groupingExpressions: Seq[NamedExpression],
-    aggregateExpressions: Seq[AggregateExpression2],
-    aggregateAttributes: Seq[Attribute],
-    resultExpressions: Seq[NamedExpression],
-    child: SparkPlan)
-  extends UnaryNode {
-
-  override def canProcessUnsafeRows: Boolean = true
-
-  override def references: AttributeSet = {
-    val referencesInResults =
-      AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
-
-    AttributeSet(
-      groupingExpressions.flatMap(_.references) ++
-      aggregateExpressions.flatMap(_.references) ++
-      referencesInResults)
-  }
-
-  override def requiredChildDistribution: List[Distribution] = {
-    requiredChildDistributionExpressions match {
-      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
-      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
-      case None => UnspecifiedDistribution :: Nil
-    }
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
-    // TODO: We should not sort the input rows if they are just in reversed order.
-    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
-  }
-
-  override def outputOrdering: Seq[SortOrder] = {
-    // It is possible that the child.outputOrdering starts with the required
-    // ordering expressions (e.g. we require [a] as the sort expression and the
-    // child's outputOrdering is [a, b]). We can only guarantee the output rows
-    // are sorted by values of groupingExpressions.
-    groupingExpressions.map(SortOrder(_, Ascending))
-  }
-
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
-    child.execute().mapPartitions { iter =>
-      if (aggregateExpressions.length == 0) {
-        new FinalSortAggregationIterator(
-          groupingExpressions,
-          Nil,
-          Nil,
-          resultExpressions,
-          newMutableProjection,
-          child.output,
-          iter)
-      } else {
-        val aggregationIterator: SortAggregationIterator = {
-          aggregateExpressions.map(_.mode).distinct.toList match {
-            case Partial :: Nil =>
-              new PartialSortAggregationIterator(
-                groupingExpressions,
-                aggregateExpressions,
-                newMutableProjection,
-                child.output,
-                iter)
-            case PartialMerge :: Nil =>
-              new PartialMergeSortAggregationIterator(
-                groupingExpressions,
-                aggregateExpressions,
-                newMutableProjection,
-                child.output,
-                iter)
-            case Final :: Nil =>
-              new FinalSortAggregationIterator(
-                groupingExpressions,
-                aggregateExpressions,
-                aggregateAttributes,
-                resultExpressions,
-                newMutableProjection,
-                child.output,
-                iter)
-            case other =>
-              sys.error(
-                s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
-                  s"modes $other in this operator.")
-          }
-        }
-
-        aggregationIterator
-      }
-    }
-  }
-}
-
-case class FinalAndCompleteAggregate2Sort(
-    previousGroupingExpressions: Seq[NamedExpression],
-    groupingExpressions: Seq[NamedExpression],
-    finalAggregateExpressions: Seq[AggregateExpression2],
-    finalAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression2],
-    completeAggregateAttributes: Seq[Attribute],
-    resultExpressions: Seq[NamedExpression],
-    child: SparkPlan)
-  extends UnaryNode {
-  override def references: AttributeSet = {
-    val referencesInResults =
-      AttributeSet(resultExpressions.flatMap(_.references)) --
-        AttributeSet(finalAggregateExpressions) --
-        AttributeSet(completeAggregateExpressions)
-
-    AttributeSet(
-      groupingExpressions.flatMap(_.references) ++
-        finalAggregateExpressions.flatMap(_.references) ++
-        completeAggregateExpressions.flatMap(_.references) ++
-        referencesInResults)
-  }
-
-  override def requiredChildDistribution: List[Distribution] = {
-    if (groupingExpressions.isEmpty) {
-      AllTuples :: Nil
-    } else {
-      ClusteredDistribution(groupingExpressions) :: Nil
-    }
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
-
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
-    child.execute().mapPartitions { iter =>
-
-      new FinalAndCompleteSortAggregationIterator(
-        previousGroupingExpressions.length,
-        groupingExpressions,
-        finalAggregateExpressions,
-        finalAggregateAttributes,
-        completeAggregateExpressions,
-        completeAggregateAttributes,
-        resultExpressions,
-        newMutableProjection,
-        child.output,
-        iter)
-    }
-  }
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
deleted file mode 100644
index 2ca0cb82c1aabcfce8673624d908e7d04384f87d..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ /dev/null
@@ -1,664 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.types.NullType
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * An iterator used to evaluate aggregate functions. It assumes that input rows
- * are already grouped by values of `groupingExpressions`.
- */
-private[sql] abstract class SortAggregationIterator(
-    groupingExpressions: Seq[NamedExpression],
-    aggregateExpressions: Seq[AggregateExpression2],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow])
-  extends Iterator[InternalRow] {
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Static fields for this iterator
-  ///////////////////////////////////////////////////////////////////////////
-
-  protected val aggregateFunctions: Array[AggregateFunction2] = {
-    var mutableBufferOffset = 0
-    var inputBufferOffset: Int = initialInputBufferOffset
-    val functions = new Array[AggregateFunction2](aggregateExpressions.length)
-    var i = 0
-    while (i < aggregateExpressions.length) {
-      val func = aggregateExpressions(i).aggregateFunction
-      val funcWithBoundReferences = aggregateExpressions(i).mode match {
-        case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
-          // We need to create BoundReferences if the function is not an
-          // AlgebraicAggregate (it does not support code-gen) and the mode of
-          // this function is Partial or Complete because we will call eval of this
-          // function's children in the update method of this aggregate function.
-          // Those eval calls require BoundReferences to work.
-          BindReferences.bindReference(func, inputAttributes)
-        case _ =>
-          // We only need to set inputBufferOffset for aggregate functions with mode
-          // PartialMerge and Final.
-          func.inputBufferOffset = inputBufferOffset
-          inputBufferOffset += func.bufferSchema.length
-          func
-      }
-      // Set mutableBufferOffset for this function. It is important that setting
-      // mutableBufferOffset happens after all potential bindReference operations
-      // because bindReference will create a new instance of the function.
-      funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
-      mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
-      functions(i) = funcWithBoundReferences
-      i += 1
-    }
-    functions
-  }
-
-  // Positions of those non-algebraic aggregate functions in aggregateFunctions.
-  // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
-  // func2 and func3 are non-algebraic aggregate functions.
-  // nonAlgebraicAggregateFunctionPositions will be [1, 2].
-  protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
-    val positions = new ArrayBuffer[Int]()
-    var i = 0
-    while (i < aggregateFunctions.length) {
-      aggregateFunctions(i) match {
-        case agg: AlgebraicAggregate =>
-        case _ => positions += i
-      }
-      i += 1
-    }
-    positions.toArray
-  }
-
-  // All non-algebraic aggregate functions.
-  protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
-    nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
-
-  // This is used to project expressions for the grouping expressions.
-  protected val groupGenerator =
-    newMutableProjection(groupingExpressions, inputAttributes)()
-
-  // The underlying buffer shared by all aggregate functions.
-  protected val buffer: MutableRow = {
-    // The number of elements of the underlying buffer of this operator.
-    // All aggregate functions are sharing this underlying buffer and they find their
-    // buffer values through bufferOffset.
-    // var size = 0
-    // var i = 0
-    // while (i < aggregateFunctions.length) {
-    //  size += aggregateFunctions(i).bufferSchema.length
-    //  i += 1
-    // }
-    new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum)
-  }
-
-  protected val joinedRow = new JoinedRow
-
-  // This projection is used to initialize buffer values for all AlgebraicAggregates.
-  protected val algebraicInitialProjection = {
-    val initExpressions = aggregateFunctions.flatMap {
-      case ae: AlgebraicAggregate => ae.initialValues
-      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
-    }
-
-    newMutableProjection(initExpressions, Nil)().target(buffer)
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Mutable states
-  ///////////////////////////////////////////////////////////////////////////
-
-  // The partition key of the current partition.
-  protected var currentGroupingKey: InternalRow = _
-  // The partition key of next partition.
-  protected var nextGroupingKey: InternalRow = _
-  // The first row of next partition.
-  protected var firstRowInNextGroup: InternalRow = _
-  // Indicates if we has new group of rows to process.
-  protected var hasNewGroup: Boolean = true
-
-  /** Initializes buffer values for all aggregate functions. */
-  protected def initializeBuffer(): Unit = {
-    algebraicInitialProjection(EmptyRow)
-    var i = 0
-    while (i < nonAlgebraicAggregateFunctions.length) {
-      nonAlgebraicAggregateFunctions(i).initialize(buffer)
-      i += 1
-    }
-  }
-
-  protected def initialize(): Unit = {
-    if (inputIter.hasNext) {
-      initializeBuffer()
-      val currentRow = inputIter.next().copy()
-      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
-      // we are making a copy at here.
-      nextGroupingKey = groupGenerator(currentRow).copy()
-      firstRowInNextGroup = currentRow
-    } else {
-      // This iter is an empty one.
-      hasNewGroup = false
-    }
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Private methods
-  ///////////////////////////////////////////////////////////////////////////
-
-  /** Processes rows in the current group. It will stop when it find a new group. */
-  private def processCurrentGroup(): Unit = {
-    currentGroupingKey = nextGroupingKey
-    // Now, we will start to find all rows belonging to this group.
-    // We create a variable to track if we see the next group.
-    var findNextPartition = false
-    // firstRowInNextGroup is the first row of this group. We first process it.
-    processRow(firstRowInNextGroup)
-    // The search will stop when we see the next group or there is no
-    // input row left in the iter.
-    while (inputIter.hasNext && !findNextPartition) {
-      val currentRow = inputIter.next()
-      // Get the grouping key based on the grouping expressions.
-      // For the below compare method, we do not need to make a copy of groupingKey.
-      val groupingKey = groupGenerator(currentRow)
-      // Check if the current row belongs the current input row.
-      if (currentGroupingKey == groupingKey) {
-        processRow(currentRow)
-      } else {
-        // We find a new group.
-        findNextPartition = true
-        nextGroupingKey = groupingKey.copy()
-        firstRowInNextGroup = currentRow.copy()
-      }
-    }
-    // We have not seen a new group. It means that there is no new row in the input
-    // iter. The current group is the last group of the iter.
-    if (!findNextPartition) {
-      hasNewGroup = false
-    }
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Public methods
-  ///////////////////////////////////////////////////////////////////////////
-
-  override final def hasNext: Boolean = hasNewGroup
-
-  override final def next(): InternalRow = {
-    if (hasNext) {
-      // Process the current group.
-      processCurrentGroup()
-      // Generate output row for the current group.
-      val outputRow = generateOutput()
-      // Initilize buffer values for the next group.
-      initializeBuffer()
-
-      outputRow
-    } else {
-      // no more result
-      throw new NoSuchElementException
-    }
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Methods that need to be implemented
-  ///////////////////////////////////////////////////////////////////////////
-
-  /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */
-  protected def initialInputBufferOffset: Int
-
-  /** The function used to process an input row. */
-  protected def processRow(row: InternalRow): Unit
-
-  /** The function used to generate the result row. */
-  protected def generateOutput(): InternalRow
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Initialize this iterator
-  ///////////////////////////////////////////////////////////////////////////
-
-  initialize()
-}
-
-/**
- * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- * The format of its output rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- */
-class PartialSortAggregationIterator(
-    groupingExpressions: Seq[NamedExpression],
-    aggregateExpressions: Seq[AggregateExpression2],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow])
-  extends SortAggregationIterator(
-    groupingExpressions,
-    aggregateExpressions,
-    newMutableProjection,
-    inputAttributes,
-    inputIter) {
-
-  // This projection is used to update buffer values for all AlgebraicAggregates.
-  private val algebraicUpdateProjection = {
-    val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes)
-    val updateExpressions = aggregateFunctions.flatMap {
-      case ae: AlgebraicAggregate => ae.updateExpressions
-      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
-    }
-    newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
-  }
-
-  override protected def initialInputBufferOffset: Int = 0
-
-  override protected def processRow(row: InternalRow): Unit = {
-    // Process all algebraic aggregate functions.
-    algebraicUpdateProjection(joinedRow(buffer, row))
-    // Process all non-algebraic aggregate functions.
-    var i = 0
-    while (i < nonAlgebraicAggregateFunctions.length) {
-      nonAlgebraicAggregateFunctions(i).update(buffer, row)
-      i += 1
-    }
-  }
-
-  override protected def generateOutput(): InternalRow = {
-    // We just output the grouping expressions and the underlying buffer.
-    joinedRow(currentGroupingKey, buffer).copy()
-  }
-}
-
-/**
- * An iterator used to do partial merge aggregations (for those aggregate functions with mode
- * PartialMerge). It assumes that input rows are already grouped by values of
- * `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its output rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- */
-class PartialMergeSortAggregationIterator(
-    groupingExpressions: Seq[NamedExpression],
-    aggregateExpressions: Seq[AggregateExpression2],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow])
-  extends SortAggregationIterator(
-    groupingExpressions,
-    aggregateExpressions,
-    newMutableProjection,
-    inputAttributes,
-    inputIter) {
-
-  // This projection is used to merge buffer values for all AlgebraicAggregates.
-  private val algebraicMergeProjection = {
-    val mergeInputSchema =
-      aggregateFunctions.flatMap(_.bufferAttributes) ++
-        groupingExpressions.map(_.toAttribute) ++
-        aggregateFunctions.flatMap(_.cloneBufferAttributes)
-    val mergeExpressions = aggregateFunctions.flatMap {
-      case ae: AlgebraicAggregate => ae.mergeExpressions
-      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
-    }
-
-    newMutableProjection(mergeExpressions, mergeInputSchema)()
-  }
-
-  override protected def initialInputBufferOffset: Int = groupingExpressions.length
-
-  override protected def processRow(row: InternalRow): Unit = {
-    // Process all algebraic aggregate functions.
-    algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
-    // Process all non-algebraic aggregate functions.
-    var i = 0
-    while (i < nonAlgebraicAggregateFunctions.length) {
-      nonAlgebraicAggregateFunctions(i).merge(buffer, row)
-      i += 1
-    }
-  }
-
-  override protected def generateOutput(): InternalRow = {
-    // We output grouping expressions and aggregation buffers.
-    joinedRow(currentGroupingKey, buffer).copy()
-  }
-}
-
-/**
- * An iterator used to do final aggregations (for those aggregate functions with mode
- * Final). It assumes that input rows are already grouped by values of
- * `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its output rows is represented by the schema of `resultExpressions`.
- */
-class FinalSortAggregationIterator(
-    groupingExpressions: Seq[NamedExpression],
-    aggregateExpressions: Seq[AggregateExpression2],
-    aggregateAttributes: Seq[Attribute],
-    resultExpressions: Seq[NamedExpression],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow])
-  extends SortAggregationIterator(
-    groupingExpressions,
-    aggregateExpressions,
-    newMutableProjection,
-    inputAttributes,
-    inputIter) {
-
-  // The result of aggregate functions.
-  private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
-
-  // The projection used to generate the output rows of this operator.
-  // This is only used when we are generating final results of aggregate functions.
-  private val resultProjection =
-    newMutableProjection(
-      resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
-
-  // This projection is used to merge buffer values for all AlgebraicAggregates.
-  private val algebraicMergeProjection = {
-    val mergeInputSchema =
-      aggregateFunctions.flatMap(_.bufferAttributes) ++
-        groupingExpressions.map(_.toAttribute) ++
-        aggregateFunctions.flatMap(_.cloneBufferAttributes)
-    val mergeExpressions = aggregateFunctions.flatMap {
-      case ae: AlgebraicAggregate => ae.mergeExpressions
-      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
-    }
-
-    newMutableProjection(mergeExpressions, mergeInputSchema)()
-  }
-
-  // This projection is used to evaluate all AlgebraicAggregates.
-  private val algebraicEvalProjection = {
-    val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
-    val evalExpressions = aggregateFunctions.map {
-      case ae: AlgebraicAggregate => ae.evaluateExpression
-      case agg: AggregateFunction2 => NoOp
-    }
-
-    newMutableProjection(evalExpressions, bufferSchemata)()
-  }
-
-  override protected def initialInputBufferOffset: Int = groupingExpressions.length
-
-  override def initialize(): Unit = {
-    if (inputIter.hasNext) {
-      initializeBuffer()
-      val currentRow = inputIter.next().copy()
-      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
-      // we are making a copy at here.
-      nextGroupingKey = groupGenerator(currentRow).copy()
-      firstRowInNextGroup = currentRow
-    } else {
-      if (groupingExpressions.isEmpty) {
-        // If there is no grouping expression, we need to generate a single row as the output.
-        initializeBuffer()
-        // Right now, the buffer only contains initial buffer values. Because
-        // merging two buffers with initial values will generate a row that
-        // still store initial values. We set the currentRow as the copy of the current buffer.
-        // Because input aggregation buffer has initialInputBufferOffset extra values at the
-        // beginning, we create a dummy row for this part.
-        val currentRow =
-          joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
-        nextGroupingKey = groupGenerator(currentRow).copy()
-        firstRowInNextGroup = currentRow
-      } else {
-        // This iter is an empty one.
-        hasNewGroup = false
-      }
-    }
-  }
-
-  override protected def processRow(row: InternalRow): Unit = {
-    // Process all algebraic aggregate functions.
-    algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
-    // Process all non-algebraic aggregate functions.
-    var i = 0
-    while (i < nonAlgebraicAggregateFunctions.length) {
-      nonAlgebraicAggregateFunctions(i).merge(buffer, row)
-      i += 1
-    }
-  }
-
-  override protected def generateOutput(): InternalRow = {
-    // Generate results for all algebraic aggregate functions.
-    algebraicEvalProjection.target(aggregateResult)(buffer)
-    // Generate results for all non-algebraic aggregate functions.
-    var i = 0
-    while (i < nonAlgebraicAggregateFunctions.length) {
-      aggregateResult.update(
-        nonAlgebraicAggregateFunctionPositions(i),
-        nonAlgebraicAggregateFunctions(i).eval(buffer))
-      i += 1
-    }
-    resultProjection(joinedRow(currentGroupingKey, aggregateResult))
-  }
-}
-
-/**
- * An iterator used to do both final aggregations (for those aggregate functions with mode
- * Final) and complete aggregations (for those aggregate functions with mode Complete).
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
- * col1 to colM are columns used by aggregate functions with Complete mode.
- * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
- * Final mode.
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBuffer(N+M)|
- * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
- * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
- * Complete.
- *
- * The format of its output rows is represented by the schema of `resultExpressions`.
- */
-class FinalAndCompleteSortAggregationIterator(
-    override protected val initialInputBufferOffset: Int,
-    groupingExpressions: Seq[NamedExpression],
-    finalAggregateExpressions: Seq[AggregateExpression2],
-    finalAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression2],
-    completeAggregateAttributes: Seq[Attribute],
-    resultExpressions: Seq[NamedExpression],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow])
-  extends SortAggregationIterator(
-    groupingExpressions,
-    // TODO: document the ordering
-    finalAggregateExpressions ++ completeAggregateExpressions,
-    newMutableProjection,
-    inputAttributes,
-    inputIter) {
-
-  // The result of aggregate functions.
-  private val aggregateResult: MutableRow =
-    new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
-
-  // The projection used to generate the output rows of this operator.
-  // This is only used when we are generating final results of aggregate functions.
-  private val resultProjection = {
-    val inputSchema =
-      groupingExpressions.map(_.toAttribute) ++
-        finalAggregateAttributes ++
-        completeAggregateAttributes
-    newMutableProjection(resultExpressions, inputSchema)()
-  }
-
-  // All aggregate functions with mode Final.
-  private val finalAggregateFunctions: Array[AggregateFunction2] = {
-    val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
-    var i = 0
-    while (i < finalAggregateExpressions.length) {
-      functions(i) = aggregateFunctions(i)
-      i += 1
-    }
-    functions
-  }
-
-  // All non-algebraic aggregate functions with mode Final.
-  private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
-    finalAggregateFunctions.collect {
-      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
-    }
-
-  // All aggregate functions with mode Complete.
-  private val completeAggregateFunctions: Array[AggregateFunction2] = {
-    val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
-    var i = 0
-    while (i < completeAggregateExpressions.length) {
-      functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
-      i += 1
-    }
-    functions
-  }
-
-  // All non-algebraic aggregate functions with mode Complete.
-  private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
-    completeAggregateFunctions.collect {
-      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
-    }
-
-  // This projection is used to merge buffer values for all AlgebraicAggregates with mode
-  // Final.
-  private val finalAlgebraicMergeProjection = {
-    // The first initialInputBufferOffset values of the input aggregation buffer is
-    // for grouping expressions and distinct columns.
-    val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset)
-
-    val completeOffsetExpressions =
-      Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
-
-    val mergeInputSchema =
-      finalAggregateFunctions.flatMap(_.bufferAttributes) ++
-        completeAggregateFunctions.flatMap(_.bufferAttributes) ++
-        groupingAttributesAndDistinctColumns ++
-        finalAggregateFunctions.flatMap(_.cloneBufferAttributes)
-    val mergeExpressions =
-      finalAggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.mergeExpressions
-        case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
-      } ++ completeOffsetExpressions
-    newMutableProjection(mergeExpressions, mergeInputSchema)()
-  }
-
-  // This projection is used to update buffer values for all AlgebraicAggregates with mode
-  // Complete.
-  private val completeAlgebraicUpdateProjection = {
-    // We do not touch buffer values of aggregate functions with the Final mode.
-    val finalOffsetExpressions =
-      Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
-
-    val bufferSchema =
-      finalAggregateFunctions.flatMap(_.bufferAttributes) ++
-        completeAggregateFunctions.flatMap(_.bufferAttributes)
-    val updateExpressions =
-      finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.updateExpressions
-        case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
-      }
-    newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
-  }
-
-  // This projection is used to evaluate all AlgebraicAggregates.
-  private val algebraicEvalProjection = {
-    val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
-    val evalExpressions = aggregateFunctions.map {
-      case ae: AlgebraicAggregate => ae.evaluateExpression
-      case agg: AggregateFunction2 => NoOp
-    }
-
-    newMutableProjection(evalExpressions, bufferSchemata)()
-  }
-
-  override def initialize(): Unit = {
-    if (inputIter.hasNext) {
-      initializeBuffer()
-      val currentRow = inputIter.next().copy()
-      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
-      // we are making a copy at here.
-      nextGroupingKey = groupGenerator(currentRow).copy()
-      firstRowInNextGroup = currentRow
-    } else {
-      if (groupingExpressions.isEmpty) {
-        // If there is no grouping expression, we need to generate a single row as the output.
-        initializeBuffer()
-        // Right now, the buffer only contains initial buffer values. Because
-        // merging two buffers with initial values will generate a row that
-        // still store initial values. We set the currentRow as the copy of the current buffer.
-        // Because input aggregation buffer has initialInputBufferOffset extra values at the
-        // beginning, we create a dummy row for this part.
-        val currentRow =
-          joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
-        nextGroupingKey = groupGenerator(currentRow).copy()
-        firstRowInNextGroup = currentRow
-      } else {
-        // This iter is an empty one.
-        hasNewGroup = false
-      }
-    }
-  }
-
-  override protected def processRow(row: InternalRow): Unit = {
-    val input = joinedRow(buffer, row)
-    // For all aggregate functions with mode Complete, update buffers.
-    completeAlgebraicUpdateProjection(input)
-    var i = 0
-    while (i < completeNonAlgebraicAggregateFunctions.length) {
-      completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
-      i += 1
-    }
-
-    // For all aggregate functions with mode Final, merge buffers.
-    finalAlgebraicMergeProjection.target(buffer)(input)
-    i = 0
-    while (i < finalNonAlgebraicAggregateFunctions.length) {
-      finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
-      i += 1
-    }
-  }
-
-  override protected def generateOutput(): InternalRow = {
-    // Generate results for all algebraic aggregate functions.
-    algebraicEvalProjection.target(aggregateResult)(buffer)
-    // Generate results for all non-algebraic aggregate functions.
-    var i = 0
-    while (i < nonAlgebraicAggregateFunctions.length) {
-      aggregateResult.update(
-        nonAlgebraicAggregateFunctionPositions(i),
-        nonAlgebraicAggregateFunctions(i).eval(buffer))
-      i += 1
-    }
-
-    resultProjection(joinedRow(currentGroupingKey, aggregateResult))
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index cc54319171bdb81e7e7c55fd7807dd128cffd9f8..5fafc916bfa0b9a3369aaefe81178ec1931d14b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -24,7 +24,154 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti
 import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
-import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType}
+import org.apache.spark.sql.types._
+
+/**
+ * A helper trait used to create specialized setter and getter for types supported by
+ * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer.
+ * (see UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema).
+ */
+sealed trait BufferSetterGetterUtils {
+
+  def createGetters(schema: StructType): Array[(InternalRow, Int) => Any] = {
+    val dataTypes = schema.fields.map(_.dataType)
+    val getters = new Array[(InternalRow, Int) => Any](dataTypes.length)
+
+    var i = 0
+    while (i < getters.length) {
+      getters(i) = dataTypes(i) match {
+        case BooleanType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
+
+        case ByteType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getByte(ordinal)
+
+        case ShortType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getShort(ordinal)
+
+        case IntegerType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+        case LongType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
+        case FloatType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getFloat(ordinal)
+
+        case DoubleType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getDouble(ordinal)
+
+        case dt: DecimalType =>
+          val precision = dt.precision
+          val scale = dt.scale
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
+
+        case other =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
+      }
+
+      i += 1
+    }
+
+    getters
+  }
+
+  def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = {
+    val dataTypes = schema.fields.map(_.dataType)
+    val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length)
+
+    var i = 0
+    while (i < setters.length) {
+      setters(i) = dataTypes(i) match {
+        case b: BooleanType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setBoolean(ordinal, value.asInstanceOf[Boolean])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case ByteType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setByte(ordinal, value.asInstanceOf[Byte])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case ShortType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setShort(ordinal, value.asInstanceOf[Short])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case IntegerType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setInt(ordinal, value.asInstanceOf[Int])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case LongType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setLong(ordinal, value.asInstanceOf[Long])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case FloatType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setFloat(ordinal, value.asInstanceOf[Float])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case DoubleType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setDouble(ordinal, value.asInstanceOf[Double])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case dt: DecimalType =>
+          val precision = dt.precision
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case other =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.update(ordinal, value)
+            } else {
+              row.setNullAt(ordinal)
+            }
+      }
+
+      i += 1
+    }
+
+    setters
+  }
+}
 
 /**
  * A Mutable [[Row]] representing an mutable aggregation buffer.
@@ -35,7 +182,7 @@ private[sql] class MutableAggregationBufferImpl (
     toScalaConverters: Array[Any => Any],
     bufferOffset: Int,
     var underlyingBuffer: MutableRow)
-  extends MutableAggregationBuffer {
+  extends MutableAggregationBuffer with BufferSetterGetterUtils {
 
   private[this] val offsets: Array[Int] = {
     val newOffsets = new Array[Int](length)
@@ -47,6 +194,10 @@ private[sql] class MutableAggregationBufferImpl (
     newOffsets
   }
 
+  private[this] val bufferValueGetters = createGetters(schema)
+
+  private[this] val bufferValueSetters = createSetters(schema)
+
   override def length: Int = toCatalystConverters.length
 
   override def get(i: Int): Any = {
@@ -54,7 +205,7 @@ private[sql] class MutableAggregationBufferImpl (
       throw new IllegalArgumentException(
         s"Could not access ${i}th value in this buffer because it only has $length values.")
     }
-    toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
+    toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
   }
 
   def update(i: Int, value: Any): Unit = {
@@ -62,7 +213,15 @@ private[sql] class MutableAggregationBufferImpl (
       throw new IllegalArgumentException(
         s"Could not update ${i}th value in this buffer because it only has $length values.")
     }
-    underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+
+    bufferValueSetters(i)(underlyingBuffer, offsets(i), toCatalystConverters(i)(value))
+  }
+
+  // Because get method call specialized getter based on the schema, we cannot use the
+  // default implementation of the isNullAt (which is get(i) == null).
+  // We have to override it to call isNullAt of the underlyingBuffer.
+  override def isNullAt(i: Int): Boolean = {
+    underlyingBuffer.isNullAt(offsets(i))
   }
 
   override def copy(): MutableAggregationBufferImpl = {
@@ -84,7 +243,7 @@ private[sql] class InputAggregationBuffer private[sql] (
     toScalaConverters: Array[Any => Any],
     bufferOffset: Int,
     var underlyingInputBuffer: InternalRow)
-  extends Row {
+  extends Row with BufferSetterGetterUtils {
 
   private[this] val offsets: Array[Int] = {
     val newOffsets = new Array[Int](length)
@@ -96,6 +255,10 @@ private[sql] class InputAggregationBuffer private[sql] (
     newOffsets
   }
 
+  private[this] val bufferValueGetters = createGetters(schema)
+
+  def getBufferOffset: Int = bufferOffset
+
   override def length: Int = toCatalystConverters.length
 
   override def get(i: Int): Any = {
@@ -103,8 +266,14 @@ private[sql] class InputAggregationBuffer private[sql] (
       throw new IllegalArgumentException(
         s"Could not access ${i}th value in this buffer because it only has $length values.")
     }
-    // TODO: Use buffer schema to avoid using generic getter.
-    toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType))
+    toScalaConverters(i)(bufferValueGetters(i)(underlyingInputBuffer, offsets(i)))
+  }
+
+  // Because get method call specialized getter based on the schema, we cannot use the
+  // default implementation of the isNullAt (which is get(i) == null).
+  // We have to override it to call isNullAt of the underlyingInputBuffer.
+  override def isNullAt(i: Int): Boolean = {
+    underlyingInputBuffer.isNullAt(offsets(i))
   }
 
   override def copy(): InputAggregationBuffer = {
@@ -147,7 +316,7 @@ private[sql] case class ScalaUDAF(
 
   override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
 
-  val childrenSchema: StructType = {
+  private[this] val childrenSchema: StructType = {
     val inputFields = children.zipWithIndex.map {
       case (child, index) =>
         StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
@@ -155,7 +324,7 @@ private[sql] case class ScalaUDAF(
     StructType(inputFields)
   }
 
-  lazy val inputProjection = {
+  private lazy val inputProjection = {
     val inputAttributes = childrenSchema.toAttributes
     log.debug(
       s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
@@ -168,40 +337,68 @@ private[sql] case class ScalaUDAF(
     }
   }
 
-  val inputToScalaConverters: Any => Any =
+  private[this] val inputToScalaConverters: Any => Any =
     CatalystTypeConverters.createToScalaConverter(childrenSchema)
 
-  val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
-    CatalystTypeConverters.createToCatalystConverter(field.dataType)
+  private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = {
+    bufferSchema.fields.map { field =>
+      CatalystTypeConverters.createToCatalystConverter(field.dataType)
+    }
   }
 
-  val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
-    CatalystTypeConverters.createToScalaConverter(field.dataType)
+  private[this] val bufferValuesToScalaConverters: Array[Any => Any] = {
+    bufferSchema.fields.map { field =>
+      CatalystTypeConverters.createToScalaConverter(field.dataType)
+    }
   }
 
-  lazy val inputAggregateBuffer: InputAggregationBuffer =
-    new InputAggregationBuffer(
-      bufferSchema,
-      bufferValuesToCatalystConverters,
-      bufferValuesToScalaConverters,
-      inputBufferOffset,
-      null)
-
-  lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
-    new MutableAggregationBufferImpl(
-      bufferSchema,
-      bufferValuesToCatalystConverters,
-      bufferValuesToScalaConverters,
-      mutableBufferOffset,
-      null)
+  // This buffer is only used at executor side.
+  private[this] var inputAggregateBuffer: InputAggregationBuffer = null
+
+  // This buffer is only used at executor side.
+  private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null
+
+  // This buffer is only used at executor side.
+  private[this] var evalAggregateBuffer: InputAggregationBuffer = null
+
+  /**
+   * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
+   * `inputAggregateBuffer` based on this new inputBufferOffset.
+   */
+  override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
+    super.withNewInputBufferOffset(newInputBufferOffset)
+    // inputBufferOffset has been updated.
+    inputAggregateBuffer =
+      new InputAggregationBuffer(
+        bufferSchema,
+        bufferValuesToCatalystConverters,
+        bufferValuesToScalaConverters,
+        inputBufferOffset,
+        null)
+  }
 
-  lazy val evalAggregateBuffer: InputAggregationBuffer =
-    new InputAggregationBuffer(
-      bufferSchema,
-      bufferValuesToCatalystConverters,
-      bufferValuesToScalaConverters,
-      mutableBufferOffset,
-      null)
+  /**
+   * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
+   * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
+   */
+  override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
+    super.withNewMutableBufferOffset(newMutableBufferOffset)
+    // mutableBufferOffset has been updated.
+    mutableAggregateBuffer =
+      new MutableAggregationBufferImpl(
+        bufferSchema,
+        bufferValuesToCatalystConverters,
+        bufferValuesToScalaConverters,
+        mutableBufferOffset,
+        null)
+    evalAggregateBuffer =
+      new InputAggregationBuffer(
+        bufferSchema,
+        bufferValuesToCatalystConverters,
+        bufferValuesToScalaConverters,
+        mutableBufferOffset,
+        null)
+  }
 
   override def initialize(buffer: MutableRow): Unit = {
     mutableAggregateBuffer.underlyingBuffer = buffer
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 03635baae4a5fcfa2d9d2d2ac4d95d9ce40da12a..960be08f84d941c9e39ce180c9a1c6e06110a275 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,13 +17,9 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
 
 /**
  * Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -52,13 +48,16 @@ object Utils {
       agg.aggregateFunction.bufferAttributes
     }
     val partialAggregate =
-      Aggregate2Sort(
-        None: Option[Seq[Expression]],
-        namedGroupingExpressions.map(_._2),
-        partialAggregateExpressions,
-        partialAggregateAttributes,
-        namedGroupingAttributes ++ partialAggregateAttributes,
-        child)
+      Aggregate(
+        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        groupingExpressions = namedGroupingExpressions.map(_._2),
+        nonCompleteAggregateExpressions = partialAggregateExpressions,
+        nonCompleteAggregateAttributes = partialAggregateAttributes,
+        completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
+        initialInputBufferOffset = 0,
+        resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes,
+        child = child)
 
     // 2. Create an Aggregate Operator for final aggregations.
     val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
@@ -78,13 +77,17 @@ object Utils {
           }.getOrElse(expression)
       }.asInstanceOf[NamedExpression]
     }
-    val finalAggregate = Aggregate2Sort(
-      Some(namedGroupingAttributes),
-      namedGroupingAttributes,
-      finalAggregateExpressions,
-      finalAggregateAttributes,
-      rewrittenResultExpressions,
-      partialAggregate)
+    val finalAggregate =
+      Aggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes,
+        nonCompleteAggregateExpressions = finalAggregateExpressions,
+        nonCompleteAggregateAttributes = finalAggregateAttributes,
+        completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
+        initialInputBufferOffset = namedGroupingAttributes.length,
+        resultExpressions = rewrittenResultExpressions,
+        child = partialAggregate)
 
     finalAggregate :: Nil
   }
@@ -133,14 +136,21 @@ object Utils {
     val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
       agg.aggregateFunction.bufferAttributes
     }
+    val partialAggregateGroupingExpressions =
+      (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
+    val partialAggregateResult =
+      namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes
     val partialAggregate =
-      Aggregate2Sort(
-        None: Option[Seq[Expression]],
-        (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
-        partialAggregateExpressions,
-        partialAggregateAttributes,
-        namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
-        child)
+      Aggregate(
+        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        groupingExpressions = partialAggregateGroupingExpressions,
+        nonCompleteAggregateExpressions = partialAggregateExpressions,
+        nonCompleteAggregateAttributes = partialAggregateAttributes,
+        completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
+        initialInputBufferOffset = 0,
+        resultExpressions = partialAggregateResult,
+        child = child)
 
     // 2. Create an Aggregate Operator for partial merge aggregations.
     val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
@@ -151,14 +161,19 @@ object Utils {
       partialMergeAggregateExpressions.flatMap { agg =>
         agg.aggregateFunction.bufferAttributes
       }
+    val partialMergeAggregateResult =
+      namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes
     val partialMergeAggregate =
-      Aggregate2Sort(
-        Some(namedGroupingAttributes),
-        namedGroupingAttributes ++ distinctColumnAttributes,
-        partialMergeAggregateExpressions,
-        partialMergeAggregateAttributes,
-        namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
-        partialAggregate)
+      Aggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+        nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+        nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+        completeAggregateExpressions = Nil,
+        completeAggregateAttributes = Nil,
+        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = partialMergeAggregateResult,
+        child = partialAggregate)
 
     // 3. Create an Aggregate Operator for partial merge aggregations.
     val finalAggregateExpressions = functionsWithoutDistinct.map {
@@ -199,15 +214,17 @@ object Utils {
           }.getOrElse(expression)
       }.asInstanceOf[NamedExpression]
     }
-    val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
-      namedGroupingAttributes ++ distinctColumnAttributes,
-      namedGroupingAttributes,
-      finalAggregateExpressions,
-      finalAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      rewrittenResultExpressions,
-      partialMergeAggregate)
+    val finalAndCompleteAggregate =
+      Aggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes,
+        nonCompleteAggregateExpressions = finalAggregateExpressions,
+        nonCompleteAggregateAttributes = finalAggregateAttributes,
+        completeAggregateExpressions = completeAggregateExpressions,
+        completeAggregateAttributes = completeAggregateAttributes,
+        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = rewrittenResultExpressions,
+        child = partialMergeAggregate)
 
     finalAndCompleteAggregate :: Nil
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2294a670c735ff48e85a3528e9d18280722cf427..5a1b000e89875235d8f4f3c0e7a6646c6a446041 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -220,7 +220,6 @@ case class TakeOrderedAndProject(
   override def outputOrdering: Seq[SortOrder] = sortOrder
 }
 
-
 /**
  * :: DeveloperApi ::
  * Return a new RDD that has exactly `numPartitions` partitions.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 51fe9d9d98bf357d13517a88555484918f4129c3..bbadc202a4f0620a8f3a20504dc7abec55b98638 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,14 +17,14 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
-import org.scalatest.BeforeAndAfterAll
-
 import java.sql.Timestamp
 
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.errors.DialectException
-import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.execution.GeneratedAggregate
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.TestData._
@@ -273,7 +273,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       var hasGeneratedAgg = false
       df.queryExecution.executedPlan.foreach {
         case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
-        case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
+        case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true
         case _ =>
       }
       if (!hasGeneratedAgg) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 54f82f89ed18af78680e8b30a0df96375b48701b..7978ed57a937e61ddc5640c6617b24935eb50407 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -138,7 +138,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
           s"Expected $expectedSerializerClass as the serializer of Exchange. " +
           s"However, the serializer was not set."
         val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage))
-        assert(serializer.getClass === expectedSerializerClass)
+        val isExpectedSerializer =
+          serializer.getClass == expectedSerializerClass ||
+            serializer.getClass == classOf[UnsafeRowSerializer]
+        val wrongSerializerErrorMessage =
+          s"Expected ${expectedSerializerClass.getCanonicalName} or " +
+            s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " +
+            s"${serializer.getClass.getCanonicalName} is used."
+        assert(isExpectedSerializer, wrongSerializerErrorMessage)
       case _ => // Ignore other nodes.
     }
   }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 0375eb79add95d637ff0e2f99a60c02dcaaed358..6f0db27775e4dd5dd82f77032ebe36632ed2c1e0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,15 +17,15 @@
 
 package org.apache.spark.sql.hive.execution
 
-import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
+import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
 import org.scalatest.BeforeAndAfterAll
 import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 
-class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
+abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
 
   override val sqlContext = TestHive
   import sqlContext.implicits._
@@ -34,7 +34,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
 
   override def beforeAll(): Unit = {
     originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
-    sqlContext.sql("set spark.sql.useAggregate2=true")
+    sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true")
     val data1 = Seq[(Integer, Integer)](
       (1, 10),
       (null, -60),
@@ -81,7 +81,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
     sqlContext.sql("DROP TABLE IF EXISTS agg1")
     sqlContext.sql("DROP TABLE IF EXISTS agg2")
     sqlContext.dropTempTable("emptyTable")
-    sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
+    sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString)
   }
 
   test("empty table") {
@@ -454,54 +454,86 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
   }
 
   test("error handling") {
-    sqlContext.sql(s"set spark.sql.useAggregate2=false")
-    var errorMessage = intercept[AnalysisException] {
-      sqlContext.sql(
-        """
-          |SELECT
-          |  key,
-          |  sum(value + 1.5 * key),
-          |  mydoublesum(value),
-          |  mydoubleavg(value)
-          |FROM agg1
-          |GROUP BY key
-        """.stripMargin).collect()
-    }.getMessage
-    assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+    withSQLConf("spark.sql.useAggregate2" -> "false") {
+      val errorMessage = intercept[AnalysisException] {
+        sqlContext.sql(
+          """
+            |SELECT
+            |  key,
+            |  sum(value + 1.5 * key),
+            |  mydoublesum(value),
+            |  mydoubleavg(value)
+            |FROM agg1
+            |GROUP BY key
+          """.stripMargin).collect()
+      }.getMessage
+      assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+    }
 
     // TODO: once we support Hive UDAF in the new interface,
     // we can remove the following two tests.
-    sqlContext.sql(s"set spark.sql.useAggregate2=true")
-    errorMessage = intercept[AnalysisException] {
-      sqlContext.sql(
+    withSQLConf("spark.sql.useAggregate2" -> "true") {
+      val errorMessage = intercept[AnalysisException] {
+        sqlContext.sql(
+          """
+            |SELECT
+            |  key,
+            |  mydoublesum(value + 1.5 * key),
+            |  stddev_samp(value)
+            |FROM agg1
+            |GROUP BY key
+          """.stripMargin).collect()
+      }.getMessage
+      assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
+
+      // This will fall back to the old aggregate
+      val newAggregateOperators = sqlContext.sql(
         """
           |SELECT
           |  key,
-          |  mydoublesum(value + 1.5 * key),
+          |  sum(value + 1.5 * key),
           |  stddev_samp(value)
           |FROM agg1
           |GROUP BY key
-        """.stripMargin).collect()
-    }.getMessage
-    assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
-
-    // This will fall back to the old aggregate
-    val newAggregateOperators = sqlContext.sql(
-      """
-        |SELECT
-        |  key,
-        |  sum(value + 1.5 * key),
-        |  stddev_samp(value)
-        |FROM agg1
-        |GROUP BY key
-      """.stripMargin).queryExecution.executedPlan.collect {
-      case agg: Aggregate2Sort => agg
+        """.stripMargin).queryExecution.executedPlan.collect {
+        case agg: aggregate.Aggregate => agg
+      }
+      val message =
+        "We should fallback to the old aggregation code path if " +
+          "there is any aggregate function that cannot be converted to the new interface."
+      assert(newAggregateOperators.isEmpty, message)
     }
-    val message =
-      "We should fallback to the old aggregation code path if there is any aggregate function " +
-        "that cannot be converted to the new interface."
-    assert(newAggregateOperators.isEmpty, message)
+  }
+}
+
+class SortBasedAggregationQuerySuite extends AggregationQuerySuite {
 
-    sqlContext.sql(s"set spark.sql.useAggregate2=true")
+  var originalUnsafeEnabled: Boolean = _
+
+  override def beforeAll(): Unit = {
+    originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false")
+    super.beforeAll()
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
+  }
+}
+
+class TungstenAggregationQuerySuite extends AggregationQuerySuite {
+
+  var originalUnsafeEnabled: Boolean = _
+
+  override def beforeAll(): Unit = {
+    originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true")
+    super.beforeAll()
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
   }
 }