From 70fe6ce52f26904aa53bd20409db69b52bccf315 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Tue, 5 Jan 2016 18:46:52 -0800
Subject: [PATCH] [SPARK-12659] fix NPE in UnsafeExternalSorter (used by
 cartesian product)

Cartesian product use UnsafeExternalSorter without comparator to do spilling, it will NPE if spilling happens.

This bug also hitted by #10605

cc JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #10606 from davies/fix_spilling.
---
 .../unsafe/sort/UnsafeExternalSorter.java     |  7 +++--
 .../unsafe/sort/UnsafeInMemorySorter.java     | 17 +++++-----
 .../sort/UnsafeExternalSorterSuite.java       | 31 +++++++++++++++++++
 project/MimaExcludes.scala                    |  1 +
 4 files changed, 45 insertions(+), 11 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 79d74b23ce..77d0b70bb8 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -400,6 +400,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
    * after consuming this iterator.
    */
   public UnsafeSorterIterator getSortedIterator() throws IOException {
+    assert(recordComparator != null);
     if (spillWriters.isEmpty()) {
       assert(inMemSorter != null);
       readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -531,18 +532,20 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
    *
    * It is the caller's responsibility to call `cleanupResources()`
    * after consuming this iterator.
+   *
+   * TODO: support forced spilling
    */
   public UnsafeSorterIterator getIterator() throws IOException {
     if (spillWriters.isEmpty()) {
       assert(inMemSorter != null);
-      return inMemSorter.getIterator();
+      return inMemSorter.getSortedIterator();
     } else {
       LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
       for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
         queue.add(spillWriter.getReader(blockManager));
       }
       if (inMemSorter != null) {
-        queue.add(inMemSorter.getIterator());
+        queue.add(inMemSorter.getSortedIterator());
       }
       return new ChainedIterator(queue);
     }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index c16cbce9a0..b7ab45675e 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -99,7 +99,11 @@ public final class UnsafeInMemorySorter {
     this.consumer = consumer;
     this.memoryManager = memoryManager;
     this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
-    this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+    if (recordComparator != null) {
+      this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+    } else {
+      this.sortComparator = null;
+    }
     this.array = array;
   }
 
@@ -223,14 +227,9 @@ public final class UnsafeInMemorySorter {
    * {@code next()} will return the same mutable object.
    */
   public SortedIterator getSortedIterator() {
-    sorter.sort(array, 0, pos / 2, sortComparator);
-    return new SortedIterator(pos / 2);
-  }
-
-  /**
-   * Returns an iterator over record pointers in original order (inserted).
-   */
-  public SortedIterator getIterator() {
+    if (sortComparator != null) {
+      sorter.sort(array, 0, pos / 2, sortComparator);
+    }
     return new SortedIterator(pos / 2);
   }
 }
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index e0ee281e98..32f5a1a7e6 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -369,6 +369,37 @@ public class UnsafeExternalSorterSuite {
     assertSpillFilesWereCleanedUp();
   }
 
+  @Test
+  public void forcedSpillingWithoutComparator() throws Exception {
+    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+      taskMemoryManager,
+      blockManager,
+      taskContext,
+      null,
+      null,
+      /* initialSize */ 1024,
+      pageSizeBytes);
+    long[] record = new long[100];
+    int recordSize = record.length * 8;
+    int n = (int) pageSizeBytes / recordSize * 3;
+    int batch = n / 4;
+    for (int i = 0; i < n; i++) {
+      record[0] = (long) i;
+      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+      if (i % batch == batch - 1) {
+        sorter.spill();
+      }
+    }
+    UnsafeSorterIterator iter = sorter.getIterator();
+    for (int i = 0; i < n; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+    }
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
   @Test
   public void testPeakMemoryUsed() throws Exception {
     final long recordLengthBytes = 8;
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8c3a40d241..940fedfa2a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -40,6 +40,7 @@ object MimaExcludes {
         excludePackage("org.apache.spark.rpc"),
         excludePackage("org.spark-project.jetty"),
         excludePackage("org.apache.spark.unused"),
+        excludePackage("org.apache.spark.util.collection.unsafe"),
         excludePackage("org.apache.spark.sql.catalyst"),
         excludePackage("org.apache.spark.sql.execution"),
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"),
-- 
GitLab