Skip to content
Snippets Groups Projects
Commit f79ebf2a authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11370] [SQL] fix a bug in GroupedIterator and create unit test for it

Before this PR, user has to consume the iterator of one group before process next group, or we will get into infinite loops.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9330 from cloud-fan/group.
parent 87f28fc2
No related branches found
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ object GroupedIterator { ...@@ -27,7 +27,7 @@ object GroupedIterator {
keyExpressions: Seq[Expression], keyExpressions: Seq[Expression],
inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
if (input.hasNext) { if (input.hasNext) {
new GroupedIterator(input, keyExpressions, inputSchema) new GroupedIterator(input.buffered, keyExpressions, inputSchema)
} else { } else {
Iterator.empty Iterator.empty
} }
...@@ -64,7 +64,7 @@ object GroupedIterator { ...@@ -64,7 +64,7 @@ object GroupedIterator {
* @param inputSchema The schema of the rows in the `input` iterator. * @param inputSchema The schema of the rows in the `input` iterator.
*/ */
class GroupedIterator private( class GroupedIterator private(
input: Iterator[InternalRow], input: BufferedIterator[InternalRow],
groupingExpressions: Seq[Expression], groupingExpressions: Seq[Expression],
inputSchema: Seq[Attribute]) inputSchema: Seq[Attribute])
extends Iterator[(InternalRow, Iterator[InternalRow])] { extends Iterator[(InternalRow, Iterator[InternalRow])] {
...@@ -83,10 +83,17 @@ class GroupedIterator private( ...@@ -83,10 +83,17 @@ class GroupedIterator private(
/** Holds a copy of an input row that is in the current group. */ /** Holds a copy of an input row that is in the current group. */
var currentGroup = currentRow.copy() var currentGroup = currentRow.copy()
var currentIterator: Iterator[InternalRow] = null
assert(keyOrdering.compare(currentGroup, currentRow) == 0) assert(keyOrdering.compare(currentGroup, currentRow) == 0)
var currentIterator = createGroupValuesIterator()
// Return true if we already have the next iterator or fetching a new iterator is successful. /**
* Return true if we already have the next iterator or fetching a new iterator is successful.
*
* Note that, if we get the iterator by `next`, we should consume it before call `hasNext`,
* because we will consume the input data to skip to next group while fetching a new iterator,
* thus make the previous iterator empty.
*/
def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
def next(): (InternalRow, Iterator[InternalRow]) = { def next(): (InternalRow, Iterator[InternalRow]) = {
...@@ -96,46 +103,64 @@ class GroupedIterator private( ...@@ -96,46 +103,64 @@ class GroupedIterator private(
ret ret
} }
def fetchNextGroupIterator(): Boolean = { private def fetchNextGroupIterator(): Boolean = {
if (currentRow != null || input.hasNext) { assert(currentIterator == null)
val inputIterator = new Iterator[InternalRow] {
// Return true if we have a row and it is in the current group, or if fetching a new row is if (currentRow == null && input.hasNext) {
// successful. currentRow = input.next()
def hasNext = { }
(currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
fetchNextRowInGroup() if (currentRow == null) {
} // These is no data left, return false.
false
} else {
// Skip to next group.
while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
currentRow = input.next()
}
if (keyOrdering.compare(currentGroup, currentRow) == 0) {
// We are in the last group, there is no more groups, return false.
false
} else {
// Now the `currentRow` is the first row of next group.
currentGroup = currentRow.copy()
currentIterator = createGroupValuesIterator()
true
}
}
}
private def createGroupValuesIterator(): Iterator[InternalRow] = {
new Iterator[InternalRow] {
def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()
def next(): InternalRow = {
assert(hasNext)
val res = currentRow
currentRow = null
res
}
def fetchNextRowInGroup(): Boolean = { private def fetchNextRowInGroup(): Boolean = {
if (currentRow != null || input.hasNext) { assert(currentRow == null)
if (input.hasNext) {
// The inner iterator should NOT consume the input into next group, here we use `head` to
// peek the next input, to see if we should continue to process it.
if (keyOrdering.compare(currentGroup, input.head) == 0) {
// Next input is in the current group. Continue the inner iterator.
currentRow = input.next() currentRow = input.next()
if (keyOrdering.compare(currentGroup, currentRow) == 0) { true
// The row is in the current group. Continue the inner iterator.
true
} else {
// We got a row, but its not in the right group. End this inner iterator and prepare
// for the next group.
currentIterator = null
currentGroup = currentRow.copy()
false
}
} else { } else {
// There is no more input so we are done. // Next input is not in the right group. End this inner iterator.
false false
} }
} } else {
// There is no more data, return false.
def next(): InternalRow = { false
assert(hasNext) // Ensure we have fetched the next row.
val res = currentRow
currentRow = null
res
} }
} }
currentIterator = inputIterator
true
} else {
false
} }
} }
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}
class GroupedIteratorSuite extends SparkFunSuite {
test("basic") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)
val result = grouped.map {
case (key, data) =>
assert(key.numFields == 1)
key.getInt(0) -> data.map(encoder.fromRow).toSeq
}.toSeq
assert(result ==
1 -> Seq(input(0), input(1)) ::
2 -> Seq(input(2)) :: Nil)
}
test("group by 2 columns") {
val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(
Row(1, 2L, "a"),
Row(1, 2L, "b"),
Row(1, 3L, "c"),
Row(2, 1L, "d"),
Row(3, 2L, "e"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)
val result = grouped.map {
case (key, data) =>
assert(key.numFields == 2)
(key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
}.toSeq
assert(result ==
(1, 2L, Seq(input(0), input(1))) ::
(1, 3L, Seq(input(2))) ::
(2, 1L, Seq(input(3))) ::
(3, 2L, Seq(input(4))) :: Nil)
}
test("do nothing to the value iterator") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
val encoder = RowEncoder(schema)
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)
assert(grouped.length == 2)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment