Skip to content
Snippets Groups Projects
Commit c50f97da authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-9943] [SQL] deserialized UnsafeHashedRelation should be serializable

When the free memory in executor goes low, the cached broadcast objects need to serialized into disk, but currently the deserialized UnsafeHashedRelation can't be serialized , fail with NPE. This PR fixes that.

cc rxin

Author: Davies Liu <davies@databricks.com>

Closes #8174 from davies/serialize_hashed.
parent 693949ba
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.{SparkConf, SparkEnv}
......@@ -247,40 +247,67 @@ private[joins] final class UnsafeHashedRelation(
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val key = entry.getKey
val values = entry.getValue
// write all the values as single byte array
var totalSize = 0L
var i = 0
while (i < values.length) {
totalSize += values(i).getSizeInBytes + 4 + 4
i += 1
if (binaryMap != null) {
// This could happen when a cached broadcast object need to be dumped into disk to free memory
out.writeInt(binaryMap.numElements())
var buffer = new Array[Byte](64)
def write(addr: MemoryLocation, length: Int): Unit = {
if (buffer.length < length) {
buffer = new Array[Byte](length)
}
Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
buffer, Platform.BYTE_ARRAY_OFFSET, length)
out.write(buffer, 0, length)
}
assert(totalSize < Integer.MAX_VALUE, "values are too big")
// [key size] [values size] [key bytes] [values bytes]
out.writeInt(key.getSizeInBytes)
out.writeInt(totalSize.toInt)
out.write(key.getBytes)
i = 0
while (i < values.length) {
// [num of fields] [num of bytes] [row bytes]
// write the integer in native order, so they can be read by UNSAFE.getInt()
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
out.writeInt(values(i).numFields())
out.writeInt(values(i).getSizeInBytes)
} else {
out.writeInt(Integer.reverseBytes(values(i).numFields()))
out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
val iter = binaryMap.iterator()
while (iter.hasNext) {
val loc = iter.next()
// [key size] [values size] [key bytes] [values bytes]
out.writeInt(loc.getKeyLength)
out.writeInt(loc.getValueLength)
write(loc.getKeyAddress, loc.getKeyLength)
write(loc.getValueAddress, loc.getValueLength)
}
} else {
assert(hashTable != null)
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val key = entry.getKey
val values = entry.getValue
// write all the values as single byte array
var totalSize = 0L
var i = 0
while (i < values.length) {
totalSize += values(i).getSizeInBytes + 4 + 4
i += 1
}
assert(totalSize < Integer.MAX_VALUE, "values are too big")
// [key size] [values size] [key bytes] [values bytes]
out.writeInt(key.getSizeInBytes)
out.writeInt(totalSize.toInt)
out.write(key.getBytes)
i = 0
while (i < values.length) {
// [num of fields] [num of bytes] [row bytes]
// write the integer in native order, so they can be read by UNSAFE.getInt()
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
out.writeInt(values(i).numFields())
out.writeInt(values(i).getSizeInBytes)
} else {
out.writeInt(Integer.reverseBytes(values(i).numFields()))
out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
}
out.write(values(i).getBytes)
i += 1
}
out.write(values(i).getBytes)
i += 1
}
}
}
......
......@@ -102,6 +102,14 @@ class HashedRelationSuite extends SparkFunSuite {
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
assert(hashed2.get(unsafeData(2)) === data2)
assert(numDataRows.value.value === data.length)
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2)
out2.flush()
// This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same
// as they are inserted
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
test("test serialization empty hash map") {
......@@ -119,5 +127,11 @@ class HashedRelationSuite extends SparkFunSuite {
val toUnsafe = UnsafeProjection.create(schema)
val row = toUnsafe(InternalRow(0))
assert(hashed2.get(row) === null)
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
hashed2.writeExternal(out2)
out2.flush()
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment