diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 5f7341e88c7c9c52162adec56e9d9c8aa8ae906e..8e0fbd109b4133f617d4940888b1f98e8e14970f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.unsafe.KVIterator
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -412,85 +411,3 @@ abstract class AggregationIterator(
    */
   protected def newBuffer: MutableRow
 }
-
-object AggregationIterator {
-  def kvIterator(
-    groupingExpressions: Seq[NamedExpression],
-    newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
-    new KVIterator[InternalRow, InternalRow] {
-      private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
-
-      private[this] var groupingKey: InternalRow = _
-
-      private[this] var value: InternalRow = _
-
-      override def next(): Boolean = {
-        if (inputIter.hasNext) {
-          // Read the next input row.
-          val inputRow = inputIter.next()
-          // Get groupingKey based on groupingExpressions.
-          groupingKey = groupingKeyGenerator(inputRow)
-          // The value is the inputRow.
-          value = inputRow
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): InternalRow = {
-        groupingKey
-      }
-
-      override def getValue(): InternalRow = {
-        value
-      }
-
-      override def close(): Unit = {
-        // Do nothing
-      }
-    }
-  }
-
-  def unsafeKVIterator(
-      groupingExpressions: Seq[NamedExpression],
-      inputAttributes: Seq[Attribute],
-      inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
-    new KVIterator[UnsafeRow, InternalRow] {
-      private[this] val groupingKeyGenerator =
-        UnsafeProjection.create(groupingExpressions, inputAttributes)
-
-      private[this] var groupingKey: UnsafeRow = _
-
-      private[this] var value: InternalRow = _
-
-      override def next(): Boolean = {
-        if (inputIter.hasNext) {
-          // Read the next input row.
-          val inputRow = inputIter.next()
-          // Get groupingKey based on groupingExpressions.
-          groupingKey = groupingKeyGenerator.apply(inputRow)
-          // The value is the inputRow.
-          value = inputRow
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): UnsafeRow = {
-        groupingKey
-      }
-
-      override def getValue(): InternalRow = {
-        value
-      }
-
-      override def close(): Unit = {
-        // Do nothing
-      }
-    }
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index f4c14a9b3556f28b8f57c33cd9e3783488a4acb4..4d37106e007f5f36afe02c7872dade0b80ae0edc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
 
 case class SortBasedAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -79,18 +78,23 @@ case class SortBasedAggregate(
         // so return an empty iterator.
         Iterator[InternalRow]()
       } else {
-        val outputIter = SortBasedAggregationIterator.createFromInputIterator(
-          groupingExpressions,
+        val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) {
+          UnsafeProjection.create(groupingExpressions, child.output)
+        } else {
+          newMutableProjection(groupingExpressions, child.output)()
+        }
+        val outputIter = new SortBasedAggregationIterator(
+          groupingKeyProjection,
+          groupingExpressions.map(_.toAttribute),
+          child.output,
+          iter,
           nonCompleteAggregateExpressions,
           nonCompleteAggregateAttributes,
           completeAggregateExpressions,
           completeAggregateAttributes,
           initialInputBufferOffset,
           resultExpressions,
-          newMutableProjection _,
-          newProjection _,
-          child.output,
-          iter,
+          newMutableProjection,
           outputsUnsafeRows,
           numInputRows,
           numOutputRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index a9e5d175bf89590f5a3cea3e3d1ba34e3c5dcc16..64c673064f576be763c549936ef99ab004aad931 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
 import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.unsafe.KVIterator
 
 /**
  * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
  * sorted by values of [[groupingKeyAttributes]].
  */
 class SortBasedAggregationIterator(
+    groupingKeyProjection: InternalRow => InternalRow,
     groupingKeyAttributes: Seq[Attribute],
     valueAttributes: Seq[Attribute],
-    inputKVIterator: KVIterator[InternalRow, InternalRow],
+    inputIterator: Iterator[InternalRow],
     nonCompleteAggregateExpressions: Seq[AggregateExpression2],
     nonCompleteAggregateAttributes: Seq[Attribute],
     completeAggregateExpressions: Seq[AggregateExpression2],
@@ -90,6 +90,22 @@ class SortBasedAggregationIterator(
   // The aggregation buffer used by the sort-based aggregation.
   private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
 
+  protected def initialize(): Unit = {
+    if (inputIterator.hasNext) {
+      initializeBuffer(sortBasedAggregationBuffer)
+      val inputRow = inputIterator.next()
+      nextGroupingKey = groupingKeyProjection(inputRow).copy()
+      firstRowInNextGroup = inputRow.copy()
+      numInputRows += 1
+      sortedInputHasNewGroup = true
+    } else {
+      // This inputIter is empty.
+      sortedInputHasNewGroup = false
+    }
+  }
+
+  initialize()
+
   /** Processes rows in the current group. It will stop when it find a new group. */
   protected def processCurrentSortedGroup(): Unit = {
     currentGroupingKey = nextGroupingKey
@@ -101,18 +117,15 @@ class SortBasedAggregationIterator(
 
     // The search will stop when we see the next group or there is no
     // input row left in the iter.
-    var hasNext = inputKVIterator.next()
-    while (!findNextPartition && hasNext) {
+    while (!findNextPartition && inputIterator.hasNext) {
       // Get the grouping key.
-      val groupingKey = inputKVIterator.getKey
-      val currentRow = inputKVIterator.getValue
+      val currentRow = inputIterator.next()
+      val groupingKey = groupingKeyProjection(currentRow)
       numInputRows += 1
 
       // Check if the current row belongs the current input row.
       if (currentGroupingKey == groupingKey) {
         processRow(sortBasedAggregationBuffer, currentRow)
-
-        hasNext = inputKVIterator.next()
       } else {
         // We find a new group.
         findNextPartition = true
@@ -149,68 +162,8 @@ class SortBasedAggregationIterator(
     }
   }
 
-  protected def initialize(): Unit = {
-    if (inputKVIterator.next()) {
-      initializeBuffer(sortBasedAggregationBuffer)
-
-      nextGroupingKey = inputKVIterator.getKey().copy()
-      firstRowInNextGroup = inputKVIterator.getValue().copy()
-      numInputRows += 1
-      sortedInputHasNewGroup = true
-    } else {
-      // This inputIter is empty.
-      sortedInputHasNewGroup = false
-    }
-  }
-
-  initialize()
-
   def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
     initializeBuffer(sortBasedAggregationBuffer)
     generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
   }
 }
-
-object SortBasedAggregationIterator {
-  // scalastyle:off
-  def createFromInputIterator(
-      groupingExprs: Seq[NamedExpression],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-      newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
-      inputAttributes: Seq[Attribute],
-      inputIter: Iterator[InternalRow],
-      outputsUnsafeRows: Boolean,
-      numInputRows: LongSQLMetric,
-      numOutputRows: LongSQLMetric): SortBasedAggregationIterator = {
-    val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
-      AggregationIterator.unsafeKVIterator(
-        groupingExprs,
-        inputAttributes,
-        inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
-    } else {
-      AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
-    }
-
-    new SortBasedAggregationIterator(
-      groupingExprs.map(_.toAttribute),
-      inputAttributes,
-      kvIterator,
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows,
-      numInputRows,
-      numOutputRows)
-  }
-  // scalastyle:on
-}