diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
index ab20ee573ab5d5888fd660ed7c03a959ce33d2cb..c117dff9c8b1dfb74b214b4ee8c4766c67fea796 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -199,99 +199,115 @@ case class SortMergeOuterJoin(
   }
 }
 
-
+/**
+ * An iterator for outputting rows in left outer join.
+ */
 private class LeftOuterIterator(
     smjScanner: SortMergeJoinScanner,
     rightNullRow: InternalRow,
     boundCondition: InternalRow => Boolean,
     resultProj: InternalRow => InternalRow,
-    numRows: LongSQLMetric
-  ) extends RowIterator {
-  private[this] val joinedRow: JoinedRow = new JoinedRow()
-  private[this] var rightIdx: Int = 0
-  assert(smjScanner.getBufferedMatches.length == 0)
-
-  private def advanceLeft(): Boolean = {
-    rightIdx = 0
-    if (smjScanner.findNextOuterJoinRows()) {
-      joinedRow.withLeft(smjScanner.getStreamedRow)
-      if (smjScanner.getBufferedMatches.isEmpty) {
-        // There are no matching right rows, so return nulls for the right row
-        joinedRow.withRight(rightNullRow)
-      } else {
-        // Find the next row from the right input that satisfied the bound condition
-        if (!advanceRightUntilBoundConditionSatisfied()) {
-          joinedRow.withRight(rightNullRow)
-        }
-      }
-      true
-    } else {
-      // Left input has been exhausted
-      false
-    }
-  }
-
-  private def advanceRightUntilBoundConditionSatisfied(): Boolean = {
-    var foundMatch: Boolean = false
-    while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) {
-      foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx)))
-      rightIdx += 1
-    }
-    foundMatch
-  }
-
-  override def advanceNext(): Boolean = {
-    val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft()
-    if (r) numRows += 1
-    r
-  }
+    numOutputRows: LongSQLMetric)
+  extends OneSideOuterIterator(
+    smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
 
-  override def getRow: InternalRow = resultProj(joinedRow)
+  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
 }
 
+/**
+ * An iterator for outputting rows in right outer join.
+ */
 private class RightOuterIterator(
     smjScanner: SortMergeJoinScanner,
     leftNullRow: InternalRow,
     boundCondition: InternalRow => Boolean,
     resultProj: InternalRow => InternalRow,
-    numRows: LongSQLMetric
-  ) extends RowIterator {
-  private[this] val joinedRow: JoinedRow = new JoinedRow()
-  private[this] var leftIdx: Int = 0
+    numOutputRows: LongSQLMetric)
+  extends OneSideOuterIterator(
+    smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
+
+  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+}
+
+/**
+ * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
+ *
+ * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
+ * streamed side will output 0 or many rows, one for each matching row on the buffered side.
+ * If there are no matches, then the buffered side of the joined output will be a null row.
+ *
+ * In left outer join, the left is the streamed side and the right is the buffered side.
+ * In right outer join, the right is the streamed side and the left is the buffered side.
+ *
+ * @param smjScanner a scanner that streams rows and buffers any matching rows
+ * @param bufferedSideNullRow the default row to return when a streamed row has no matches
+ * @param boundCondition an additional filter condition for buffered rows
+ * @param resultProj how the output should be projected
+ * @param numOutputRows an accumulator metric for the number of rows output
+ */
+private abstract class OneSideOuterIterator(
+    smjScanner: SortMergeJoinScanner,
+    bufferedSideNullRow: InternalRow,
+    boundCondition: InternalRow => Boolean,
+    resultProj: InternalRow => InternalRow,
+    numOutputRows: LongSQLMetric) extends RowIterator {
+
+  // A row to store the joined result, reused many times
+  protected[this] val joinedRow: JoinedRow = new JoinedRow()
+
+  // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
+  private[this] var bufferIndex: Int = 0
+
+  // This iterator is initialized lazily so there should be no matches initially
   assert(smjScanner.getBufferedMatches.length == 0)
 
-  private def advanceRight(): Boolean = {
-    leftIdx = 0
+  // Set output methods to be overridden by subclasses
+  protected def setStreamSideOutput(row: InternalRow): Unit
+  protected def setBufferedSideOutput(row: InternalRow): Unit
+
+  /**
+   * Advance to the next row on the stream side and populate the buffer with matches.
+   * @return whether there are more rows in the stream to consume.
+   */
+  private def advanceStream(): Boolean = {
+    bufferIndex = 0
     if (smjScanner.findNextOuterJoinRows()) {
-      joinedRow.withRight(smjScanner.getStreamedRow)
+      setStreamSideOutput(smjScanner.getStreamedRow)
       if (smjScanner.getBufferedMatches.isEmpty) {
-        // There are no matching left rows, so return nulls for the left row
-        joinedRow.withLeft(leftNullRow)
+        // There are no matching rows in the buffer, so return the null row
+        setBufferedSideOutput(bufferedSideNullRow)
       } else {
-        // Find the next row from the left input that satisfied the bound condition
-        if (!advanceLeftUntilBoundConditionSatisfied()) {
-          joinedRow.withLeft(leftNullRow)
+        // Find the next row in the buffer that satisfied the bound condition
+        if (!advanceBufferUntilBoundConditionSatisfied()) {
+          setBufferedSideOutput(bufferedSideNullRow)
         }
       }
       true
     } else {
-      // Right input has been exhausted
+      // Stream has been exhausted
       false
     }
   }
 
-  private def advanceLeftUntilBoundConditionSatisfied(): Boolean = {
+  /**
+   * Advance to the next row in the buffer that satisfies the bound condition.
+   * @return whether there is such a row in the current buffer.
+   */
+  private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
     var foundMatch: Boolean = false
-    while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) {
-      foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx)))
-      leftIdx += 1
+    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
+      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+      foundMatch = boundCondition(joinedRow)
+      bufferIndex += 1
     }
     foundMatch
   }
 
   override def advanceNext(): Boolean = {
-    val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight()
-    if (r) numRows += 1
+    val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
+    if (r) numOutputRows += 1
     r
   }