From f7e26d788757f917b32749856bb29feb7b4c2987 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Tue, 6 Sep 2016 10:46:31 -0700
Subject: [PATCH] [SPARK-16922] [SPARK-17211] [SQL] make the address of values
 portable in LongToUnsafeRowMap

## What changes were proposed in this pull request?

In LongToUnsafeRowMap, we use offset of a value as pointer, stored in a array also in the page for chained values. The offset is not portable, because Platform.LONG_ARRAY_OFFSET will be different with different JVM Heap size, then the deserialized LongToUnsafeRowMap will be corrupt.

This PR will change to use portable address (without Platform.LONG_ARRAY_OFFSET).

## How was this patch tested?

Added a test case with random generated keys, to improve the coverage. But this test is not a regression test, that could require a Spark cluster that have at least 32G heap in driver or executor.

Author: Davies Liu <davies@databricks.com>

Closes #14927 from davies/longmap.
---
 .../sql/execution/joins/HashedRelation.scala  | 27 ++++++---
 .../execution/joins/HashedRelationSuite.scala | 56 +++++++++++++++++++
 2 files changed, 75 insertions(+), 8 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 08975733ff..8821c0dea9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -447,10 +447,20 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
    */
   private def nextSlot(pos: Int): Int = (pos + 2) & mask
 
+  private[this] def toAddress(offset: Long, size: Int): Long = {
+    ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size
+  }
+
+  private[this] def toOffset(address: Long): Long = {
+    (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET
+  }
+
+  private[this] def toSize(address: Long): Int = {
+    (address & SIZE_MASK).toInt
+  }
+
   private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
-    val offset = address >>> SIZE_BITS
-    val size = address & SIZE_MASK
-    resultRow.pointTo(page, offset, size.toInt)
+    resultRow.pointTo(page, toOffset(address), toSize(address))
     resultRow
   }
 
@@ -485,9 +495,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       var addr = address
       override def hasNext: Boolean = addr != 0
       override def next(): UnsafeRow = {
-        val offset = addr >>> SIZE_BITS
-        val size = addr & SIZE_MASK
-        resultRow.pointTo(page, offset, size.toInt)
+        val offset = toOffset(addr)
+        val size = toSize(addr)
+        resultRow.pointTo(page, offset, size)
         addr = Platform.getLong(page, offset + size)
         resultRow
       }
@@ -554,7 +564,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
     Platform.putLong(page, cursor, 0)
     cursor += 8
     numValues += 1
-    updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
+    updateIndex(key, toAddress(offset, row.getSizeInBytes))
   }
 
   /**
@@ -562,6 +572,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
    */
   private def updateIndex(key: Long, address: Long): Unit = {
     var pos = firstSlot(key)
+    assert(numKeys < array.length / 2)
     while (array(pos) != key && array(pos + 1) != 0) {
       pos = nextSlot(pos)
     }
@@ -582,7 +593,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       }
     } else {
       // there are some values for this key, put the address in the front of them.
-      val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
+      val pointer = toOffset(address) + toSize(address)
       Platform.putLong(page, pointer, array(pos + 1))
       array(pos + 1) = address
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 1196f5ec7b..ede63fea96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
 
+import scala.util.Random
+
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.serializer.KryoSerializer
@@ -197,6 +199,60 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
+  test("LongToUnsafeRowMap with random keys") {
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+        Long.MaxValue,
+        Long.MaxValue,
+        1),
+      0)
+    val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
+
+    val N = 1000000
+    val rand = new Random
+    val keys = (0 to N).map(x => rand.nextLong()).toArray
+
+    val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
+    keys.foreach { k =>
+      map.append(k, unsafeProj(InternalRow(k)))
+    }
+    map.optimize()
+
+    val os = new ByteArrayOutputStream()
+    val out = new ObjectOutputStream(os)
+    map.writeExternal(out)
+    out.flush()
+    val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+    val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
+    map2.readExternal(in)
+
+    val row = unsafeProj(InternalRow(0L)).copy()
+    keys.foreach { k =>
+      val r = map2.get(k, row)
+      assert(r.hasNext)
+      var c = 0
+      while (r.hasNext) {
+        val rr = r.next()
+        assert(rr.getLong(0) === k)
+        c += 1
+      }
+    }
+    var i = 0
+    while (i < N * 10) {
+      val k = rand.nextLong()
+      val r = map2.get(k, row)
+      if (r != null) {
+        assert(r.hasNext)
+        while (r.hasNext) {
+          assert(r.next().getLong(0) === k)
+        }
+      }
+      i += 1
+    }
+    map.free()
+  }
+
   test("Spark-14521") {
     val ser = new KryoSerializer(
       (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
-- 
GitLab