diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index ab20ee573ab5d5888fd660ed7c03a959ce33d2cb..c117dff9c8b1dfb74b214b4ee8c4766c67fea796 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -199,99 +199,115 @@ case class SortMergeOuterJoin( } } - +/** + * An iterator for outputting rows in left outer join. + */ private class LeftOuterIterator( smjScanner: SortMergeJoinScanner, rightNullRow: InternalRow, boundCondition: InternalRow => Boolean, resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var rightIdx: Int = 0 - assert(smjScanner.getBufferedMatches.length == 0) - - private def advanceLeft(): Boolean = { - rightIdx = 0 - if (smjScanner.findNextOuterJoinRows()) { - joinedRow.withLeft(smjScanner.getStreamedRow) - if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching right rows, so return nulls for the right row - joinedRow.withRight(rightNullRow) - } else { - // Find the next row from the right input that satisfied the bound condition - if (!advanceRightUntilBoundConditionSatisfied()) { - joinedRow.withRight(rightNullRow) - } - } - true - } else { - // Left input has been exhausted - false - } - } - - private def advanceRightUntilBoundConditionSatisfied(): Boolean = { - var foundMatch: Boolean = false - while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) - rightIdx += 1 - } - foundMatch - } - - override def advanceNext(): Boolean = { - val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() - if (r) numRows += 1 - r - } + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { - override def getRow: InternalRow = resultProj(joinedRow) + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) } +/** + * An iterator for outputting rows in right outer join. + */ private class RightOuterIterator( smjScanner: SortMergeJoinScanner, leftNullRow: InternalRow, boundCondition: InternalRow => Boolean, resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var leftIdx: Int = 0 + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) +} + +/** + * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. + * + * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the + * streamed side will output 0 or many rows, one for each matching row on the buffered side. + * If there are no matches, then the buffered side of the joined output will be a null row. + * + * In left outer join, the left is the streamed side and the right is the buffered side. + * In right outer join, the right is the streamed side and the left is the buffered side. + * + * @param smjScanner a scanner that streams rows and buffers any matching rows + * @param bufferedSideNullRow the default row to return when a streamed row has no matches + * @param boundCondition an additional filter condition for buffered rows + * @param resultProj how the output should be projected + * @param numOutputRows an accumulator metric for the number of rows output + */ +private abstract class OneSideOuterIterator( + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { + + // A row to store the joined result, reused many times + protected[this] val joinedRow: JoinedRow = new JoinedRow() + + // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row + private[this] var bufferIndex: Int = 0 + + // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) - private def advanceRight(): Boolean = { - leftIdx = 0 + // Set output methods to be overridden by subclasses + protected def setStreamSideOutput(row: InternalRow): Unit + protected def setBufferedSideOutput(row: InternalRow): Unit + + /** + * Advance to the next row on the stream side and populate the buffer with matches. + * @return whether there are more rows in the stream to consume. + */ + private def advanceStream(): Boolean = { + bufferIndex = 0 if (smjScanner.findNextOuterJoinRows()) { - joinedRow.withRight(smjScanner.getStreamedRow) + setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching left rows, so return nulls for the left row - joinedRow.withLeft(leftNullRow) + // There are no matching rows in the buffer, so return the null row + setBufferedSideOutput(bufferedSideNullRow) } else { - // Find the next row from the left input that satisfied the bound condition - if (!advanceLeftUntilBoundConditionSatisfied()) { - joinedRow.withLeft(leftNullRow) + // Find the next row in the buffer that satisfied the bound condition + if (!advanceBufferUntilBoundConditionSatisfied()) { + setBufferedSideOutput(bufferedSideNullRow) } } true } else { - // Right input has been exhausted + // Stream has been exhausted false } } - private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + /** + * Advance to the next row in the buffer that satisfies the bound condition. + * @return whether there is such a row in the current buffer. + */ + private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) - leftIdx += 1 + while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { + setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + foundMatch = boundCondition(joinedRow) + bufferIndex += 1 } foundMatch } override def advanceNext(): Boolean = { - val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() - if (r) numRows += 1 + val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() + if (r) numOutputRows += 1 r }