diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala index 10742cf7348f803a4bca1631a6ba8463096339d2..6a8850129f1ac1218ae0b00e1578896e9cff2728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -27,7 +27,7 @@ object GroupedIterator { keyExpressions: Seq[Expression], inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { if (input.hasNext) { - new GroupedIterator(input, keyExpressions, inputSchema) + new GroupedIterator(input.buffered, keyExpressions, inputSchema) } else { Iterator.empty } @@ -64,7 +64,7 @@ object GroupedIterator { * @param inputSchema The schema of the rows in the `input` iterator. */ class GroupedIterator private( - input: Iterator[InternalRow], + input: BufferedIterator[InternalRow], groupingExpressions: Seq[Expression], inputSchema: Seq[Attribute]) extends Iterator[(InternalRow, Iterator[InternalRow])] { @@ -83,10 +83,17 @@ class GroupedIterator private( /** Holds a copy of an input row that is in the current group. */ var currentGroup = currentRow.copy() - var currentIterator: Iterator[InternalRow] = null + assert(keyOrdering.compare(currentGroup, currentRow) == 0) + var currentIterator = createGroupValuesIterator() - // Return true if we already have the next iterator or fetching a new iterator is successful. + /** + * Return true if we already have the next iterator or fetching a new iterator is successful. + * + * Note that, if we get the iterator by `next`, we should consume it before call `hasNext`, + * because we will consume the input data to skip to next group while fetching a new iterator, + * thus make the previous iterator empty. + */ def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator def next(): (InternalRow, Iterator[InternalRow]) = { @@ -96,46 +103,64 @@ class GroupedIterator private( ret } - def fetchNextGroupIterator(): Boolean = { - if (currentRow != null || input.hasNext) { - val inputIterator = new Iterator[InternalRow] { - // Return true if we have a row and it is in the current group, or if fetching a new row is - // successful. - def hasNext = { - (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) || - fetchNextRowInGroup() - } + private def fetchNextGroupIterator(): Boolean = { + assert(currentIterator == null) + + if (currentRow == null && input.hasNext) { + currentRow = input.next() + } + + if (currentRow == null) { + // These is no data left, return false. + false + } else { + // Skip to next group. + while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) { + currentRow = input.next() + } + + if (keyOrdering.compare(currentGroup, currentRow) == 0) { + // We are in the last group, there is no more groups, return false. + false + } else { + // Now the `currentRow` is the first row of next group. + currentGroup = currentRow.copy() + currentIterator = createGroupValuesIterator() + true + } + } + } + + private def createGroupValuesIterator(): Iterator[InternalRow] = { + new Iterator[InternalRow] { + def hasNext: Boolean = currentRow != null || fetchNextRowInGroup() + + def next(): InternalRow = { + assert(hasNext) + val res = currentRow + currentRow = null + res + } - def fetchNextRowInGroup(): Boolean = { - if (currentRow != null || input.hasNext) { + private def fetchNextRowInGroup(): Boolean = { + assert(currentRow == null) + + if (input.hasNext) { + // The inner iterator should NOT consume the input into next group, here we use `head` to + // peek the next input, to see if we should continue to process it. + if (keyOrdering.compare(currentGroup, input.head) == 0) { + // Next input is in the current group. Continue the inner iterator. currentRow = input.next() - if (keyOrdering.compare(currentGroup, currentRow) == 0) { - // The row is in the current group. Continue the inner iterator. - true - } else { - // We got a row, but its not in the right group. End this inner iterator and prepare - // for the next group. - currentIterator = null - currentGroup = currentRow.copy() - false - } + true } else { - // There is no more input so we are done. + // Next input is not in the right group. End this inner iterator. false } - } - - def next(): InternalRow = { - assert(hasNext) // Ensure we have fetched the next row. - val res = currentRow - currentRow = null - res + } else { + // There is no more data, return false. + false } } - currentIterator = inputIterator - true - } else { - false } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e7a08481cfa805a343e617251f5816da85eaeab2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} + +class GroupedIteratorSuite extends SparkFunSuite { + + test("basic") { + val schema = new StructType().add("i", IntegerType).add("s", StringType) + val encoder = RowEncoder(schema) + val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0)), schema.toAttributes) + + val result = grouped.map { + case (key, data) => + assert(key.numFields == 1) + key.getInt(0) -> data.map(encoder.fromRow).toSeq + }.toSeq + + assert(result == + 1 -> Seq(input(0), input(1)) :: + 2 -> Seq(input(2)) :: Nil) + } + + test("group by 2 columns") { + val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) + val encoder = RowEncoder(schema) + + val input = Seq( + Row(1, 2L, "a"), + Row(1, 2L, "b"), + Row(1, 3L, "c"), + Row(2, 1L, "d"), + Row(3, 2L, "e")) + + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) + + val result = grouped.map { + case (key, data) => + assert(key.numFields == 2) + (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq) + }.toSeq + + assert(result == + (1, 2L, Seq(input(0), input(1))) :: + (1, 3L, Seq(input(2))) :: + (2, 1L, Seq(input(3))) :: + (3, 2L, Seq(input(4))) :: Nil) + } + + test("do nothing to the value iterator") { + val schema = new StructType().add("i", IntegerType).add("s", StringType) + val encoder = RowEncoder(schema) + val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0)), schema.toAttributes) + + assert(grouped.length == 2) + } +}