Skip to content
Snippets Groups Projects
Commit a4aed717 authored by yzhou2001's avatar yzhou2001 Committed by Davies Liu
Browse files

[SPARK-14521] [SQL] StackOverflowError in Kryo when executing TPC-DS

## What changes were proposed in this pull request?

Observed stackOverflowError in Kryo when executing TPC-DS Query27. Spark thrift server disables kryo reference tracking (if not specified in conf). When "spark.kryo.referenceTracking" is set to true explicitly in spark-defaults.conf, query executes successfully. The root cause is that the TaskMemoryManager inside MemoryConsumer and LongToUnsafeRowMap were not transient and thus were serialized and broadcast around from within LongHashedRelation, which could potentially cause circular reference inside Kryo. But the TaskMemoryManager is per task and should not be passed around at the first place. This fix makes it transient.

## How was this patch tested?
core/test, hive/test, sql/test, catalyst/test, dev/lint-scala, org.apache.spark.sql.hive.execution.HiveCompatibilitySuite, dev/scalastyle,
manual test of TBC-DS Query 27 with 1GB data but without the "limit 100" which would cause a NPE due to SPARK-14752.

Author: yzhou2001 <yzhou_1999@yahoo.com>

Closes #12598 from yzhou2001/master.
parent 659f635d
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.io._
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
......@@ -116,7 +119,7 @@ private[execution] object HashedRelation {
private[joins] class UnsafeHashedRelation(
private var numFields: Int,
private var binaryMap: BytesToBytesMap)
extends HashedRelation with Externalizable {
extends HashedRelation with Externalizable with KryoSerializable {
private[joins] def this() = this(0, null) // Needed for serialization
......@@ -171,10 +174,21 @@ private[joins] class UnsafeHashedRelation(
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(numFields)
write(out.writeInt, out.writeLong, out.write)
}
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
write(out.writeInt, out.writeLong, out.write)
}
private def write(
writeInt: (Int) => Unit,
writeLong: (Long) => Unit,
writeBuffer: (Array[Byte], Int, Int) => Unit) : Unit = {
writeInt(numFields)
// TODO: move these into BytesToBytesMap
out.writeLong(binaryMap.numKeys())
out.writeLong(binaryMap.numValues())
writeLong(binaryMap.numKeys())
writeLong(binaryMap.numValues())
var buffer = new Array[Byte](64)
def write(base: Object, offset: Long, length: Int): Unit = {
......@@ -182,25 +196,32 @@ private[joins] class UnsafeHashedRelation(
buffer = new Array[Byte](length)
}
Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
out.write(buffer, 0, length)
writeBuffer(buffer, 0, length)
}
val iter = binaryMap.iterator()
while (iter.hasNext) {
val loc = iter.next()
// [key size] [values size] [key bytes] [value bytes]
out.writeInt(loc.getKeyLength)
out.writeInt(loc.getValueLength)
writeInt(loc.getKeyLength)
writeInt(loc.getValueLength)
write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
numFields = in.readInt()
read(in.readInt, in.readLong, in.readFully)
}
private def read(
readInt: () => Int,
readLong: () => Long,
readBuffer: (Array[Byte], Int, Int) => Unit): Unit = {
numFields = readInt()
resultRow = new UnsafeRow(numFields)
val nKeys = in.readLong()
val nValues = in.readLong()
val nKeys = readLong()
val nValues = readLong()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
......@@ -227,16 +248,16 @@ private[joins] class UnsafeHashedRelation(
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
while (i < nValues) {
val keySize = in.readInt()
val valuesSize = in.readInt()
val keySize = readInt()
val valuesSize = readInt()
if (keySize > keyBuffer.length) {
keyBuffer = new Array[Byte](keySize)
}
in.readFully(keyBuffer, 0, keySize)
readBuffer(keyBuffer, 0, keySize)
if (valuesSize > valuesBuffer.length) {
valuesBuffer = new Array[Byte](valuesSize)
}
in.readFully(valuesBuffer, 0, valuesSize)
readBuffer(valuesBuffer, 0, valuesSize)
val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
......@@ -248,6 +269,10 @@ private[joins] class UnsafeHashedRelation(
i += 1
}
}
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
read(in.readInt, in.readLong, in.readBytes)
}
}
private[joins] object UnsafeHashedRelation {
......@@ -324,8 +349,8 @@ private[joins] object UnsafeHashedRelation {
*
* see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/
*/
private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int)
extends MemoryConsumer(mm) with Externalizable {
private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, capacity: Int)
extends MemoryConsumer(mm) with Externalizable with KryoSerializable {
// Whether the keys are stored in dense mode or not.
private var isDense = false
......@@ -624,58 +649,85 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
}
}
private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = {
private def writeLongArray(
writeBuffer: (Array[Byte], Int, Int) => Unit,
arr: Array[Long],
len: Int): Unit = {
val buffer = new Array[Byte](4 << 10)
var offset: Long = Platform.LONG_ARRAY_OFFSET
val end = len * 8L + Platform.LONG_ARRAY_OFFSET
while (offset < end) {
val size = Math.min(buffer.length, (end - offset).toInt)
Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
out.write(buffer, 0, size)
writeBuffer(buffer, 0, size)
offset += size
}
}
override def writeExternal(out: ObjectOutput): Unit = {
out.writeBoolean(isDense)
out.writeLong(minKey)
out.writeLong(maxKey)
out.writeLong(numKeys)
out.writeLong(numValues)
out.writeLong(array.length)
writeLongArray(out, array, array.length)
private def write(
writeBoolean: (Boolean) => Unit,
writeLong: (Long) => Unit,
writeBuffer: (Array[Byte], Int, Int) => Unit): Unit = {
writeBoolean(isDense)
writeLong(minKey)
writeLong(maxKey)
writeLong(numKeys)
writeLong(numValues)
writeLong(array.length)
writeLongArray(writeBuffer, array, array.length)
val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt
out.writeLong(used)
writeLongArray(out, page, used)
writeLong(used)
writeLongArray(writeBuffer, page, used)
}
private def readLongArray(in: ObjectInput, length: Int): Array[Long] = {
override def writeExternal(output: ObjectOutput): Unit = {
write(output.writeBoolean, output.writeLong, output.write)
}
override def write(kryo: Kryo, out: Output): Unit = {
write(out.writeBoolean, out.writeLong, out.write)
}
private def readLongArray(
readBuffer: (Array[Byte], Int, Int) => Unit,
length: Int): Array[Long] = {
val array = new Array[Long](length)
val buffer = new Array[Byte](4 << 10)
var offset: Long = Platform.LONG_ARRAY_OFFSET
val end = length * 8L + Platform.LONG_ARRAY_OFFSET
while (offset < end) {
val size = Math.min(buffer.length, (end - offset).toInt)
in.readFully(buffer, 0, size)
readBuffer(buffer, 0, size)
Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
offset += size
}
array
}
override def readExternal(in: ObjectInput): Unit = {
isDense = in.readBoolean()
minKey = in.readLong()
maxKey = in.readLong()
numKeys = in.readLong
numValues = in.readLong()
private def read(
readBoolean: () => Boolean,
readLong: () => Long,
readBuffer: (Array[Byte], Int, Int) => Unit): Unit = {
isDense = readBoolean()
minKey = readLong()
maxKey = readLong()
numKeys = readLong()
numValues = readLong()
val length = in.readLong().toInt
val length = readLong().toInt
mask = length - 2
array = readLongArray(in, length)
val pageLength = in.readLong().toInt
page = readLongArray(in, pageLength)
array = readLongArray(readBuffer, length)
val pageLength = readLong().toInt
page = readLongArray(readBuffer, pageLength)
}
override def readExternal(in: ObjectInput): Unit = {
read(in.readBoolean, in.readLong, in.readFully)
}
override def read(kryo: Kryo, in: Input): Unit = {
read(in.readBoolean, in.readLong, in.readBytes)
}
}
......
......@@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream,
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
......@@ -151,6 +152,40 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}
test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
val key = Seq(BoundReference(0, IntegerType, false))
// Testing Kryo serialization of HashedRelation
val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
val longRelation = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
val longRelation2 = ser.deserialize[LongHashedRelation](ser.serialize(longRelation))
(0 until 100).foreach { i =>
val rows = longRelation2.get(i).toArray
assert(rows.length === 2)
assert(rows(0).getInt(0) === i)
assert(rows(0).getInt(1) === i + 1)
assert(rows(1).getInt(0) === i)
assert(rows(1).getInt(1) === i + 1)
}
// Testing Kryo serialization of UnsafeHashedRelation
val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm)
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
unsafeHashed2.writeExternal(out2)
out2.flush()
assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
}
// This test require 4G heap to run, should run it manually
ignore("build HashedRelation that is larger than 1G") {
val unsafeProj = UnsafeProjection.create(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment