diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8f65672d5a8399dd1198e5da487c4ee4bba84a0f..a85f87aece45b83c19a6f85bb25583b3c410f26e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // This file defines the configuration options for Spark SQL.
@@ -715,6 +716,27 @@ object SQLConf {
       .stringConf
       .createWithDefault(TimeZone.getDefault().getID())
 
+  val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD =
+    buildConf("spark.sql.windowExec.buffer.spill.threshold")
+      .internal()
+      .doc("Threshold for number of rows buffered in window operator")
+      .intConf
+      .createWithDefault(4096)
+
+  val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD =
+    buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold")
+      .internal()
+      .doc("Threshold for number of rows buffered in sort merge join operator")
+      .intConf
+      .createWithDefault(Int.MaxValue)
+
+  val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD =
+    buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold")
+      .internal()
+      .doc("Threshold for number of rows buffered in cartesian product operator")
+      .intConf
+      .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -945,6 +967,14 @@ class SQLConf extends Serializable with Logging {
 
   def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD)
 
+  def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)
+
+  def sortMergeJoinExecBufferSpillThreshold: Int =
+    getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD)
+
+  def cartesianProductExecBufferSpillThreshold: Int =
+    getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
+
   def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
 
   /** ********************** SQLConf functionality methods ************ */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
new file mode 100644
index 0000000000000000000000000000000000000000..458ac4ba3637ceebf1174805fa93c0b789de783e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+
+/**
+ * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
+ * threshold of rows is reached.
+ *
+ * Setting spill threshold faces following trade-off:
+ *
+ * - If the spill threshold is too high, the in-memory array may occupy more memory than is
+ *   available, resulting in OOM.
+ * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
+ *   This may lead to a performance regression compared to the normal case of using an
+ *   [[ArrayBuffer]] or [[Array]].
+ */
+private[sql] class ExternalAppendOnlyUnsafeRowArray(
+    taskMemoryManager: TaskMemoryManager,
+    blockManager: BlockManager,
+    serializerManager: SerializerManager,
+    taskContext: TaskContext,
+    initialSize: Int,
+    pageSizeBytes: Long,
+    numRowsSpillThreshold: Int) extends Logging {
+
+  def this(numRowsSpillThreshold: Int) {
+    this(
+      TaskContext.get().taskMemoryManager(),
+      SparkEnv.get.blockManager,
+      SparkEnv.get.serializerManager,
+      TaskContext.get(),
+      1024,
+      SparkEnv.get.memoryManager.pageSizeBytes,
+      numRowsSpillThreshold)
+  }
+
+  private val initialSizeOfInMemoryBuffer =
+    Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+
+  private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
+    new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
+  } else {
+    null
+  }
+
+  private var spillableArray: UnsafeExternalSorter = _
+  private var numRows = 0
+
+  // A counter to keep track of total modifications done to this array since its creation.
+  // This helps to invalidate iterators when there are changes done to the backing array.
+  private var modificationsCount: Long = 0
+
+  private var numFieldsPerRow = 0
+
+  def length: Int = numRows
+
+  def isEmpty: Boolean = numRows == 0
+
+  /**
+   * Clears up resources (eg. memory) held by the backing storage
+   */
+  def clear(): Unit = {
+    if (spillableArray != null) {
+      // The last `spillableArray` of this task will be cleaned up via task completion listener
+      // inside `UnsafeExternalSorter`
+      spillableArray.cleanupResources()
+      spillableArray = null
+    } else if (inMemoryBuffer != null) {
+      inMemoryBuffer.clear()
+    }
+    numFieldsPerRow = 0
+    numRows = 0
+    modificationsCount += 1
+  }
+
+  def add(unsafeRow: UnsafeRow): Unit = {
+    if (numRows < numRowsSpillThreshold) {
+      inMemoryBuffer += unsafeRow.copy()
+    } else {
+      if (spillableArray == null) {
+        logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+          s"${classOf[UnsafeExternalSorter].getName}")
+
+        // We will not sort the rows, so prefixComparator and recordComparator are null
+        spillableArray = UnsafeExternalSorter.create(
+          taskMemoryManager,
+          blockManager,
+          serializerManager,
+          taskContext,
+          null,
+          null,
+          initialSize,
+          pageSizeBytes,
+          numRowsSpillThreshold,
+          false)
+
+        // populate with existing in-memory buffered rows
+        if (inMemoryBuffer != null) {
+          inMemoryBuffer.foreach(existingUnsafeRow =>
+            spillableArray.insertRecord(
+              existingUnsafeRow.getBaseObject,
+              existingUnsafeRow.getBaseOffset,
+              existingUnsafeRow.getSizeInBytes,
+              0,
+              false)
+          )
+          inMemoryBuffer.clear()
+        }
+        numFieldsPerRow = unsafeRow.numFields()
+      }
+
+      spillableArray.insertRecord(
+        unsafeRow.getBaseObject,
+        unsafeRow.getBaseOffset,
+        unsafeRow.getSizeInBytes,
+        0,
+        false)
+    }
+
+    numRows += 1
+    modificationsCount += 1
+  }
+
+  /**
+   * Creates an [[Iterator]] for the current rows in the array starting from a user provided index
+   *
+   * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of
+   * the iterator, then the iterator is invalidated thus saving clients from thinking that they
+   * have read all the data while there were new rows added to this array.
+   */
+  def generateIterator(startIndex: Int): Iterator[UnsafeRow] = {
+    if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) {
+      throw new ArrayIndexOutOfBoundsException(
+        "Invalid `startIndex` provided for generating iterator over the array. " +
+          s"Total elements: $numRows, requested `startIndex`: $startIndex")
+    }
+
+    if (spillableArray == null) {
+      new InMemoryBufferIterator(startIndex)
+    } else {
+      new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex)
+    }
+  }
+
+  def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0)
+
+  private[this]
+  abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] {
+    private val expectedModificationsCount = modificationsCount
+
+    protected def isModified(): Boolean = expectedModificationsCount != modificationsCount
+
+    protected def throwExceptionIfModified(): Unit = {
+      if (expectedModificationsCount != modificationsCount) {
+        throw new ConcurrentModificationException(
+          s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " +
+            s"since the creation of this Iterator")
+      }
+    }
+  }
+
+  private[this] class InMemoryBufferIterator(startIndex: Int)
+    extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+    private var currentIndex = startIndex
+
+    override def hasNext(): Boolean = !isModified() && currentIndex < numRows
+
+    override def next(): UnsafeRow = {
+      throwExceptionIfModified()
+      val result = inMemoryBuffer(currentIndex)
+      currentIndex += 1
+      result
+    }
+  }
+
+  private[this] class SpillableArrayIterator(
+      iterator: UnsafeSorterIterator,
+      numFieldPerRow: Int,
+      startIndex: Int)
+    extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+    private val currentRow = new UnsafeRow(numFieldPerRow)
+
+    def init(): Unit = {
+      var i = 0
+      while (i < startIndex) {
+        if (iterator.hasNext) {
+          iterator.loadNext()
+        } else {
+          throw new ArrayIndexOutOfBoundsException(
+            "Invalid `startIndex` provided for generating iterator over the array. " +
+              s"Total elements: $numRows, requested `startIndex`: $startIndex")
+        }
+        i += 1
+      }
+    }
+
+    // Traverse upto the given [[startIndex]]
+    init()
+
+    override def hasNext(): Boolean = !isModified() && iterator.hasNext
+
+    override def next(): UnsafeRow = {
+      throwExceptionIfModified()
+      iterator.loadNext()
+      currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength)
+      currentRow
+    }
+  }
+}
+
+private[sql] object ExternalAppendOnlyUnsafeRowArray {
+  val DefaultInitialSizeOfInMemoryBuffer = 128
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8341fe2ffd078234849e38166ed3f7ade28435db..f38098695131741378149c1c9c0da3a7ad110220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark._
 import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
-import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 /**
  * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
  * will be much faster than building the right partition for every row in left RDD, it also
  * materialize the right RDD (in case of the right RDD is nondeterministic).
  */
-class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
+class UnsafeCartesianRDD(
+    left : RDD[UnsafeRow],
+    right : RDD[UnsafeRow],
+    numFieldsOfRight: Int,
+    spillThreshold: Int)
   extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
 
   override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
-    // We will not sort the rows, so prefixComparator and recordComparator are null.
-    val sorter = UnsafeExternalSorter.create(
-      context.taskMemoryManager(),
-      SparkEnv.get.blockManager,
-      SparkEnv.get.serializerManager,
-      context,
-      null,
-      null,
-      1024,
-      SparkEnv.get.memoryManager.pageSizeBytes,
-      SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
-        UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
-      false)
+    val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
 
     val partition = split.asInstanceOf[CartesianPartition]
-    for (y <- rdd2.iterator(partition.s2, context)) {
-      sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
-    }
+    rdd2.iterator(partition.s2, context).foreach(rowArray.add)
 
-    // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
-    def createIter(): Iterator[UnsafeRow] = {
-      val iter = sorter.getIterator
-      val unsafeRow = new UnsafeRow(numFieldsOfRight)
-      new Iterator[UnsafeRow] {
-        override def hasNext: Boolean = {
-          iter.hasNext
-        }
-        override def next(): UnsafeRow = {
-          iter.loadNext()
-          unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
-          unsafeRow
-        }
-      }
-    }
+    // Create an iterator from rowArray
+    def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
 
     val resultIter =
       for (x <- rdd1.iterator(partition.s1, context);
            y <- createIter()) yield (x, y)
     CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
-      resultIter, sorter.cleanupResources())
+      resultIter, rowArray.clear())
   }
 }
 
