diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 88fb516e64aaf379141273542013f882c98df3eb..a73024d6adba10328df7962dedacbc16287c0868 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate {
   override def dataType: DataType = resultType
 
   // Expected input data type.
-  // TODO: Once we remove the old code path, we can use our analyzer to cast NullType
-  // to the default data type of the NumericType.
+  // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+  // new version at planning time (after analysis phase). For now, NullType is added at here
+  // to make it resolved when we have cases like `select avg(null)`.
+  // We can use our analyzer to cast NullType to the default data type of the NumericType once
+  // we remove the old aggregate functions. Then, we will not need NullType at here.
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
 
   private val resultType = child.dataType match {
@@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
   override def dataType: DataType = resultType
 
   // Expected input data type.
+  // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
+  // new version at planning time (after analysis phase). For now, NullType is added at here
+  // to make it resolved when we have cases like `select sum(null)`.
+  // We can use our analyzer to cast NullType to the default data type of the NumericType once
+  // we remove the old aggregate functions. Then, we will not need NullType at here.
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
 
   private val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
+    // TODO: Remove this line once we remove the NullType from inputTypes.
+    case NullType => IntegerType
     case _ => child.dataType
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a730ffbb217c01b458cdafadb2af9421f4b746bf..c5aaebe6732252b68e934fd6682df55b0fc40ed3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
             // aggregate function to the corresponding attribute of the function.
             val aggregateFunctionMap = aggregateExpressions.map { agg =>
               val aggregateFunction = agg.aggregateFunction
+              val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
               (aggregateFunction, agg.isDistinct) ->
-                Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+                (aggregateFunction -> attribtue)
             }.toMap
 
             val (functionsWithDistinct, functionsWithoutDistinct) =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 16498da080c88128b5339286e4780fc425f0f9a2..39f8f992a9f009c915d2941167ddf7c6365c2450 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
+import java.io._
 import java.nio.ByteBuffer
 
 import scala.reflect.ClassTag
@@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
    */
   override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
     private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
+    // When `out` is backed by ChainedBufferOutputStream, we will get an
+    // UnsupportedOperationException when we call dOut.writeInt because it internally calls
+    // ChainedBufferOutputStream's write(b: Int), which is not supported.
+    // To workaround this issue, we create an array for sorting the int value.
+    // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
+    // run SparkSqlSerializer2SortMergeShuffleSuite.
+    private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
     private[this] val dOut: DataOutputStream = new DataOutputStream(out)
 
     override def writeValue[T: ClassTag](value: T): SerializationStream = {
       val row = value.asInstanceOf[UnsafeRow]
-      dOut.writeInt(row.getSizeInBytes)
+      val size = row.getSizeInBytes
+      // This part is based on DataOutputStream's writeInt.
+      // It is for dOut.writeInt(row.getSizeInBytes).
+      intBuffer(0) = ((size >>> 24) & 0xFF).toByte
+      intBuffer(1) = ((size >>> 16) & 0xFF).toByte
+      intBuffer(2) = ((size >>> 8) & 0xFF).toByte
+      intBuffer(3) = ((size >>> 0) & 0xFF).toByte
+      dOut.write(intBuffer, 0, 4)
+
       row.writeToStream(out, writeBuffer)
       this
     }
@@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
 
     override def close(): Unit = {
       writeBuffer = null
+      intBuffer = null
       dOut.writeInt(EOF)
       dOut.close()
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
deleted file mode 100644
index cf568dc048674f2c01ecf9c849df46c97fcb17c2..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
-import org.apache.spark.sql.types.StructType
-
-/**
- * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types
- * of the grouping expressions and aggregate functions, it determines if it uses
- * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to
- * process input rows.
- */
-case class Aggregate(
-    requiredChildDistributionExpressions: Option[Seq[Expression]],
-    groupingExpressions: Seq[NamedExpression],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression2],
-    completeAggregateAttributes: Seq[Attribute],
-    initialInputBufferOffset: Int,
-    resultExpressions: Seq[NamedExpression],
-    child: SparkPlan)
-  extends UnaryNode {
-
-  private[this] val allAggregateExpressions =
-    nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
-  private[this] val hasNonAlgebricAggregateFunctions =
-    !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
-
-  // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
-  // grouping key and aggregation buffer is supported; and (3) all
-  // aggregate functions are algebraic.
-  private[this] val supportsHybridIterator: Boolean = {
-    val aggregationBufferSchema: StructType =
-      StructType.fromAttributes(
-        allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
-    val groupKeySchema: StructType =
-      StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
-
-    val schemaSupportsUnsafe: Boolean =
-      UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
-        UnsafeProjection.canSupport(groupKeySchema)
-
-    // TODO: Use the hybrid iterator for non-algebric aggregate functions.
-    sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions
-  }
-
-  // We need to use sorted input if we have grouping expressions, and
-  // we cannot use the hybrid iterator or the hybrid is disabled.
-  private[this] val requiresSortedInput: Boolean = {
-    groupingExpressions.nonEmpty && !supportsHybridIterator
-  }
-
-  override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions
-
-  // If result expressions' data types are all fixed length, we generate unsafe rows
-  // (We have this requirement instead of check the result of UnsafeProjection.canSupport
-  // is because we use a mutable projection to generate the result).
-  override def outputsUnsafeRows: Boolean = {
-    // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
-    // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix
-    // any issue we get.
-    false
-  }
-
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  override def requiredChildDistribution: List[Distribution] = {
-    requiredChildDistributionExpressions match {
-      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
-      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
-      case None => UnspecifiedDistribution :: Nil
-    }
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
-    if (requiresSortedInput) {
-      // TODO: We should not sort the input rows if they are just in reversed order.
-      groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
-    } else {
-      Seq.fill(children.size)(Nil)
-    }
-  }
-
-  override def outputOrdering: Seq[SortOrder] = {
-    if (requiresSortedInput) {
-      // It is possible that the child.outputOrdering starts with the required
-      // ordering expressions (e.g. we require [a] as the sort expression and the
-      // child's outputOrdering is [a, b]). We can only guarantee the output rows
-      // are sorted by values of groupingExpressions.
-      groupingExpressions.map(SortOrder(_, Ascending))
-    } else {
-      Nil
-    }
-  }
-
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
-    child.execute().mapPartitions { iter =>
-      // Because the constructor of an aggregation iterator will read at least the first row,
-      // we need to get the value of iter.hasNext first.
-      val hasInput = iter.hasNext
-      val useHybridIterator =
-        hasInput &&
-          supportsHybridIterator &&
-          groupingExpressions.nonEmpty
-      if (useHybridIterator) {
-        UnsafeHybridAggregationIterator.createFromInputIterator(
-          groupingExpressions,
-          nonCompleteAggregateExpressions,
-          nonCompleteAggregateAttributes,
-          completeAggregateExpressions,
-          completeAggregateAttributes,
-          initialInputBufferOffset,
-          resultExpressions,
-          newMutableProjection _,
-          child.output,
-          iter,
-          outputsUnsafeRows)
-      } else {
-        if (!hasInput && groupingExpressions.nonEmpty) {
-          // This is a grouped aggregate and the input iterator is empty,
-          // so return an empty iterator.
-          Iterator[InternalRow]()
-        } else {
-          val outputIter = SortBasedAggregationIterator.createFromInputIterator(
-            groupingExpressions,
-            nonCompleteAggregateExpressions,
-            nonCompleteAggregateAttributes,
-            completeAggregateExpressions,
-            completeAggregateAttributes,
-            initialInputBufferOffset,
-            resultExpressions,
-            newMutableProjection _ ,
-            newProjection _,
-            child.output,
-            iter,
-            outputsUnsafeRows)
-          if (!hasInput && groupingExpressions.isEmpty) {
-            // There is no input and there is no grouping expressions.
-            // We need to output a single row as the output.
-            Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
-          } else {
-            outputIter
-          }
-        }
-      }
-    }
-  }
-
-  override def simpleString: String = {
-    val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) {
-      classOf[UnsafeHybridAggregationIterator].getSimpleName
-    } else {
-      classOf[SortBasedAggregationIterator].getSimpleName
-    }
-
-    s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}"""
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ad428ad663f307634d83cce5ee44f360f424cf08
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+case class SortBasedAggregate(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  override def outputsUnsafeRows: Boolean = false
+
+  override def canProcessUnsafeRows: Boolean = false
+
+  override def canProcessSafeRows: Boolean = true
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    groupingExpressions.map(SortOrder(_, Ascending))
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+    child.execute().mapPartitions { iter =>
+      // Because the constructor of an aggregation iterator will read at least the first row,
+      // we need to get the value of iter.hasNext first.
+      val hasInput = iter.hasNext
+      if (!hasInput && groupingExpressions.nonEmpty) {
+        // This is a grouped aggregate and the input iterator is empty,
+        // so return an empty iterator.
+        Iterator[InternalRow]()
+      } else {
+        val outputIter = SortBasedAggregationIterator.createFromInputIterator(
+          groupingExpressions,
+          nonCompleteAggregateExpressions,
+          nonCompleteAggregateAttributes,
+          completeAggregateExpressions,
+          completeAggregateAttributes,
+          initialInputBufferOffset,
+          resultExpressions,
+          newMutableProjection _,
+          newProjection _,
+          child.output,
+          iter,
+          outputsUnsafeRows)
+        if (!hasInput && groupingExpressions.isEmpty) {
+          // There is no input and there is no grouping expressions.
+          // We need to output a single row as the output.
+          Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+        } else {
+          outputIter
+        }
+      }
+    }
+  }
+
+  override def simpleString: String = {
+    val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+    s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}"""
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 40f6bff53d2b7de71bbfadae179d98423973529f..67ebafde25ad3e5620d1a6df8ea2d32b74a30b88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -204,31 +204,5 @@ object SortBasedAggregationIterator {
       newMutableProjection,
       outputsUnsafeRows)
   }
-
-  def createFromKVIterator(
-      groupingKeyAttributes: Seq[Attribute],
-      valueAttributes: Seq[Attribute],
-      inputKVIterator: KVIterator[InternalRow, InternalRow],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-      outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
-    new SortBasedAggregationIterator(
-      groupingKeyAttributes,
-      valueAttributes,
-      inputKVIterator,
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows)
-  }
   // scalastyle:on
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5a0b4d47d62f8dc10c16e3a495911581dea3dbe0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+
+case class TungstenAggregate(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  override def outputsUnsafeRows: Boolean = true
+
+  override def canProcessUnsafeRows: Boolean = true
+
+  override def canProcessSafeRows: Boolean = false
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  // This is for testing. We force TungstenAggregationIterator to fall back to sort-based
+  // aggregation once it has processed a given number of input rows.
+  private val testFallbackStartsAt: Option[Int] = {
+    sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
+      case null | "" => None
+      case fallbackStartsAt => Some(fallbackStartsAt.toInt)
+    }
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+    child.execute().mapPartitions { iter =>
+      val hasInput = iter.hasNext
+      if (!hasInput && groupingExpressions.nonEmpty) {
+        // This is a grouped aggregate and the input iterator is empty,
+        // so return an empty iterator.
+        Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
+      } else {
+        val aggregationIterator =
+          new TungstenAggregationIterator(
+            groupingExpressions,
+            nonCompleteAggregateExpressions,
+            completeAggregateExpressions,
+            initialInputBufferOffset,
+            resultExpressions,
+            newMutableProjection,
+            child.output,
+            iter.asInstanceOf[Iterator[UnsafeRow]],
+            testFallbackStartsAt)
+
+        if (!hasInput && groupingExpressions.isEmpty) {
+          Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
+        } else {
+          aggregationIterator
+        }
+      }
+    }
+  }
+
+  override def simpleString: String = {
+    val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+    testFallbackStartsAt match {
+      case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}"
+      case Some(fallbackStartsAt) =>
+        s"TungstenAggregateWithControlledFallback ${groupingExpressions} " +
+          s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt"
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b9d44aace1009b01117e5cda473fe84258770e6e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s.
+ *
+ * This iterator first uses hash-based aggregation to process input rows. It uses
+ * a hash map to store groups and their corresponding aggregation buffers. If we
+ * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * it switches to sort-based aggregation. The process of the switch has the following step:
+ *  - Step 1: Sort all entries of the hash map based on values of grouping expressions and
+ *            spill them to disk.
+ *  - Step 2: Create a external sorter based on the spilled sorted map entries.
+ *  - Step 3: Redirect all input rows to the external sorter.
+ *  - Step 4: Get a sorted [[KVIterator]] from the external sorter.
+ *  - Step 5: Initialize sort-based aggregation.
+ * Then, this iterator works in the way of sort-based aggregation.
+ *
+ * The code of this class is organized as follows:
+ *  - Part 1: Initializing aggregate functions.
+ *  - Part 2: Methods and fields used by setting aggregation buffer values,
+ *            processing input rows from inputIter, and generating output
+ *            rows.
+ *  - Part 3: Methods and fields used by hash-based aggregation.
+ *  - Part 4: The function used to switch this iterator from hash-based
+ *            aggregation to sort-based aggregation.
+ *  - Part 5: Methods and fields used by sort-based aggregation.
+ *  - Part 6: Loads input and process input rows.
+ *  - Part 7: Public methods of this iterator.
+ *  - Part 8: A utility function used to generate a result when there is no
+ *            input and there is no grouping expression.
+ *
+ * @param groupingExpressions
+ *   expressions for grouping keys
+ * @param nonCompleteAggregateExpressions
+ *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
+ *   [[PartialMerge]], or [[Final]].
+ * @param completeAggregateExpressions
+ *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
+ * @param initialInputBufferOffset
+ *   If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]].
+ *   The input rows have the format of `grouping keys + aggregation buffer`.
+ *   This offset indicates the starting position of aggregation buffer in a input row.
+ * @param resultExpressions
+ *   expressions for generating output rows.
+ * @param newMutableProjection
+ *   the function used to create mutable projections.
+ * @param originalInputAttributes
+ *   attributes of representing input rows from `inputIter`.
+ * @param inputIter
+ *   the iterator containing input [[UnsafeRow]]s.
+ */
+class TungstenAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    originalInputAttributes: Seq[Attribute],
+    inputIter: Iterator[UnsafeRow],
+    testFallbackStartsAt: Option[Int])
+  extends Iterator[UnsafeRow] with Logging {
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 1: Initializing aggregate functions.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // A Seq containing all AggregateExpressions.
+  // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
+  // are at the beginning of the allAggregateExpressions.
+  private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
+    nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+  // Check to make sure we do not have more than three modes in our AggregateExpressions.
+  // If we have, users are hitting a bug and we throw an IllegalStateException.
+  if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
+    throw new IllegalStateException(
+      s"$allAggregateExpressions should have no more than 2 kinds of modes.")
+  }
+
+  //
+  // The modes of AggregateExpressions. Right now, we can handle the following mode:
+  //  - Partial-only:
+  //      All AggregateExpressions have the mode of Partial.
+  //      For this case, aggregationMode is (Some(Partial), None).
+  //  - PartialMerge-only:
+  //      All AggregateExpressions have the mode of PartialMerge).
+  //      For this case, aggregationMode is (Some(PartialMerge), None).
+  //  - Final-only:
+  //      All AggregateExpressions have the mode of Final.
+  //      For this case, aggregationMode is (Some(Final), None).
+  //  - Final-Complete:
+  //      Some AggregateExpressions have the mode of Final and
+  //      others have the mode of Complete. For this case,
+  //      aggregationMode is (Some(Final), Some(Complete)).
+  //  - Complete-only:
+  //      nonCompleteAggregateExpressions is empty and we have AggregateExpressions
+  //      with mode Complete in completeAggregateExpressions. For this case,
+  //      aggregationMode is (None, Some(Complete)).
+  //  - Grouping-only:
+  //      There is no AggregateExpression. For this case, AggregationMode is (None,None).
+  //
+  private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = {
+    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+      completeAggregateExpressions.map(_.mode).distinct.headOption
+  }
+
+  // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates.
+  // If there is any functions that is not an AlgebraicAggregate, we throw an
+  // IllegalStateException.
+  private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
+    if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) {
+      throw new IllegalStateException(
+        "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.")
+    }
+
+    allAggregateExpressions
+      .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
+      .toArray
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 2: Methods and fields used by setting aggregation buffer values,
+  //         processing input rows from inputIter, and generating output
+  //         rows.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // The projection used to initialize buffer values.
+  private[this] val algebraicInitialProjection: MutableProjection = {
+    val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+    newMutableProjection(initExpressions, Nil)()
+  }
+
+  // Creates a new aggregation buffer and initializes buffer values.
+  // This functions should be only called at most three times (when we create the hash map,
+  // when we switch to sort-based aggregation, and when we create the re-used buffer for
+  // sort-based aggregation).
+  private def createNewAggregationBuffer(): UnsafeRow = {
+    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val bufferRowSize: Int = bufferSchema.length
+
+    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+    val unsafeProjection =
+      UnsafeProjection.create(bufferSchema.map(_.dataType))
+    val buffer = unsafeProjection.apply(genericMutableBuffer)
+    algebraicInitialProjection.target(buffer)(EmptyRow)
+    buffer
+  }
+
+  // Creates a function used to process a row based on the given inputAttributes.
+  private def generateProcessRow(
+      inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
+
+    val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes)
+    val inputSchema = StructType.fromAttributes(inputAttributes)
+    val unsafeRowJoiner =
+      GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
+
+    aggregationMode match {
+      // Partial-only
+      case (Some(Partial), None) =>
+        val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
+        val algebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          algebraicUpdateProjection.target(currentBuffer)
+          algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+        }
+
+      // PartialMerge-only or Final-only
+      case (Some(PartialMerge), None) | (Some(Final), None) =>
+        val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
+        // This projection is used to merge buffer values for all AlgebraicAggregates.
+        val algebraicMergeProjection =
+          newMutableProjection(
+            mergeExpressions,
+            aggregationBufferAttributes ++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          // Process all algebraic aggregate functions.
+          algebraicMergeProjection.target(currentBuffer)
+          algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row))
+        }
+
+      // Final-Complete
+      case (Some(Final), Some(Complete)) =>
+        val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] =
+          allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+        val completeAggregateFunctions: Array[AlgebraicAggregate] =
+          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+        val completeOffsetExpressions =
+          Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+        val mergeExpressions =
+          nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
+        val finalAlgebraicMergeProjection =
+          newMutableProjection(
+            mergeExpressions,
+            aggregationBufferAttributes ++ inputAttributes)()
+
+        // We do not touch buffer values of aggregate functions with the Final mode.
+        val finalOffsetExpressions =
+          Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+        val updateExpressions =
+          finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
+        val completeAlgebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          val input = unsafeRowJoiner.join(currentBuffer, row)
+          // For all aggregate functions with mode Complete, update the given currentBuffer.
+          completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+
+          // For all aggregate functions with mode Final, merge buffer values in row to
+          // currentBuffer.
+          finalAlgebraicMergeProjection.target(currentBuffer)(input)
+        }
+
+      // Complete-only
+      case (None, Some(Complete)) =>
+        val completeAggregateFunctions: Array[AlgebraicAggregate] =
+          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+        val updateExpressions =
+          completeAggregateFunctions.flatMap(_.updateExpressions)
+        val completeAlgebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          completeAlgebraicUpdateProjection.target(currentBuffer)
+          // For all aggregate functions with mode Complete, update the given currentBuffer.
+          completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+        }
+
+      // Grouping only.
+      case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {}
+
+      case other =>
+        throw new IllegalStateException(
+          s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+    }
+  }
+
+  // Creates a function used to generate output rows.
+  private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
+
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
+    val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+    val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val bufferSchema = StructType.fromAttributes(bufferAttributes)
+    val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
+    aggregationMode match {
+      // Partial-only or PartialMerge-only: every output row is basically the values of
+      // the grouping expressions and the corresponding aggregation buffer.
+      case (Some(Partial), None) | (Some(PartialMerge), None) =>
+        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+          unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
+        }
+
+      // Final-only, Complete-only and Final-Complete: a output row is generated based on
+      // resultExpressions.
+      case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+        val resultProjection =
+          UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
+
+        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+          resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer))
+        }
+
+      // Grouping-only: a output row is generated from values of grouping expressions.
+      case (None, None) =>
+        val resultProjection =
+          UnsafeProjection.create(resultExpressions, groupingAttributes)
+
+        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+          resultProjection(currentGroupingKey)
+        }
+
+      case other =>
+        throw new IllegalStateException(
+          s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+    }
+  }
+
+  // An UnsafeProjection used to extract grouping keys from the input rows.
+  private[this] val groupProjection =
+    UnsafeProjection.create(groupingExpressions, originalInputAttributes)
+
+  // A function used to process a input row. Its first argument is the aggregation buffer
+  // and the second argument is the input row.
+  private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit =
+    generateProcessRow(originalInputAttributes)
+
+  // A function used to generate output rows based on the grouping keys (first argument)
+  // and the corresponding aggregation buffer (second argument).
+  private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow =
+    generateResultProjection()
+
+  // An aggregation buffer containing initial buffer values. It is used to
+  // initialize other aggregation buffers.
+  private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 3: Methods and fields used by hash-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // This is the hash map used for hash-based aggregation. It is backed by an
+  // UnsafeFixedWidthAggregationMap and it is used to store
+  // all groups and their corresponding aggregation buffers for hash-based aggregation.
+  private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
+    initialAggregationBuffer,
+    StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+    StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
+    TaskContext.get.taskMemoryManager(),
+    SparkEnv.get.shuffleMemoryManager,
+    1024 * 16, // initial capacity
+    SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
+    false // disable tracking of performance metrics
+  )
+
+  // The function used to read and process input rows. When processing input rows,
+  // it first uses hash-based aggregation by putting groups and their buffers in
+  // hashMap. If we could not allocate more memory for the map, we switch to
+  // sort-based aggregation (by calling switchToSortBasedAggregation).
+  private def processInputs(): Unit = {
+    while (!sortBased && inputIter.hasNext) {
+      val newInput = inputIter.next()
+      val groupingKey = groupProjection.apply(newInput)
+      val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey)
+      if (buffer == null) {
+        // buffer == null means that we could not allocate more memory.
+        // Now, we need to spill the map and switch to sort-based aggregation.
+        switchToSortBasedAggregation(groupingKey, newInput)
+      } else {
+        processRow(buffer, newInput)
+      }
+    }
+  }
+
+  // This function is only used for testing. It basically the same as processInputs except
+  // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
+  // been processed.
+  private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
+    var i = 0
+    while (!sortBased && inputIter.hasNext) {
+      val newInput = inputIter.next()
+      val groupingKey = groupProjection.apply(newInput)
+      val buffer: UnsafeRow = if (i < fallbackStartsAt) {
+        hashMap.getAggregationBuffer(groupingKey)
+      } else {
+        null
+      }
+      if (buffer == null) {
+        // buffer == null means that we could not allocate more memory.
+        // Now, we need to spill the map and switch to sort-based aggregation.
+        switchToSortBasedAggregation(groupingKey, newInput)
+      } else {
+        processRow(buffer, newInput)
+      }
+      i += 1
+    }
+  }
+
+  // The iterator created from hashMap. It is used to generate output rows when we
+  // are using hash-based aggregation.
+  private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null
+
+  // Indicates if aggregationBufferMapIterator still has key-value pairs.
+  private[this] var mapIteratorHasNext: Boolean = false
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 4: The function used to switch this iterator from hash-based
+  // aggregation to sort-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
+    logInfo("falling back to sort based aggregation.")
+    // Step 1: Get the ExternalSorter containing sorted entries of the map.
+    val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter()
+
+    // Step 2: Free the memory used by the map.
+    hashMap.free()
+
+    // Step 3: If we have aggregate function with mode Partial or Complete,
+    // we need to process input rows to get aggregation buffer.
+    // So, later in the sort-based aggregation iterator, we can do merge.
+    // If aggregate functions are with mode Final and PartialMerge,
+    // we just need to project the aggregation buffer from an input row.
+    val needsProcess = aggregationMode match {
+      case (Some(Partial), None) => true
+      case (None, Some(Complete)) => true
+      case (Some(Final), Some(Complete)) => true
+      case _ => false
+    }
+
+    if (needsProcess) {
+      // First, we create a buffer.
+      val buffer = createNewAggregationBuffer()
+
+      // Process firstKey and firstInput.
+      // Initialize buffer.
+      buffer.copyFrom(initialAggregationBuffer)
+      processRow(buffer, firstInput)
+      externalSorter.insertKV(firstKey, buffer)
+
+      // Process the rest of input rows.
+      while (inputIter.hasNext) {
+        val newInput = inputIter.next()
+        val groupingKey = groupProjection.apply(newInput)
+        buffer.copyFrom(initialAggregationBuffer)
+        processRow(buffer, newInput)
+        externalSorter.insertKV(groupingKey, buffer)
+      }
+    } else {
+      // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
+      // We need to project the aggregation buffer part from an input row.
+      val buffer = createNewAggregationBuffer()
+      // The originalInputAttributes are using cloneBufferAttributes. So, we need to use
+      // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+      val bufferExtractor = newMutableProjection(
+        allAggregateFunctions.flatMap(_.cloneBufferAttributes),
+        originalInputAttributes)()
+      bufferExtractor.target(buffer)
+
+      // Insert firstKey and its buffer.
+      bufferExtractor(firstInput)
+      externalSorter.insertKV(firstKey, buffer)
+
+      // Insert the rest of input rows.
+      while (inputIter.hasNext) {
+        val newInput = inputIter.next()
+        val groupingKey = groupProjection.apply(newInput)
+        bufferExtractor(newInput)
+        externalSorter.insertKV(groupingKey, buffer)
+      }
+    }
+
+    // Set aggregationMode, processRow, and generateOutput for sort-based aggregation.
+    val newAggregationMode = aggregationMode match {
+      case (Some(Partial), None) => (Some(PartialMerge), None)
+      case (None, Some(Complete)) => (Some(Final), None)
+      case (Some(Final), Some(Complete)) => (Some(Final), None)
+      case other => other
+    }
+    aggregationMode = newAggregationMode
+
+    // Basically the value of the KVIterator returned by externalSorter
+    // will just aggregation buffer. At here, we use cloneBufferAttributes.
+    val newInputAttributes: Seq[Attribute] =
+      allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+
+    // Set up new processRow and generateOutput.
+    processRow = generateProcessRow(newInputAttributes)
+    generateOutput = generateResultProjection()
+
+    // Step 5: Get the sorted iterator from the externalSorter.
+    sortedKVIterator = externalSorter.sortedIterator()
+
+    // Step 6: Pre-load the first key-value pair from the sorted iterator to make
+    // hasNext idempotent.
+    sortedInputHasNewGroup = sortedKVIterator.next()
+
+    // Copy the first key and value (aggregation buffer).
+    if (sortedInputHasNewGroup) {
+      val key = sortedKVIterator.getKey
+      val value = sortedKVIterator.getValue
+      nextGroupingKey = key.copy()
+      currentGroupingKey = key.copy()
+      firstRowInNextGroup = value.copy()
+    }
+
+    // Step 7: set sortBased to true.
+    sortBased = true
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 5: Methods and fields used by sort-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // Indicates if we are using sort-based aggregation. Because we first try to use
+  // hash-based aggregation, its initial value is false.
+  private[this] var sortBased: Boolean = false
+
+  // The KVIterator containing input rows for the sort-based aggregation. It will be
+  // set in switchToSortBasedAggregation when we switch to sort-based aggregation.
+  private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null
+
+  // The grouping key of the current group.
+  private[this] var currentGroupingKey: UnsafeRow = null
+
+  // The grouping key of next group.
+  private[this] var nextGroupingKey: UnsafeRow = null
+
+  // The first row of next group.
+  private[this] var firstRowInNextGroup: UnsafeRow = null
+
+  // Indicates if we has new group of rows from the sorted input iterator.
+  private[this] var sortedInputHasNewGroup: Boolean = false
+
+  // The aggregation buffer used by the sort-based aggregation.
+  private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+
+  // Processes rows in the current group. It will stop when it find a new group.
+  private def processCurrentSortedGroup(): Unit = {
+    // First, we need to copy nextGroupingKey to currentGroupingKey.
+    currentGroupingKey.copyFrom(nextGroupingKey)
+    // Now, we will start to find all rows belonging to this group.
+    // We create a variable to track if we see the next group.
+    var findNextPartition = false
+    // firstRowInNextGroup is the first row of this group. We first process it.
+    processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+    // The search will stop when we see the next group or there is no
+    // input row left in the iter.
+    // Pre-load the first key-value pair to make the condition of the while loop
+    // has no action (we do not trigger loading a new key-value pair
+    // when we evaluate the condition).
+    var hasNext = sortedKVIterator.next()
+    while (!findNextPartition && hasNext) {
+      // Get the grouping key and value (aggregation buffer).
+      val groupingKey = sortedKVIterator.getKey
+      val inputAggregationBuffer = sortedKVIterator.getValue
+
+      // Check if the current row belongs the current input row.
+      if (currentGroupingKey.equals(groupingKey)) {
+        processRow(sortBasedAggregationBuffer, inputAggregationBuffer)
+
+        hasNext = sortedKVIterator.next()
+      } else {
+        // We find a new group.
+        findNextPartition = true
+        // copyFrom will fail when
+        nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy()
+        firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy()
+
+      }
+    }
+    // We have not seen a new group. It means that there is no new row in the input
+    // iter. The current group is the last group of the sortedKVIterator.
+    if (!findNextPartition) {
+      sortedInputHasNewGroup = false
+      sortedKVIterator.close()
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 6: Loads input rows and setup aggregationBufferMapIterator if we
+  //         have not switched to sort-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // Starts to process input rows.
+  testFallbackStartsAt match {
+    case None =>
+      processInputs()
+    case Some(fallbackStartsAt) =>
+      // This is the testing path. processInputsWithControlledFallback is same as processInputs
+      // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+      // have been processed.
+      processInputsWithControlledFallback(fallbackStartsAt)
+  }
+
+  // If we did not switch to sort-based aggregation in processInputs,
+  // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+  if (!sortBased) {
+    // First, set aggregationBufferMapIterator.
+    aggregationBufferMapIterator = hashMap.iterator()
+    // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+    mapIteratorHasNext = aggregationBufferMapIterator.next()
+    // If the map is empty, we just free it.
+    if (!mapIteratorHasNext) {
+      hashMap.free()
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Par 7: Iterator's public methods.
+  ///////////////////////////////////////////////////////////////////////////
+
+  override final def hasNext: Boolean = {
+    (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
+  }
+
+  override final def next(): UnsafeRow = {
+    if (hasNext) {
+      if (sortBased) {
+        // Process the current group.
+        processCurrentSortedGroup()
+        // Generate output row for the current group.
+        val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
+        // Initialize buffer values for the next group.
+        sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+
+        outputRow
+      } else {
+        // We did not fall back to sort-based aggregation.
+        val result =
+          generateOutput(
+            aggregationBufferMapIterator.getKey,
+            aggregationBufferMapIterator.getValue)
+
+        // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext
+        // idempotent.
+        mapIteratorHasNext = aggregationBufferMapIterator.next()
+
+        if (!mapIteratorHasNext) {
+          // If there is no input from aggregationBufferMapIterator, we copy current result.
+          val resultCopy = result.copy()
+          // Then, we free the map.
+          hashMap.free()
+
+          resultCopy
+        } else {
+          result
+        }
+      }
+    } else {
+      // no more result
+      throw new NoSuchElementException
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 8: A utility function used to generate a output row when there is no
+  // input and there is no grouping expression.
+  ///////////////////////////////////////////////////////////////////////////
+  def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
+    if (groupingExpressions.isEmpty) {
+      sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+      // We create a output row and copy it. So, we can free the map.
+      val resultCopy =
+        generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
+      hashMap.free()
+      resultCopy
+    } else {
+      throw new IllegalStateException(
+        "This method should not be called when groupingExpressions is not empty.")
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
deleted file mode 100644
index b465787fe8cbd38a73bc1c618cdf04a73557f2be..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
+++ /dev/null
@@ -1,372 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
-import org.apache.spark.sql.types.StructType
-
-/**
- * An iterator used to evaluate [[AggregateFunction2]].
- * It first tries to use in-memory hash-based aggregation. If we cannot allocate more
- * space for the hash map, we spill the sorted map entries, free the map, and then
- * switch to sort-based aggregation.
- */
-class UnsafeHybridAggregationIterator(
-    groupingKeyAttributes: Seq[Attribute],
-    valueAttributes: Seq[Attribute],
-    inputKVIterator: KVIterator[UnsafeRow, InternalRow],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression2],
-    completeAggregateAttributes: Seq[Attribute],
-    initialInputBufferOffset: Int,
-    resultExpressions: Seq[NamedExpression],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-    outputsUnsafeRows: Boolean)
-  extends AggregationIterator(
-    groupingKeyAttributes,
-    valueAttributes,
-    nonCompleteAggregateExpressions,
-    nonCompleteAggregateAttributes,
-    completeAggregateExpressions,
-    completeAggregateAttributes,
-    initialInputBufferOffset,
-    resultExpressions,
-    newMutableProjection,
-    outputsUnsafeRows) {
-
-  require(groupingKeyAttributes.nonEmpty)
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Unsafe Aggregation buffers
-  ///////////////////////////////////////////////////////////////////////////
-
-  // This is the Unsafe Aggregation Map used to store all buffers.
-  private[this] val buffers = new UnsafeFixedWidthAggregationMap(
-    newBuffer,
-    StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
-    StructType.fromAttributes(groupingKeyAttributes),
-    TaskContext.get.taskMemoryManager(),
-    SparkEnv.get.shuffleMemoryManager,
-    1024 * 16, // initial capacity
-    SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
-    false // disable tracking of performance metrics
-  )
-
-  override protected def newBuffer: UnsafeRow = {
-    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
-    val bufferRowSize: Int = bufferSchema.length
-
-    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
-    val unsafeProjection =
-      UnsafeProjection.create(bufferSchema.map(_.dataType))
-    val buffer = unsafeProjection.apply(genericMutableBuffer)
-    initializeBuffer(buffer)
-    buffer
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Methods and variables related to switching to sort-based aggregation
-  ///////////////////////////////////////////////////////////////////////////
-  private[this] var sortBased = false
-
-  private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _
-
-  // The value part of the input KV iterator is used to store original input values of
-  // aggregate functions, we need to convert them to aggregation buffers.
-  private def processOriginalInput(
-      firstKey: UnsafeRow,
-      firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
-    new KVIterator[UnsafeRow, UnsafeRow] {
-      private[this] var isFirstRow = true
-
-      private[this] var groupingKey: UnsafeRow = _
-
-      private[this] val buffer: UnsafeRow = newBuffer
-
-      override def next(): Boolean = {
-        initializeBuffer(buffer)
-        if (isFirstRow) {
-          isFirstRow = false
-          groupingKey = firstKey
-          processRow(buffer, firstValue)
-
-          true
-        } else if (inputKVIterator.next()) {
-          groupingKey = inputKVIterator.getKey()
-          val value = inputKVIterator.getValue()
-          processRow(buffer, value)
-
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): UnsafeRow = {
-        groupingKey
-      }
-
-      override def getValue(): UnsafeRow = {
-        buffer
-      }
-
-      override def close(): Unit = {
-        // Do nothing.
-      }
-    }
-  }
-
-  // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer.
-  // We need to project the aggregation buffer out.
-  private def projectInputBufferToUnsafe(
-      firstKey: UnsafeRow,
-      firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
-    new KVIterator[UnsafeRow, UnsafeRow] {
-      private[this] var isFirstRow = true
-
-      private[this] var groupingKey: UnsafeRow = _
-
-      private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
-
-      private[this] val value: UnsafeRow = {
-        val genericMutableRow = new GenericMutableRow(bufferSchema.length)
-        UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow)
-      }
-
-      private[this] val projectInputBuffer = {
-        newMutableProjection(bufferSchema, valueAttributes)().target(value)
-      }
-
-      override def next(): Boolean = {
-        if (isFirstRow) {
-          isFirstRow = false
-          groupingKey = firstKey
-          projectInputBuffer(firstValue)
-
-          true
-        } else if (inputKVIterator.next()) {
-          groupingKey = inputKVIterator.getKey()
-          projectInputBuffer(inputKVIterator.getValue())
-
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): UnsafeRow = {
-        groupingKey
-      }
-
-      override def getValue(): UnsafeRow = {
-        value
-      }
-
-      override def close(): Unit = {
-        // Do nothing.
-      }
-    }
-  }
-
-  /**
-   * We need to fall back to sort based aggregation because we do not have enough memory
-   * for our in-memory hash map (i.e. `buffers`).
-   */
-  private def switchToSortBasedAggregation(
-      currentGroupingKey: UnsafeRow,
-      currentRow: InternalRow): Unit = {
-    logInfo("falling back to sort based aggregation.")
-
-    // Step 1: Get the ExternalSorter containing entries of the map.
-    val externalSorter = buffers.destructAndCreateExternalSorter()
-
-    // Step 2: Free the memory used by the map.
-    buffers.free()
-
-    // Step 3: If we have aggregate function with mode Partial or Complete,
-    // we need to process them to get aggregation buffer.
-    // So, later in the sort-based aggregation iterator, we can do merge.
-    // If aggregate functions are with mode Final and PartialMerge,
-    // we just need to project the aggregation buffer from the input.
-    val needsProcess = aggregationMode match {
-      case (Some(Partial), None) => true
-      case (None, Some(Complete)) => true
-      case (Some(Final), Some(Complete)) => true
-      case _ => false
-    }
-
-    val processedIterator = if (needsProcess) {
-      processOriginalInput(currentGroupingKey, currentRow)
-    } else {
-      // The input value's format is groupingExprs + buffer.
-      // We need to project the buffer part out.
-      projectInputBufferToUnsafe(currentGroupingKey, currentRow)
-    }
-
-    // Step 4: Redirect processedIterator to externalSorter.
-    while (processedIterator.next()) {
-      externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue())
-    }
-
-    // Step 5: Get the sorted iterator from the externalSorter.
-    val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()
-
-    // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
-    // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
-    // will be PartialMerge. For a aggregate function with mode Complete,
-    // its mode in the SortBasedAggregationIterator will be Final.
-    val newNonCompleteAggregateExpressions = allAggregateExpressions.map {
-        case AggregateExpression2(func, Partial, isDistinct) =>
-          AggregateExpression2(func, PartialMerge, isDistinct)
-        case AggregateExpression2(func, Complete, isDistinct) =>
-          AggregateExpression2(func, Final, isDistinct)
-        case other => other
-      }
-    val newNonCompleteAggregateAttributes =
-      nonCompleteAggregateAttributes ++ completeAggregateAttributes
-
-    val newValueAttributes =
-      allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
-
-    sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator(
-      groupingKeyAttributes = groupingKeyAttributes,
-      valueAttributes = newValueAttributes,
-      inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]],
-      nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes,
-      completeAggregateExpressions = Nil,
-      completeAggregateAttributes = Nil,
-      initialInputBufferOffset = 0,
-      resultExpressions = resultExpressions,
-      newMutableProjection = newMutableProjection,
-      outputsUnsafeRows = outputsUnsafeRows)
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Methods used to initialize this iterator.
-  ///////////////////////////////////////////////////////////////////////////
-
-  /** Starts to read input rows and falls back to sort-based aggregation if necessary. */
-  protected def initialize(): Unit = {
-    var hasNext = inputKVIterator.next()
-    while (!sortBased && hasNext) {
-      val groupingKey = inputKVIterator.getKey()
-      val currentRow = inputKVIterator.getValue()
-      val buffer = buffers.getAggregationBuffer(groupingKey)
-      if (buffer == null) {
-        // buffer == null means that we could not allocate more memory.
-        // Now, we need to spill the map and switch to sort-based aggregation.
-        switchToSortBasedAggregation(groupingKey, currentRow)
-        sortBased = true
-      } else {
-        processRow(buffer, currentRow)
-        hasNext = inputKVIterator.next()
-      }
-    }
-  }
-
-  // This is the starting point of this iterator.
-  initialize()
-
-  // Creates the iterator for the Hash Aggregation Map after we have populated
-  // contents of that map.
-  private[this] val aggregationBufferMapIterator = buffers.iterator()
-
-  private[this] var _mapIteratorHasNext = false
-
-  // Pre-load the first key-value pair from the map to make hasNext idempotent.
-  if (!sortBased) {
-    _mapIteratorHasNext = aggregationBufferMapIterator.next()
-    // If the map is empty, we just free it.
-    if (!_mapIteratorHasNext) {
-      buffers.free()
-    }
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Iterator's public methods
-  ///////////////////////////////////////////////////////////////////////////
-
-  override final def hasNext: Boolean = {
-    (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext)
-  }
-
-
-  override final def next(): InternalRow = {
-    if (hasNext) {
-      if (sortBased) {
-        sortBasedAggregationIterator.next()
-      } else {
-        // We did not fall back to the sort-based aggregation.
-        val result =
-          generateOutput(
-            aggregationBufferMapIterator.getKey,
-            aggregationBufferMapIterator.getValue)
-        // Pre-load next key-value pair form aggregationBufferMapIterator.
-        _mapIteratorHasNext = aggregationBufferMapIterator.next()
-
-        if (!_mapIteratorHasNext) {
-          val resultCopy = result.copy()
-          buffers.free()
-          resultCopy
-        } else {
-          result
-        }
-      }
-    } else {
-      // no more result
-      throw new NoSuchElementException
-    }
-  }
-}
-
-object UnsafeHybridAggregationIterator {
-  // scalastyle:off
-  def createFromInputIterator(
-      groupingExprs: Seq[NamedExpression],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-      inputAttributes: Seq[Attribute],
-      inputIter: Iterator[InternalRow],
-      outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
-    new UnsafeHybridAggregationIterator(
-      groupingExprs.map(_.toAttribute),
-      inputAttributes,
-      AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter),
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows)
-  }
-  // scalastyle:on
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 960be08f84d941c9e39ce180c9a1c6e06110a275..80816a095ea8c0a295a8a16c2fcd19046fafebb0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,20 +17,41 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
+import org.apache.spark.sql.types.StructType
 
 /**
  * Utility functions used by the query planner to convert our plan to new aggregation code path.
  */
 object Utils {
+  def supportsTungstenAggregate(
+      groupingExpressions: Seq[Expression],
+      aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+    val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+
+    UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+      UnsafeProjection.canSupport(groupingExpressions)
+  }
+
   def planAggregateWithoutDistinct(
       groupingExpressions: Seq[Expression],
       aggregateExpressions: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
+    // Check if we can use TungstenAggregate.
+    val usesTungstenAggregate =
+      child.sqlContext.conf.unsafeEnabled &&
+      aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+      supportsTungstenAggregate(
+        groupingExpressions,
+        aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
+
     // 1. Create an Aggregate Operator for partial aggregations.
     val namedGroupingExpressions = groupingExpressions.map {
       case ne: NamedExpression => ne -> ne
@@ -44,11 +65,23 @@ object Utils {
     val groupExpressionMap = namedGroupingExpressions.toMap
     val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
     val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
-    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
-      agg.aggregateFunction.bufferAttributes
-    }
-    val partialAggregate =
-      Aggregate(
+    val partialAggregateAttributes =
+      partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+    val partialResultExpressions =
+      namedGroupingAttributes ++
+        partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+
+    val partialAggregate = if (usesTungstenAggregate) {
+      TungstenAggregate(
+        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        groupingExpressions = namedGroupingExpressions.map(_._2),
+        nonCompleteAggregateExpressions = partialAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = 0,
+        resultExpressions = partialResultExpressions,
+        child = child)
+    } else {
+      SortBasedAggregate(
         requiredChildDistributionExpressions = None: Option[Seq[Expression]],
         groupingExpressions = namedGroupingExpressions.map(_._2),
         nonCompleteAggregateExpressions = partialAggregateExpressions,
@@ -56,29 +89,57 @@ object Utils {
         completeAggregateExpressions = Nil,
         completeAggregateAttributes = Nil,
         initialInputBufferOffset = 0,
-        resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes,
+        resultExpressions = partialResultExpressions,
         child = child)
+    }
 
     // 2. Create an Aggregate Operator for final aggregations.
     val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
     val finalAggregateAttributes =
       finalAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
       }
-    val rewrittenResultExpressions = resultExpressions.map { expr =>
-      expr.transformDown {
-        case agg: AggregateExpression2 =>
-          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
-        case expression =>
-          // We do not rely on the equality check at here since attributes may
-          // different cosmetically. Instead, we use semanticEquals.
-          groupExpressionMap.collectFirst {
-            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-          }.getOrElse(expression)
-      }.asInstanceOf[NamedExpression]
-    }
-    val finalAggregate =
-      Aggregate(
+
+    val finalAggregate = if (usesTungstenAggregate) {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transformDown {
+          case agg: AggregateExpression2 =>
+            // aggregateFunctionMap contains unique aggregate functions.
+            val aggregateFunction =
+              aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
+            aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      TungstenAggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes,
+        nonCompleteAggregateExpressions = finalAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = namedGroupingAttributes.length,
+        resultExpressions = rewrittenResultExpressions,
+        child = partialAggregate)
+    } else {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transformDown {
+          case agg: AggregateExpression2 =>
+            aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      SortBasedAggregate(
         requiredChildDistributionExpressions = Some(namedGroupingAttributes),
         groupingExpressions = namedGroupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
@@ -88,6 +149,7 @@ object Utils {
         initialInputBufferOffset = namedGroupingAttributes.length,
         resultExpressions = rewrittenResultExpressions,
         child = partialAggregate)
+    }
 
     finalAggregate :: Nil
   }
@@ -96,10 +158,18 @@ object Utils {
       groupingExpressions: Seq[Expression],
       functionsWithDistinct: Seq[AggregateExpression2],
       functionsWithoutDistinct: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
+    val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
+    val usesTungstenAggregate =
+      child.sqlContext.conf.unsafeEnabled &&
+        aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+        supportsTungstenAggregate(
+          groupingExpressions,
+          aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
     // 1. Create an Aggregate Operator for partial aggregations.
     // The grouping expressions are original groupingExpressions and
     // distinct columns. For example, for avg(distinct value) ... group by key
@@ -129,19 +199,26 @@ object Utils {
     val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
     val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
 
-    val partialAggregateExpressions = functionsWithoutDistinct.map {
-      case AggregateExpression2(aggregateFunction, mode, _) =>
-        AggregateExpression2(aggregateFunction, Partial, false)
-    }
-    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
-      agg.aggregateFunction.bufferAttributes
-    }
+    val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+    val partialAggregateAttributes =
+      partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
     val partialAggregateGroupingExpressions =
       (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
     val partialAggregateResult =
-      namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes
-    val partialAggregate =
-      Aggregate(
+      namedGroupingAttributes ++
+        distinctColumnAttributes ++
+        partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+    val partialAggregate = if (usesTungstenAggregate) {
+      TungstenAggregate(
+        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        groupingExpressions = partialAggregateGroupingExpressions,
+        nonCompleteAggregateExpressions = partialAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = 0,
+        resultExpressions = partialAggregateResult,
+        child = child)
+    } else {
+      SortBasedAggregate(
         requiredChildDistributionExpressions = None: Option[Seq[Expression]],
         groupingExpressions = partialAggregateGroupingExpressions,
         nonCompleteAggregateExpressions = partialAggregateExpressions,
@@ -151,20 +228,27 @@ object Utils {
         initialInputBufferOffset = 0,
         resultExpressions = partialAggregateResult,
         child = child)
+    }
 
     // 2. Create an Aggregate Operator for partial merge aggregations.
-    val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
-      case AggregateExpression2(aggregateFunction, mode, _) =>
-        AggregateExpression2(aggregateFunction, PartialMerge, false)
-    }
+    val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
     val partialMergeAggregateAttributes =
-      partialMergeAggregateExpressions.flatMap { agg =>
-        agg.aggregateFunction.bufferAttributes
-      }
+      partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
     val partialMergeAggregateResult =
-      namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes
-    val partialMergeAggregate =
-      Aggregate(
+      namedGroupingAttributes ++
+        distinctColumnAttributes ++
+        partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+    val partialMergeAggregate = if (usesTungstenAggregate) {
+      TungstenAggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+        nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = partialMergeAggregateResult,
+        child = partialAggregate)
+    } else {
+      SortBasedAggregate(
         requiredChildDistributionExpressions = Some(namedGroupingAttributes),
         groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
         nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
@@ -174,48 +258,91 @@ object Utils {
         initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
         resultExpressions = partialMergeAggregateResult,
         child = partialAggregate)
+    }
 
     // 3. Create an Aggregate Operator for partial merge aggregations.
-    val finalAggregateExpressions = functionsWithoutDistinct.map {
-      case AggregateExpression2(aggregateFunction, mode, _) =>
-        AggregateExpression2(aggregateFunction, Final, false)
-    }
+    val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
     val finalAggregateAttributes =
       finalAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
       }
+    // Create a map to store those rewritten aggregate functions. We always need to use
+    // both function and its corresponding isDistinct flag as the key because function itself
+    // does not knows if it is has distinct keyword or now.
+    val rewrittenAggregateFunctions =
+      mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2]
     val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
       // Children of an AggregateFunction with DISTINCT keyword has already
       // been evaluated. At here, we need to replace original children
       // to AttributeReferences.
-      case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+      case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
         val rewrittenAggregateFunction = aggregateFunction.transformDown {
           case expr if distinctColumnExpressionMap.contains(expr) =>
             distinctColumnExpressionMap(expr).toAttribute
         }.asInstanceOf[AggregateFunction2]
+        // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions
+        // to track the old version and the new version of this function.
+        rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction
         // We rewrite the aggregate function to a non-distinct aggregation because
         // its input will have distinct arguments.
+        // We just keep the isDistinct setting to true, so when users look at the query plan,
+        // they still can see distinct aggregations.
         val rewrittenAggregateExpression =
-          AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+          AggregateExpression2(rewrittenAggregateFunction, Complete, true)
 
-        val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+        val aggregateFunctionAttribute =
+          aggregateFunctionMap(agg.aggregateFunction, true)._2
         (rewrittenAggregateExpression -> aggregateFunctionAttribute)
     }.unzip
 
-    val rewrittenResultExpressions = resultExpressions.map { expr =>
-      expr.transform {
-        case agg: AggregateExpression2 =>
-          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
-        case expression =>
-          // We do not rely on the equality check at here since attributes may
-          // different cosmetically. Instead, we use semanticEquals.
-          groupExpressionMap.collectFirst {
-            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-          }.getOrElse(expression)
-      }.asInstanceOf[NamedExpression]
-    }
-    val finalAndCompleteAggregate =
-      Aggregate(
+    val finalAndCompleteAggregate = if (usesTungstenAggregate) {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transform {
+          case agg: AggregateExpression2 =>
+            val function = agg.aggregateFunction
+            val isDistinct = agg.isDistinct
+            val aggregateFunction =
+              if (rewrittenAggregateFunctions.contains(function, isDistinct)) {
+                // If this function has been rewritten, we get the rewritten version from
+                // rewrittenAggregateFunctions.
+                rewrittenAggregateFunctions(function, isDistinct)
+              } else {
+                // Oterwise, we get it from aggregateFunctionMap, which contains unique
+                // aggregate functions that have not been rewritten.
+                aggregateFunctionMap(function, isDistinct)._1
+              }
+            aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      TungstenAggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes,
+        nonCompleteAggregateExpressions = finalAggregateExpressions,
+        completeAggregateExpressions = completeAggregateExpressions,
+        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = rewrittenResultExpressions,
+        child = partialMergeAggregate)
+    } else {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transform {
+          case agg: AggregateExpression2 =>
+            aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+      SortBasedAggregate(
         requiredChildDistributionExpressions = Some(namedGroupingAttributes),
         groupingExpressions = namedGroupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
@@ -225,6 +352,7 @@ object Utils {
         initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
         resultExpressions = rewrittenResultExpressions,
         child = partialMergeAggregate)
+    }
 
     finalAndCompleteAggregate :: Nil
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cef40dd324d9efd5fd6fb000b43d9ec7e74c17ff..c64aa7a07dc2bfb149ae799512467c11e62824cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
     val df = sql(sqlText)
     // First, check if we have GeneratedAggregate.
     val hasGeneratedAgg = df.queryExecution.executedPlan
-      .collect { case _: aggregate.Aggregate => true }
+      .collect { case _: aggregate.TungstenAggregate => true }
       .nonEmpty
     if (!hasGeneratedAgg) {
       fail(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 4b35c8fd83533c45c210442d69f9f0d7993fed37..7b5aa4763fd9ea2f6e8e052910e2f8b22e4150d0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
+import org.apache.spark.sql._
 import org.scalatest.BeforeAndAfterAll
-import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
 
@@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
       Nil)
   }
 
+  test("null literal") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  AVG(null),
+          |  COUNT(null),
+          |  FIRST(null),
+          |  LAST(null),
+          |  MAX(null),
+          |  MIN(null),
+          |  SUM(null)
+        """.stripMargin),
+      Row(null, 0, null, null, null, null, null) :: Nil)
+  }
+
   test("only do grouping") {
     checkAnswer(
       sqlContext.sql(
@@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
           |SELECT avg(value) FROM agg1
         """.stripMargin),
       Row(11.125) :: Nil)
-
-    checkAnswer(
-      sqlContext.sql(
-        """
-          |SELECT avg(null)
-        """.stripMargin),
-      Row(null) :: Nil)
   }
 
   test("udaf") {
@@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
           |  max(distinct value1)
           |FROM agg2
         """.stripMargin),
-      Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+      Row(-60, 70.0, 101.0/9.0, 5.6, 100))
 
     checkAnswer(
       sqlContext.sql(
@@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
         Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
         Row(3, null, 3.0, null, null, null) ::
         Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  count(value1),
+          |  count(*),
+          |  count(1),
+          |  count(DISTINCT value1),
+          |  key
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(3, 3, 3, 2, 1) ::
+        Row(3, 4, 4, 2, 2) ::
+        Row(0, 2, 2, 0, 3) ::
+        Row(3, 4, 4, 3, null) :: Nil)
   }
 
   test("test count") {
@@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
           |FROM agg1
           |GROUP BY key
         """.stripMargin).queryExecution.executedPlan.collect {
-        case agg: aggregate.Aggregate => agg
+        case agg: aggregate.SortBasedAggregate => agg
+        case agg: aggregate.TungstenAggregate => agg
       }
       val message =
         "We should fallback to the old aggregation code path if " +
@@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite {
     sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
   }
 }
+
+class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
+
+  var originalUnsafeEnabled: Boolean = _
+
+  override def beforeAll(): Unit = {
+    originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true")
+    super.beforeAll()
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
+    sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt")
+  }
+
+  override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+    (0 to 2).foreach { fallbackStartsAt =>
+      sqlContext.setConf(
+        "spark.sql.TungstenAggregate.testFallbackStartsAt",
+        fallbackStartsAt.toString)
+
+      // Create a new df to make sure its physical operator picks up
+      // spark.sql.TungstenAggregate.testFallbackStartsAt.
+      val newActual = DataFrame(sqlContext, actual.logicalPlan)
+
+      QueryTest.checkAnswer(newActual, expectedAnswer) match {
+        case Some(errorMessage) =>
+          val newErrorMessage =
+            s"""
+              |The following aggregation query failed when using TungstenAggregate with
+              |controlled fallback (it falls back to sort-based aggregation once it has processed
+              |$fallbackStartsAt input rows). The query is
+              |${actual.queryExecution}
+              |
+              |$errorMessage
+            """.stripMargin
+
+          fail(newErrorMessage)
+        case None =>
+      }
+    }
+  }
+
+  // Override it to make sure we call the actually overridden checkAnswer.
+  override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+    checkAnswer(df, Seq(expectedAnswer))
+  }
+
+  // Override it to make sure we call the actually overridden checkAnswer.
+  override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
+    checkAnswer(df, expectedAnswer.collect())
+  }
+}