diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
index 10742cf7348f803a4bca1631a6ba8463096339d2..6a8850129f1ac1218ae0b00e1578896e9cff2728 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
@@ -27,7 +27,7 @@ object GroupedIterator {
       keyExpressions: Seq[Expression],
       inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
     if (input.hasNext) {
-      new GroupedIterator(input, keyExpressions, inputSchema)
+      new GroupedIterator(input.buffered, keyExpressions, inputSchema)
     } else {
       Iterator.empty
     }
@@ -64,7 +64,7 @@ object GroupedIterator {
  * @param inputSchema The schema of the rows in the `input` iterator.
  */
 class GroupedIterator private(
-    input: Iterator[InternalRow],
+    input: BufferedIterator[InternalRow],
     groupingExpressions: Seq[Expression],
     inputSchema: Seq[Attribute])
   extends Iterator[(InternalRow, Iterator[InternalRow])] {
@@ -83,10 +83,17 @@ class GroupedIterator private(
 
   /** Holds a copy of an input row that is in the current group. */
   var currentGroup = currentRow.copy()
-  var currentIterator: Iterator[InternalRow] = null
+
   assert(keyOrdering.compare(currentGroup, currentRow) == 0)
+  var currentIterator = createGroupValuesIterator()
 
-  // Return true if we already have the next iterator or fetching a new iterator is successful.
+  /**
+   * Return true if we already have the next iterator or fetching a new iterator is successful.
+   *
+   * Note that, if we get the iterator by `next`, we should consume it before call `hasNext`,
+   * because we will consume the input data to skip to next group while fetching a new iterator,
+   * thus make the previous iterator empty.
+   */
   def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
 
   def next(): (InternalRow, Iterator[InternalRow]) = {
@@ -96,46 +103,64 @@ class GroupedIterator private(
     ret
   }
 
-  def fetchNextGroupIterator(): Boolean = {
-    if (currentRow != null || input.hasNext) {
-      val inputIterator = new Iterator[InternalRow] {
-        // Return true if we have a row and it is in the current group, or if fetching a new row is
-        // successful.
-        def hasNext = {
-          (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
-            fetchNextRowInGroup()
-        }
+  private def fetchNextGroupIterator(): Boolean = {
+    assert(currentIterator == null)
+
+    if (currentRow == null && input.hasNext) {
+      currentRow = input.next()
+    }
+
+    if (currentRow == null) {
+      // These is no data left, return false.
+      false
+    } else {
+      // Skip to next group.
+      while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
+        currentRow = input.next()
+      }
+
+      if (keyOrdering.compare(currentGroup, currentRow) == 0) {
+        // We are in the last group, there is no more groups, return false.
+        false
+      } else {
+        // Now the `currentRow` is the first row of next group.
+        currentGroup = currentRow.copy()
+        currentIterator = createGroupValuesIterator()
+        true
+      }
+    }
+  }
+
+  private def createGroupValuesIterator(): Iterator[InternalRow] = {
+    new Iterator[InternalRow] {
+      def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()
+
+      def next(): InternalRow = {
+        assert(hasNext)
+        val res = currentRow
+        currentRow = null
+        res
+      }
 
-        def fetchNextRowInGroup(): Boolean = {
-          if (currentRow != null || input.hasNext) {
+      private def fetchNextRowInGroup(): Boolean = {
+        assert(currentRow == null)
+
+        if (input.hasNext) {
+          // The inner iterator should NOT consume the input into next group, here we use `head` to
+          // peek the next input, to see if we should continue to process it.
+          if (keyOrdering.compare(currentGroup, input.head) == 0) {
+            // Next input is in the current group.  Continue the inner iterator.
             currentRow = input.next()
-            if (keyOrdering.compare(currentGroup, currentRow) == 0) {
-              // The row is in the current group.  Continue the inner iterator.
-              true
-            } else {
-              // We got a row, but its not in the right group.  End this inner iterator and prepare
-              // for the next group.
-              currentIterator = null
-              currentGroup = currentRow.copy()
-              false
-            }
+            true
           } else {
-            // There is no more input so we are done.
+            // Next input is not in the right group.  End this inner iterator.
             false
           }
-        }
-
-        def next(): InternalRow = {
-          assert(hasNext) // Ensure we have fetched the next row.
-          val res = currentRow
-          currentRow = null
-          res
+        } else {
+          // There is no more data, return false.
+          false
         }
       }
-      currentIterator = inputIterator
-      true
-    } else {
-      false
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e7a08481cfa805a343e617251f5816da85eaeab2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}
+
+class GroupedIteratorSuite extends SparkFunSuite {
+
+  test("basic") {
+    val schema = new StructType().add("i", IntegerType).add("s", StringType)
+    val encoder = RowEncoder(schema)
+    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
+    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+      Seq('i.int.at(0)), schema.toAttributes)
+
+    val result = grouped.map {
+      case (key, data) =>
+        assert(key.numFields == 1)
+        key.getInt(0) -> data.map(encoder.fromRow).toSeq
+    }.toSeq
+
+    assert(result ==
+      1 -> Seq(input(0), input(1)) ::
+      2 -> Seq(input(2)) :: Nil)
+  }
+
+  test("group by 2 columns") {
+    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
+    val encoder = RowEncoder(schema)
+
+    val input = Seq(
+      Row(1, 2L, "a"),
+      Row(1, 2L, "b"),
+      Row(1, 3L, "c"),
+      Row(2, 1L, "d"),
+      Row(3, 2L, "e"))
+
+    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)
+
+    val result = grouped.map {
+      case (key, data) =>
+        assert(key.numFields == 2)
+        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
+    }.toSeq
+
+    assert(result ==
+      (1, 2L, Seq(input(0), input(1))) ::
+      (1, 3L, Seq(input(2))) ::
+      (2, 1L, Seq(input(3))) ::
+      (3, 2L, Seq(input(4))) :: Nil)
+  }
+
+  test("do nothing to the value iterator") {
+    val schema = new StructType().add("i", IntegerType).add("s", StringType)
+    val encoder = RowEncoder(schema)
+    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
+    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+      Seq('i.int.at(0)), schema.toAttributes)
+
+    assert(grouped.length == 2)
+  }
+}