@@ -97,7 +71,9 @@ case class CartesianProductExec(
     val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
     val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
-    val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
+    val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
+
+    val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
     pair.mapPartitionsWithIndexInternal { (index, iter) =>
       val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
       val filtered = if (condition.isDefined) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ca9c0ed8cec32e01e836a765542876164121cf30..bcdc4dcdf7d9904498a4291ed320a686678c28f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
+ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.util.collection.BitSet
 
@@ -95,9 +96,13 @@ case class SortMergeJoinExec(
   private def createRightKeyGenerator(): Projection =
     UnsafeProjection.create(rightKeys, right.output)
 
+  private def getSpillThreshold: Int = {
+    sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-
+    val spillThreshold = getSpillThreshold
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
       val boundCondition: (InternalRow) => Boolean = {
         condition.map { cond =>
@@ -115,39 +120,39 @@ case class SortMergeJoinExec(
         case _: InnerLike =>
           new RowIterator {
             private[this] var currentLeftRow: InternalRow = _
-            private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
-            private[this] var currentMatchIdx: Int = -1
+            private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
+            private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
             private[this] val smjScanner = new SortMergeJoinScanner(
               createLeftKeyGenerator(),
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
             if (smjScanner.findNextInnerJoinRows()) {
               currentRightMatches = smjScanner.getBufferedMatches
               currentLeftRow = smjScanner.getStreamedRow
-              currentMatchIdx = 0
+              rightMatchesIterator = currentRightMatches.generateIterator()
             }
 
             override def advanceNext(): Boolean = {
-              while (currentMatchIdx >= 0) {
-                if (currentMatchIdx == currentRightMatches.length) {
+              while (rightMatchesIterator != null) {
+                if (!rightMatchesIterator.hasNext) {
                   if (smjScanner.findNextInnerJoinRows()) {
                     currentRightMatches = smjScanner.getBufferedMatches
                     currentLeftRow = smjScanner.getStreamedRow
-                    currentMatchIdx = 0
+                    rightMatchesIterator = currentRightMatches.generateIterator()
                   } else {
                     currentRightMatches = null
                     currentLeftRow = null
-                    currentMatchIdx = -1
+                    rightMatchesIterator = null
                     return false
                   }
                 }
-                joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
-                currentMatchIdx += 1
+                joinRow(currentLeftRow, rightMatchesIterator.next())
                 if (boundCondition(joinRow)) {
                   numOutputRows += 1
                   return true
@@ -165,7 +170,8 @@ case class SortMergeJoinExec(
             bufferedKeyGenerator = createRightKeyGenerator(),
             keyOrdering,
             streamedIter = RowIterator.fromScala(leftIter),
-            bufferedIter = RowIterator.fromScala(rightIter)
+            bufferedIter = RowIterator.fromScala(rightIter),
+            spillThreshold
           )
           val rightNullRow = new GenericInternalRow(right.output.length)
           new LeftOuterIterator(
@@ -177,7 +183,8 @@ case class SortMergeJoinExec(
             bufferedKeyGenerator = createLeftKeyGenerator(),
             keyOrdering,
             streamedIter = RowIterator.fromScala(rightIter),
-            bufferedIter = RowIterator.fromScala(leftIter)
+            bufferedIter = RowIterator.fromScala(leftIter),
+            spillThreshold
           )
           val leftNullRow = new GenericInternalRow(left.output.length)
           new RightOuterIterator(
@@ -209,7 +216,8 @@ case class SortMergeJoinExec(
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
@@ -217,14 +225,15 @@ case class SortMergeJoinExec(
               while (smjScanner.findNextInnerJoinRows()) {
                 val currentRightMatches = smjScanner.getBufferedMatches
                 currentLeftRow = smjScanner.getStreamedRow
-                var i = 0
-                while (i < currentRightMatches.length) {
-                  joinRow(currentLeftRow, currentRightMatches(i))
-                  if (boundCondition(joinRow)) {
-                    numOutputRows += 1
-                    return true
+                if (currentRightMatches != null && currentRightMatches.length > 0) {
+                  val rightMatchesIterator = currentRightMatches.generateIterator()
+                  while (rightMatchesIterator.hasNext) {
+                    joinRow(currentLeftRow, rightMatchesIterator.next())
+                    if (boundCondition(joinRow)) {
+                      numOutputRows += 1
+                      return true
+                    }
                   }
-                  i += 1
                 }
               }
               false
@@ -241,7 +250,8 @@ case class SortMergeJoinExec(
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
@@ -249,17 +259,16 @@ case class SortMergeJoinExec(
               while (smjScanner.findNextOuterJoinRows()) {
                 currentLeftRow = smjScanner.getStreamedRow
                 val currentRightMatches = smjScanner.getBufferedMatches
-                if (currentRightMatches == null) {
+                if (currentRightMatches == null || currentRightMatches.length == 0) {
                   return true
                 }
-                var i = 0
                 var found = false
-                while (!found && i < currentRightMatches.length) {
-                  joinRow(currentLeftRow, currentRightMatches(i))
+                val rightMatchesIterator = currentRightMatches.generateIterator()
+                while (!found && rightMatchesIterator.hasNext) {
+                  joinRow(currentLeftRow, rightMatchesIterator.next())
                   if (boundCondition(joinRow)) {
                     found = true
                   }
-                  i += 1
                 }
                 if (!found) {
                   numOutputRows += 1
@@ -281,7 +290,8 @@ case class SortMergeJoinExec(
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
@@ -290,14 +300,13 @@ case class SortMergeJoinExec(
                 currentLeftRow = smjScanner.getStreamedRow
                 val currentRightMatches = smjScanner.getBufferedMatches
                 var found = false
-                if (currentRightMatches != null) {
-                  var i = 0
-                  while (!found && i < currentRightMatches.length) {
-                    joinRow(currentLeftRow, currentRightMatches(i))
+                if (currentRightMatches != null && currentRightMatches.length > 0) {
+                  val rightMatchesIterator = currentRightMatches.generateIterator()
+                  while (!found && rightMatchesIterator.hasNext) {
+                    joinRow(currentLeftRow, rightMatchesIterator.next())
                     if (boundCondition(joinRow)) {
                       found = true
                     }
-                    i += 1
                   }
                 }
                 result.setBoolean(0, found)
@@ -376,8 +385,11 @@ case class SortMergeJoinExec(
 
     // A list to hold all matched rows from right side.
     val matches = ctx.freshName("matches")
-    val clsName = classOf[java.util.ArrayList[InternalRow]].getName
-    ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+    val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
+
+    val spillThreshold = getSpillThreshold
+
+    ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
     // Copy the left keys as class members so they could be used in next function call.
     val matchedKeyVars = copyKeys(ctx, leftKeyVars)
 
@@ -428,7 +440,7 @@ case class SortMergeJoinExec(
          |        }
          |        $leftRow = null;
          |      } else {
-         |        $matches.add($rightRow.copy());
+         |        $matches.add((UnsafeRow) $rightRow);
          |        $rightRow = null;;
          |      }
          |    } while ($leftRow != null);
@@ -517,8 +529,7 @@ case class SortMergeJoinExec(
     val rightRow = ctx.freshName("rightRow")
     val rightVars = createRightVar(ctx, rightRow)
 
-    val size = ctx.freshName("size")
-    val i = ctx.freshName("i")
+    val iterator = ctx.freshName("iterator")
     val numOutput = metricTerm(ctx, "numOutputRows")
     val (beforeLoop, condCheck) = if (condition.isDefined) {
       // Split the code of creating variables based on whether it's used by condition or not.
@@ -551,10 +562,10 @@ case class SortMergeJoinExec(
 
     s"""
        |while (findNextInnerJoinRows($leftInput, $rightInput)) {
-       |  int $size = $matches.size();
        |  ${beforeLoop.trim}
-       |  for (int $i = 0; $i < $size; $i ++) {
-       |    InternalRow $rightRow = (InternalRow) $matches.get($i);
+       |  scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
+       |  while ($iterator.hasNext()) {
+       |    InternalRow $rightRow = (InternalRow) $iterator.next();
        |    ${condCheck.trim}
        |    $numOutput.add(1);
        |    ${consume(ctx, leftVars ++ rightVars)}
@@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner(
     bufferedKeyGenerator: Projection,
     keyOrdering: Ordering[InternalRow],
     streamedIter: RowIterator,
-    bufferedIter: RowIterator) {
+    bufferedIter: RowIterator,
+    bufferThreshold: Int) {
   private[this] var streamedRow: InternalRow = _
   private[this] var streamedRowKey: InternalRow = _
   private[this] var bufferedRow: InternalRow = _
@@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner(
    */
   private[this] var matchJoinKey: InternalRow = _
   /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
-  private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+  private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
 
   // Initialization (note: do _not_ want to advance streamed here).
   advancedBufferedToRowWithNullFreeJoinKey()
@@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner(
 
   def getStreamedRow: InternalRow = streamedRow
 
-  def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+  def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
 
   /**
    * Advances both input iterators, stopping when we have found rows with matching join keys.
@@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner(
     matchJoinKey = streamedRowKey.copy()
     bufferedMatches.clear()
     do {
-      bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+      bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow])
       advancedBufferedToRowWithNullFreeJoinKey()
     } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
   }
@@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator(
   protected[this] val joinedRow: JoinedRow = new JoinedRow()
 
   // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
-  private[this] var bufferIndex: Int = 0
+  private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
 
   // This iterator is initialized lazily so there should be no matches initially
   assert(smjScanner.getBufferedMatches.length == 0)
@@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator(
    * @return whether there are more rows in the stream to consume.
    */
   private def advanceStream(): Boolean = {
-    bufferIndex = 0
+    rightMatchesIterator = null
     if (smjScanner.findNextOuterJoinRows()) {
       setStreamSideOutput(smjScanner.getStreamedRow)
       if (smjScanner.getBufferedMatches.isEmpty) {
@@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator(
    */
   private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
     var foundMatch: Boolean = false
-    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
-      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+    if (rightMatchesIterator == null) {
+      rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator()
+    }
+
+    while (!foundMatch && rightMatchesIterator.hasNext) {
+      setBufferedSideOutput(rightMatchesIterator.next())
       foundMatch = boundCondition(joinedRow)
-      bufferIndex += 1
     }
     foundMatch
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
deleted file mode 100644
index ee36c8425151979b874c9427c1c584082c9da7e2..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.window
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
-
-
-/**
- * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the
- * row buffer is used to materialize a partition of rows since we need to repeatedly scan these
- * rows in window function processing.
- */
-private[window] abstract class RowBuffer {
-
-  /** Number of rows. */
-  def size: Int
-
-  /** Return next row in the buffer, null if no more left. */
-  def next(): InternalRow
-
-  /** Skip the next `n` rows. */
-  def skip(n: Int): Unit
-
-  /** Return a new RowBuffer that has the same rows. */
-  def copy(): RowBuffer
-}
-
-/**
- * A row buffer based on ArrayBuffer (the number of rows is limited).
- */
-private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
-
-  private[this] var cursor: Int = -1
-
-  /** Number of rows. */
-  override def size: Int = buffer.length
-
-  /** Return next row in the buffer, null if no more left. */
-  override def next(): InternalRow = {
-    cursor += 1
-    if (cursor < buffer.length) {
-      buffer(cursor)
-    } else {
-      null
-    }
-  }
-
-  /** Skip the next `n` rows. */
-  override def skip(n: Int): Unit = {
-    cursor += n
-  }
-
-  /** Return a new RowBuffer that has the same rows. */
-  override def copy(): RowBuffer = {
-    new ArrayRowBuffer(buffer)
-  }
-}
-
-/**
- * An external buffer of rows based on UnsafeExternalSorter.
- */
-private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
-  extends RowBuffer {
-
-  private[this] val iter: UnsafeSorterIterator = sorter.getIterator
-
-  private[this] val currentRow = new UnsafeRow(numFields)
-
-  /** Number of rows. */
-  override def size: Int = iter.getNumRecords()
-
-  /** Return next row in the buffer, null if no more left. */
-  override def next(): InternalRow = {
-    if (iter.hasNext) {
-      iter.loadNext()
-      currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
-      currentRow
-    } else {
-      null
-    }
-  }
-
-  /** Skip the next `n` rows. */
-  override def skip(n: Int): Unit = {
-    var i = 0
-    while (i < n && iter.hasNext) {
-      iter.loadNext()
-      i += 1
-    }
-  }
-
-  /** Return a new RowBuffer that has the same rows. */
-  override def copy(): RowBuffer = {
-    new ExternalRowBuffer(sorter, numFields)
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 80b87d5ffa7976f16b05f05ed6c31895fe0a44e7..950a6794a74a3ebd90d8d8169c0dcde95c2fd123 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 /**
  * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -284,6 +282,7 @@ case class WindowExec(
     // Unwrap the expressions and factories from the map.
     val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
     val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+    val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
 
     // Start processing.
     child.execute().mapPartitions { stream =>
@@ -310,10 +309,12 @@ case class WindowExec(
         fetchNextRow()
 
         // Manage the current partition.
-        val rows = ArrayBuffer.empty[UnsafeRow]
         val inputFields = child.output.length
-        var sorter: UnsafeExternalSorter = null
-        var rowBuffer: RowBuffer = null
+
+        val buffer: ExternalAppendOnlyUnsafeRowArray =
+          new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+        var bufferIterator: Iterator[UnsafeRow] = _
+
         val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
         val frames = factories.map(_(windowFunctionResult))
         val numFrames = frames.length
@@ -323,78 +324,43 @@ case class WindowExec(
           val currentGroup = nextGroup.copy()
 
           // clear last partition
-          if (sorter != null) {
-            // the last sorter of this task will be cleaned up via task completion listener
-            sorter.cleanupResources()
-            sorter = null
-          } else {
-            rows.clear()
-          }
+          buffer.clear()
 
           while (nextRowAvailable && nextGroup == currentGroup) {
-            if (sorter == null) {
-              rows += nextRow.copy()
-
-              if (rows.length >= 4096) {
-                // We will not sort the rows, so prefixComparator and recordComparator are null.
-                sorter = UnsafeExternalSorter.create(
-                  TaskContext.get().taskMemoryManager(),
-                  SparkEnv.get.blockManager,
-                  SparkEnv.get.serializerManager,
-                  TaskContext.get(),
-                  null,
-                  null,
-                  1024,
-                  SparkEnv.get.memoryManager.pageSizeBytes,
-                  SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
-                    UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
-                  false)
-                rows.foreach { r =>
-                  sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
-                }
-                rows.clear()
-              }
-            } else {
-              sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
-                nextRow.getSizeInBytes, 0, false)
-            }
+            buffer.add(nextRow)
             fetchNextRow()
           }
-          if (sorter != null) {
-            rowBuffer = new ExternalRowBuffer(sorter, inputFields)
-          } else {
-            rowBuffer = new ArrayRowBuffer(rows)
-          }
 
           // Setup the frames.
           var i = 0
           while (i < numFrames) {
-            frames(i).prepare(rowBuffer.copy())
+            frames(i).prepare(buffer)
             i += 1
           }
 
           // Setup iteration
           rowIndex = 0
-          rowsSize = rowBuffer.size
+          bufferIterator = buffer.generateIterator()
         }
 
         // Iteration
         var rowIndex = 0
-        var rowsSize = 0L
 
-        override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
+        override final def hasNext: Boolean =
+          (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
 
         val join = new JoinedRow
         override final def next(): InternalRow = {
           // Load the next partition if we need to.
-          if (rowIndex >= rowsSize && nextRowAvailable) {
+          if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
             fetchNextPartition()
           }
 
-          if (rowIndex < rowsSize) {
+          if (bufferIterator.hasNext) {
+            val current = bufferIterator.next()
+
             // Get the results for the window frames.
             var i = 0
-            val current = rowBuffer.next()
             while (i < numFrames) {
               frames(i).write(rowIndex, current)
               i += 1
@@ -406,7 +372,9 @@ case class WindowExec(
 
             // Return the projection.
             result(join)
-          } else throw new NoSuchElementException
+          } else {
+            throw new NoSuchElementException
+          }
         }
       }
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 70efc0f78ddb0bc22c6c57f9e3e63c8f0e31f01d..af2b4fb92062bf5b3c833ee3ff7d7c893fa12d2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -22,6 +22,7 @@ import java.util
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
 
 
 /**
@@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame {
    *
    * @param rows to calculate the frame results for.
    */
-  def prepare(rows: RowBuffer): Unit
+  def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit
 
   /**
    * Write the current results to the target row.
@@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame {
   def write(index: Int, current: InternalRow): Unit
 }
 
+object WindowFunctionFrame {
+  def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = {
+    if (iterator.hasNext) iterator.next() else null
+  }
+}
+
 /**
  * The offset window frame calculates frames containing LEAD/LAG statements.
  *
@@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+  /**
+   * An iterator over the [[input]]
+   */
+  private[this] var inputIterator: Iterator[UnsafeRow] = _
 
   /** Index of the input row currently used for output. */
   private[this] var inputIndex = 0
@@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame(
     newMutableProjection(boundExpressions, Nil).target(target)
   }
 
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
+    inputIterator = input.generateIterator()
     // drain the first few rows if offset is larger than zero
     inputIndex = 0
     while (inputIndex < offset) {
-      input.next()
+      if (inputIterator.hasNext) inputIterator.next()
       inputIndex += 1
     }
     inputIndex = offset
   }
 
   override def write(index: Int, current: InternalRow): Unit = {
-    if (inputIndex >= 0 && inputIndex < input.size) {
-      val r = input.next()
+    if (inputIndex >= 0 && inputIndex < input.length) {
+      val r = WindowFunctionFrame.getNextOrNull(inputIterator)
       projection(r)
     } else {
       // Use default values since the offset row does not exist.
@@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+  /**
+   * An iterator over the [[input]]
+   */
+  private[this] var inputIterator: Iterator[UnsafeRow] = _
 
   /** The next row from `input`. */
   private[this] var nextRow: InternalRow = null
@@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame(
   private[this] var inputLowIndex = 0
 
   /** Prepare the frame for calculating a new partition. Reset all variables. */
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
-    nextRow = rows.next()
+    inputIterator = input.generateIterator()
+    nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
     inputHighIndex = 0
     inputLowIndex = 0
     buffer.clear()
@@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame(
     // the output row upper bound.
     while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
       buffer.add(nextRow.copy())
-      nextRow = input.next()
+      nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
       inputHighIndex += 1
       bufferUpdated = true
     }
@@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame(
 
     // Only recalculate and update when the buffer changes.
     if (bufferUpdated) {
-      processor.initialize(input.size)
+      processor.initialize(input.length)
       val iter = buffer.iterator()
       while (iter.hasNext) {
         processor.update(iter.next())
@@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Prepare the frame for calculating a new partition. Process all rows eagerly. */
-  override def prepare(rows: RowBuffer): Unit = {
-    val size = rows.size
-    processor.initialize(size)
-    var i = 0
-    while (i < size) {
-      processor.update(rows.next())
-      i += 1
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+    processor.initialize(rows.length)
+
+    val iterator = rows.generateIterator()
+    while (iterator.hasNext) {
+      processor.update(iterator.next())
     }
   }
 
@@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+  /**
+   * An iterator over the [[input]]
+   */
+  private[this] var inputIterator: Iterator[UnsafeRow] = _
 
   /** The next row from `input`. */
   private[this] var nextRow: InternalRow = null
@@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
   private[this] var inputIndex = 0
 
   /** Prepare the frame for calculating a new partition. */
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
-    nextRow = rows.next()
     inputIndex = 0
-    processor.initialize(input.size)
+    inputIterator = input.generateIterator()
+    if (inputIterator.hasNext) {
+      nextRow = inputIterator.next()
+    }
+
+    processor.initialize(input.length)
   }
 
   /** Write the frame columns for the current row to the given target row. */
@@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
     // the output row upper bound.
     while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
       processor.update(nextRow)
-      nextRow = input.next()
+      nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
       inputIndex += 1
       bufferUpdated = true
     }
@@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
 
   /**
    * Index of the first input row with a value equal to or greater than the lower bound of the
@@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
   private[this] var inputIndex = 0
 
   /** Prepare the frame for calculating a new partition. */
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
     inputIndex = 0
   }
@@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
   override def write(index: Int, current: InternalRow): Unit = {
     var bufferUpdated = index == 0
 
-    // Duplicate the input to have a new iterator
-    val tmp = input.copy()
-
-    // Drop all rows from the buffer for which the input row value is smaller than
+    // Ignore all the rows from the buffer for which the input row value is smaller than
     // the output row lower bound.
-    tmp.skip(inputIndex)
-    var nextRow = tmp.next()
+    val iterator = input.generateIterator(startIndex = inputIndex)
+
+    var nextRow = WindowFunctionFrame.getNextOrNull(iterator)
     while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
-      nextRow = tmp.next()
       inputIndex += 1
       bufferUpdated = true
+      nextRow = WindowFunctionFrame.getNextOrNull(iterator)
     }
 
     // Only recalculate and update when the buffer changes.
     if (bufferUpdated) {
-      processor.initialize(input.size)
-      while (nextRow != null) {
+      processor.initialize(input.length)
+      if (nextRow != null) {
         processor.update(nextRow)
-        nextRow = tmp.next()
+      }
+      while (iterator.hasNext) {
+        processor.update(iterator.next())
       }
       processor.evaluate(target)
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 2e006735d123e21eff90902a364a4594e22381ae..1a66aa85f5a02bfeea2c32c0aac0f0d6191261e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import scala.collection.mutable.ListBuffer
 import scala.language.existentials
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.execution.joins._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
-
+import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
 
 class JoinSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext {
 
     cartesianQueries.foreach(checkCartesianDetection)
   }
+
+  test("test SortMergeJoin (without spill)") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+      "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) {
+
+      assertNotSpilled(sparkContext, "inner join") {
+        checkAnswer(
+          sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+          Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+        )
+      }
+
+      val expected = new ListBuffer[Row]()
+      expected.append(
+        Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+        Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+        Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+      )
+      for (i <- 4 to 100) {
+        expected.append(Row(i, i.toString, null, null))
+      }
+
+      assertNotSpilled(sparkContext, "left outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData big
+              |LEFT OUTER JOIN
+              |  testData2 small
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+
+      assertNotSpilled(sparkContext, "right outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData2 small
+              |RIGHT OUTER JOIN
+              |  testData big
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+    }
+  }
+
+  test("test SortMergeJoin (with spill)") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+      "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
+
+      assertSpilled(sparkContext, "inner join") {
+        checkAnswer(
+          sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+          Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+        )
+      }
+
+      val expected = new ListBuffer[Row]()
+      expected.append(
+        Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+        Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+        Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+      )
+      for (i <- 4 to 100) {
+        expected.append(Row(i, i.toString, null, null))
+      }
+
+      assertSpilled(sparkContext, "left outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData big
+              |LEFT OUTER JOIN
+              |  testData2 small
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+
+      assertSpilled(sparkContext, "right outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData2 small
+              |RIGHT OUTER JOIN
+              |  testData big
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+
+      // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]]
+      // so should not cause any spill
+      assertNotSpilled(sparkContext, "full outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData2 small
+              |FULL OUTER JOIN
+              |  testData big
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
new file mode 100644
index 0000000000000000000000000000000000000000..00c5f2550cbb1e0aa27bf407087afb16068fbc8d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.Benchmark
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+object ExternalAppendOnlyUnsafeRowArrayBenchmark {
+
+  def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = {
+    val random = new java.util.Random()
+    val rows = (1 to numRows).map(_ => {
+      val row = new UnsafeRow(1)
+      row.pointTo(new Array[Byte](64), 16)
+      row.setLong(0, random.nextLong())
+      row
+    })
+
+    val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows)
+
+    // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an
+    // in-memory buffer of size `numSpillThreshold`. This will mimic that
+    val initialSize =
+      Math.min(
+        ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+        numSpillThreshold)
+
+    benchmark.addCase("ArrayBuffer") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = new ArrayBuffer[UnsafeRow](initialSize)
+
+        // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a
+        // copy of the row. This will mimic that
+        rows.foreach(x => array += x.copy())
+
+        var i = 0
+        val n = array.length
+        while (i < n) {
+          sum = sum + array(i).getLong(0)
+          i += 1
+        }
+        array.clear()
+      }
+    }
+
+    benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+        rows.foreach(x => array.add(x))
+
+        val iterator = array.generateIterator()
+        while (iterator.hasNext) {
+          sum = sum + iterator.next().getLong(0)
+        }
+        array.clear()
+      }
+    }
+
+    val conf = new SparkConf(false)
+    // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+    // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+    conf.set("spark.serializer.objectStreamReset", "1")
+    conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+    val sc = new SparkContext("local", "test", conf)
+    val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+    TaskContext.setTaskContext(taskContext)
+    benchmark.run()
+    sc.stop()
+  }
+
+  def testAgainstRawUnsafeExternalSorter(
+      numSpillThreshold: Int,
+      numRows: Int,
+      iterations: Int): Unit = {
+
+    val random = new java.util.Random()
+    val rows = (1 to numRows).map(_ => {
+      val row = new UnsafeRow(1)
+      row.pointTo(new Array[Byte](64), 16)
+      row.setLong(0, random.nextLong())
+      row
+    })
+
+    val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows)
+
+    benchmark.addCase("UnsafeExternalSorter") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = UnsafeExternalSorter.create(
+          TaskContext.get().taskMemoryManager(),
+          SparkEnv.get.blockManager,
+          SparkEnv.get.serializerManager,
+          TaskContext.get(),
+          null,
+          null,
+          1024,
+          SparkEnv.get.memoryManager.pageSizeBytes,
+          numSpillThreshold,
+          false)
+
+        rows.foreach(x =>
+          array.insertRecord(
+            x.getBaseObject,
+            x.getBaseOffset,
+            x.getSizeInBytes,
+            0,
+            false))
+
+        val unsafeRow = new UnsafeRow(1)
+        val iter = array.getIterator
+        while (iter.hasNext) {
+          iter.loadNext()
+          unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+          sum = sum + unsafeRow.getLong(0)
+        }
+        array.cleanupResources()
+      }
+    }
+
+    benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+        rows.foreach(x => array.add(x))
+
+        val iterator = array.generateIterator()
+        while (iterator.hasNext) {
+          sum = sum + iterator.next().getLong(0)
+        }
+        array.clear()
+      }
+    }
+
+    val conf = new SparkConf(false)
+    // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+    // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+    conf.set("spark.serializer.objectStreamReset", "1")
+    conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+    val sc = new SparkContext("local", "test", conf)
+    val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+    TaskContext.setTaskContext(taskContext)
+    benchmark.run()
+    sc.stop()
+  }
+
+  def main(args: Array[String]): Unit = {
+
+    // ========================================================================================= //
+    // WITHOUT SPILL
+    // ========================================================================================= //
+
+    val spillThreshold = 100 * 1000
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Array with 1000 rows:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    ArrayBuffer                                   7821 / 7941         33.5          29.8       1.0X
+    ExternalAppendOnlyUnsafeRowArray              8798 / 8819         29.8          33.6       0.9X
+    */
+    testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18)
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Array with 30000 rows:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    ArrayBuffer                                 19200 / 19206         25.6          39.1       1.0X
+    ExternalAppendOnlyUnsafeRowArray            19558 / 19562         25.1          39.8       1.0X
+    */
+    testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14)
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Array with 100000 rows:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    ArrayBuffer                                   5949 / 6028         17.2          58.1       1.0X
+    ExternalAppendOnlyUnsafeRowArray              6078 / 6138         16.8          59.4       1.0X
+    */
+    testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10)
+
+    // ========================================================================================= //
+    // WITH SPILL
+    // ========================================================================================= //
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Spilling with 1000 rows:                 Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    UnsafeExternalSorter                          9239 / 9470         28.4          35.2       1.0X
+    ExternalAppendOnlyUnsafeRowArray              8857 / 8909         29.6          33.8       1.0X
+    */
+    testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18)
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Spilling with 10000 rows:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    UnsafeExternalSorter                             4 /    5         39.3          25.5       1.0X
+    ExternalAppendOnlyUnsafeRowArray                 5 /    6         29.8          33.5       0.8X
+    */
+    testAgainstRawUnsafeExternalSorter(
+      UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..53c41639942b48da28e4799314398199ed915932
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext {
+  private val random = new java.util.Random()
+  private var taskContext: TaskContext = _
+
+  override def afterAll(): Unit = TaskContext.unset()
+
+  private def withExternalArray(spillThreshold: Int)
+                               (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+
+    taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+    TaskContext.setTaskContext(taskContext)
+
+    val array = new ExternalAppendOnlyUnsafeRowArray(
+      taskContext.taskMemoryManager(),
+      SparkEnv.get.blockManager,
+      SparkEnv.get.serializerManager,
+      taskContext,
+      1024,
+      SparkEnv.get.memoryManager.pageSizeBytes,
+      spillThreshold)
+    try f(array) finally {
+      array.clear()
+    }
+  }
+
+  private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = {
+    val valueInserted = random.nextLong()
+
+    val row = new UnsafeRow(1)
+    row.pointTo(new Array[Byte](64), 16)
+    row.setLong(0, valueInserted)
+    array.add(row)
+    valueInserted
+  }
+
+  private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = {
+    assert(iterator.hasNext)
+    val actualRow = iterator.next()
+    assert(actualRow.getLong(0) == expectedValue)
+    assert(actualRow.getSizeInBytes == 16)
+  }
+
+  private def validateData(
+      array: ExternalAppendOnlyUnsafeRowArray,
+      expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = {
+    val iterator = array.generateIterator()
+    for (value <- expectedValues) {
+      checkIfValueExists(iterator, value)
+    }
+
+    assert(!iterator.hasNext)
+    iterator
+  }
+
+  private def populateRows(
+      array: ExternalAppendOnlyUnsafeRowArray,
+      numRowsToBePopulated: Int): ArrayBuffer[Long] = {
+    val populatedValues = new ArrayBuffer[Long]
+    populateRows(array, numRowsToBePopulated, populatedValues)
+  }
+
+  private def populateRows(
+      array: ExternalAppendOnlyUnsafeRowArray,
+      numRowsToBePopulated: Int,
+      populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = {
+    for (_ <- 0 until numRowsToBePopulated) {
+      populatedValues.append(insertRow(array))
+    }
+    populatedValues
+  }
+
+  private def getNumBytesSpilled: Long = {
+    TaskContext.get().taskMetrics().memoryBytesSpilled
+  }
+
+  private def assertNoSpill(): Unit = {
+    assert(getNumBytesSpilled == 0)
+  }
+
+  private def assertSpill(): Unit = {
+    assert(getNumBytesSpilled > 0)
+  }
+
+  test("insert rows less than the spillThreshold") {
+    val spillThreshold = 100
+    withExternalArray(spillThreshold) { array =>
+      assert(array.isEmpty)
+
+      val expectedValues = populateRows(array, 1)
+      assert(!array.isEmpty)
+      assert(array.length == 1)
+
+      val iterator1 = validateData(array, expectedValues)
+
+      // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]])
+      // Verify that NO spill has happened
+      populateRows(array, spillThreshold - 1, expectedValues)
+      assert(array.length == spillThreshold)
+      assertNoSpill()
+
+      val iterator2 = validateData(array, expectedValues)
+
+      assert(!iterator1.hasNext)
+      assert(!iterator2.hasNext)
+    }
+  }
+
+  test("insert rows more than the spillThreshold to force spill") {
+    val spillThreshold = 100
+    withExternalArray(spillThreshold) { array =>
+      val numValuesInserted = 20 * spillThreshold
+
+      assert(array.isEmpty)
+      val expectedValues = populateRows(array, 1)
+      assert(array.length == 1)
+
+      val iterator1 = validateData(array, expectedValues)
+
+      // Populate more rows to trigger spill. Verify that spill has happened
+      populateRows(array, numValuesInserted - 1, expectedValues)
+      assert(array.length == numValuesInserted)
+      assertSpill()
+
+      val iterator2 = validateData(array, expectedValues)
+      assert(!iterator2.hasNext)
+
+      assert(!iterator1.hasNext)
+      intercept[ConcurrentModificationException](iterator1.next())
+    }
+  }
+
+  test("iterator on an empty array should be empty") {
+    withExternalArray(spillThreshold = 10) { array =>
+      val iterator = array.generateIterator()
+      assert(array.isEmpty)
+      assert(array.length == 0)
+      assert(!iterator.hasNext)
+    }
+  }
+
+  test("generate iterator with negative start index") {
+    withExternalArray(spillThreshold = 2) { array =>
+      val exception =
+        intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10))
+
+      assert(exception.getMessage.contains(
+        "Invalid `startIndex` provided for generating iterator over the array")
+      )
+    }
+  }
+
+  test("generate iterator with start index exceeding array's size (without spill)") {
+    val spillThreshold = 2
+    withExternalArray(spillThreshold) { array =>
+      populateRows(array, spillThreshold / 2)
+
+      val exception =
+        intercept[ArrayIndexOutOfBoundsException](
+          array.generateIterator(startIndex = spillThreshold * 10))
+      assert(exception.getMessage.contains(
+        "Invalid `startIndex` provided for generating iterator over the array"))
+    }
+  }
+
+  test("generate iterator with start index exceeding array's size (with spill)") {
+    val spillThreshold = 2
+    withExternalArray(spillThreshold) { array =>
+      populateRows(array, spillThreshold * 2)
+
+      val exception =
+        intercept[ArrayIndexOutOfBoundsException](
+          array.generateIterator(startIndex = spillThreshold * 10))
+
+      assert(exception.getMessage.contains(
+        "Invalid `startIndex` provided for generating iterator over the array"))
+    }
+  }
+
+  test("generate iterator with custom start index (without spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      val expectedValues = populateRows(array, spillThreshold)
+      val startIndex = spillThreshold / 2
+      val iterator = array.generateIterator(startIndex = startIndex)
+      for (i <- startIndex until expectedValues.length) {
+        checkIfValueExists(iterator, expectedValues(i))
+      }
+    }
+  }
+
+  test("generate iterator with custom start index (with spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      val expectedValues = populateRows(array, spillThreshold * 10)
+      val startIndex = spillThreshold * 2
+      val iterator = array.generateIterator(startIndex = startIndex)
+      for (i <- startIndex until expectedValues.length) {
+        checkIfValueExists(iterator, expectedValues(i))
+      }
+    }
+  }
+
+  test("test iterator invalidation (without spill)") {
+    withExternalArray(spillThreshold = 10) { array =>
+      // insert 2 rows, iterate until the first row
+      populateRows(array, 2)
+
+      var iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      // Adding more row(s) should invalidate any old iterators
+      populateRows(array, 1)
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+
+      // Clearing the array should also invalidate any old iterators
+      iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      array.clear()
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+    }
+  }
+
+  test("test iterator invalidation (with spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      // Populate enough rows so that spill has happens
+      populateRows(array, spillThreshold * 2)
+      assertSpill()
+
+      var iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      // Adding more row(s) should invalidate any old iterators
+      populateRows(array, 1)
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+
+      // Clearing the array should also invalidate any old iterators
+      iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      array.clear()
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+    }
+  }
+
+  test("clear on an empty the array") {
+    withExternalArray(spillThreshold = 2) { array =>
+      val iterator = array.generateIterator()
+      assert(!iterator.hasNext)
+
+      // multiple clear'ing should not have an side-effect
+      array.clear()
+      array.clear()
+      array.clear()
+      assert(array.isEmpty)
+      assert(array.length == 0)
+
+      // Clearing an empty array should also invalidate any old iterators
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+    }
+  }
+
+  test("clear array (without spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      // Populate rows ... but not enough to trigger spill
+      populateRows(array, spillThreshold / 2)
+      assertNoSpill()
+
+      // Clear the array
+      array.clear()
+      assert(array.isEmpty)
+
+      // Re-populate few rows so that there is no spill
+      // Verify the data. Verify that there was no spill
+      val expectedValues = populateRows(array, spillThreshold / 3)
+      validateData(array, expectedValues)
+      assertNoSpill()
+
+      // Populate more rows .. enough to not trigger a spill.
+      // Verify the data. Verify that there was no spill
+      populateRows(array, spillThreshold / 3, expectedValues)
+      validateData(array, expectedValues)
+      assertNoSpill()
+    }
+  }
+
+  test("clear array (with spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      // Populate enough rows to trigger spill
+      populateRows(array, spillThreshold * 2)
+      val bytesSpilled = getNumBytesSpilled
+      assert(bytesSpilled > 0)
+
+      // Clear the array
+      array.clear()
+      assert(array.isEmpty)
+
+      // Re-populate the array ... but NOT upto the point that there is spill.
+      // Verify data. Verify that there was NO "extra" spill
+      val expectedValues = populateRows(array, spillThreshold / 2)
+      validateData(array, expectedValues)
+      assert(getNumBytesSpilled == bytesSpilled)
+
+      // Populate more rows to trigger spill
+      // Verify the data. Verify that there was "extra" spill
+      populateRows(array, spillThreshold * 2, expectedValues)
+      validateData(array, expectedValues)
+      assert(getNumBytesSpilled > bytesSpilled)
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index afd47897ed4b22c189164ccd09e2927ae9e4d260..52e4f047225defcdcf43093c9657fac6b10f386e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.TestUtils.assertSpilled
 
 case class WindowData(month: Int, area: String, product: Int)
 
@@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
       """.stripMargin),
       Row(1, 3, null) :: Row(2, null, 4) :: Nil)
   }
+
+  test("test with low buffer spill threshold") {
+    val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+    nums.createOrReplaceTempView("nums")
+
+    val expected =
+      Row(1, 1, 1) ::
+        Row(0, 2, 3) ::
+        Row(1, 3, 6) ::
+        Row(0, 4, 10) ::
+        Row(1, 5, 15) ::
+        Row(0, 6, 21) ::
+        Row(1, 7, 28) ::
+        Row(0, 8, 36) ::
+        Row(1, 9, 45) ::
+        Row(0, 10, 55) :: Nil
+
+    val actual = sql(
+      """
+        |SELECT y, x, sum(x) OVER w1 AS running_sum
+        |FROM nums
+        |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
+      """.stripMargin)
+
+    withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") {
+      assertSpilled(sparkContext, "test with low buffer spill threshold") {
+        checkAnswer(actual, expected)
+      }
+    }
+
+    spark.catalog.dropTempView("nums")
+  }
 }