diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0427db4e3bf2550818a2fd401f9e8b71fb65c9ce..b280c76c70a61a5c492bc2b132d18bb4294b9077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -173,8 +173,8 @@ private[joins] class UnsafeHashedRelation( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(numFields) // TODO: move these into BytesToBytesMap - out.writeInt(binaryMap.numKeys()) - out.writeInt(binaryMap.numValues()) + out.writeLong(binaryMap.numKeys()) + out.writeLong(binaryMap.numValues()) var buffer = new Array[Byte](64) def write(base: Object, offset: Long, length: Int): Unit = { @@ -199,8 +199,8 @@ private[joins] class UnsafeHashedRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { numFields = in.readInt() resultRow = new UnsafeRow(numFields) - val nKeys = in.readInt() - val nValues = in.readInt() + val nKeys = in.readLong() + val nValues = in.readLong() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory // TODO(josh): This needs to be revisited before we merge this patch; making this change now // so that tests compile: @@ -345,16 +345,20 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap // The page to store all bytes of UnsafeRow and the pointer to next rows. // [row1][pointer1] [row2][pointer2] - private var page: Array[Byte] = null + private var page: Array[Long] = null // Current write cursor in the page. - private var cursor = Platform.BYTE_ARRAY_OFFSET + private var cursor: Long = Platform.LONG_ARRAY_OFFSET + + // The number of bits for size in address + private val SIZE_BITS = 28 + private val SIZE_MASK = 0xfffffff // The total number of values of all keys. - private var numValues = 0 + private var numValues = 0L // The number of unique keys. - private var numKeys = 0 + private var numKeys = 0L // needed by serializer def this() = { @@ -390,7 +394,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap acquireMemory(n * 2 * 8 + (1 << 20)) array = new Array[Long](n * 2) mask = n * 2 - 2 - page = new Array[Byte](1 << 20) // 1M bytes + page = new Array[Long](1 << 17) // 1M bytes } } @@ -406,7 +410,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap /** * Returns total memory consumption. */ - def getTotalMemoryConsumption: Long = array.length * 8 + page.length + def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L /** * Returns the first slot of array that store the keys (sparse mode). @@ -422,8 +426,8 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap private def nextSlot(pos: Int): Int = (pos + 2) & mask private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - val offset = address >>> 32 - val size = address & 0xffffffffL + val offset = address >>> SIZE_BITS + val size = address & SIZE_MASK resultRow.pointTo(page, offset, size.toInt) resultRow } @@ -450,15 +454,15 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } /** - * Returns an interator of UnsafeRow for multiple linked values. + * Returns an iterator of UnsafeRow for multiple linked values. */ private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { new Iterator[UnsafeRow] { var addr = address override def hasNext: Boolean = addr != 0 override def next(): UnsafeRow = { - val offset = addr >>> 32 - val size = addr & 0xffffffffL + val offset = addr >>> SIZE_BITS + val size = addr & SIZE_MASK resultRow.pointTo(page, offset, size.toInt) addr = Platform.getLong(page, offset + size) resultRow @@ -491,6 +495,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap * Appends the key and row into this map. */ def append(key: Long, row: UnsafeRow): Unit = { + val sizeInBytes = row.getSizeInBytes + if (sizeInBytes >= (1 << SIZE_BITS)) { + sys.error("Does not support row that is larger than 256M") + } + if (key < minKey) { minKey = key } @@ -499,16 +508,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { val used = page.length - if (used * 2L > (1L << 31)) { - sys.error("Can't allocate a page that is larger than 2G") + if (used >= (1 << 30)) { + sys.error("Can not build a HashedRelation that is larger than 8G") } - acquireMemory(used * 2) - val newPage = new Array[Byte](used * 2) - System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + acquireMemory(used * 8L * 2) + val newPage = new Array[Long](used * 2) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) page = newPage - freeMemory(used) + freeMemory(used * 8) } // copy the bytes of UnsafeRow @@ -518,7 +528,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap Platform.putLong(page, cursor, 0) cursor += 8 numValues += 1 - updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes) } /** @@ -536,11 +546,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap numKeys += 1 if (numKeys * 4 > array.length) { // reach half of the capacity - growArray() + if (array.length < (1 << 30)) { + // Cannot allocate an array with 2G elements + growArray() + } else if (numKeys > array.length / 2 * 0.75) { + // The fill ratio should be less than 0.75 + sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + } } } else { // there are some values for this key, put the address in the front of them. - val pointer = (address >>> 32) + (address & 0xffffffffL) + val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK) Platform.putLong(page, pointer, array(pos + 1)) array(pos + 1) = address } @@ -550,7 +566,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap var old_array = array val n = array.length numKeys = 0 - acquireMemory(n * 2 * 8) + acquireMemory(n * 2 * 8L) array = new Array[Long](n * 2) mask = n * 2 - 2 var i = 0 @@ -599,7 +615,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap */ def free(): Unit = { if (page != null) { - freeMemory(page.length) + freeMemory(page.length * 8) page = null } if (array != null) { @@ -608,52 +624,58 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } } + private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = { + val buffer = new Array[Byte](4 << 10) + var offset: Long = Platform.LONG_ARRAY_OFFSET + val end = len * 8L + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, (end - offset).toInt) + Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size + } + } + override def writeExternal(out: ObjectOutput): Unit = { out.writeBoolean(isDense) out.writeLong(minKey) out.writeLong(maxKey) - out.writeInt(numKeys) - out.writeInt(numValues) + out.writeLong(numKeys) + out.writeLong(numValues) + + out.writeLong(array.length) + writeLongArray(out, array, array.length) + val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt + out.writeLong(used) + writeLongArray(out, page, used) + } - out.writeInt(array.length) + private def readLongArray(in: ObjectInput, length: Int): Array[Long] = { + val array = new Array[Long](length) val buffer = new Array[Byte](4 << 10) - var offset = Platform.LONG_ARRAY_OFFSET - val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + var offset: Long = Platform.LONG_ARRAY_OFFSET + val end = length * 8L + Platform.LONG_ARRAY_OFFSET while (offset < end) { - val size = Math.min(buffer.length, end - offset) - Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) - out.write(buffer, 0, size) + val size = Math.min(buffer.length, (end - offset).toInt) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) offset += size } - - val used = cursor - Platform.BYTE_ARRAY_OFFSET - out.writeInt(used) - out.write(page, 0, used) + array } override def readExternal(in: ObjectInput): Unit = { isDense = in.readBoolean() minKey = in.readLong() maxKey = in.readLong() - numKeys = in.readInt() - numValues = in.readInt() + numKeys = in.readLong + numValues = in.readLong() - val length = in.readInt() - array = new Array[Long](length) + val length = in.readLong().toInt mask = length - 2 - val buffer = new Array[Byte](4 << 10) - var offset = Platform.LONG_ARRAY_OFFSET - val end = length * 8 + Platform.LONG_ARRAY_OFFSET - while (offset < end) { - val size = Math.min(buffer.length, end - offset) - in.readFully(buffer, 0, size) - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) - offset += size - } - - val numBytes = in.readInt() - page = new Array[Byte](numBytes) - in.readFully(page) + array = readLongArray(in, length) + val pageLength = in.readLong().toInt + page = readLongArray(in, pageLength) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 371a9ed617d65c0c90811764893289a869fad0a9..3ee25c0996035e34ad656b53afce7a48ac108bd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { @@ -149,4 +150,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(rows(1).getInt(1) === i + 1) } } + + // This test require 4G heap to run, should run it manually + ignore("build HashedRelation that is larger than 1G") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), + BoundReference(1, StringType, true))) + val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100))) + val key = Seq(BoundReference(0, IntegerType, false)) + val rows = (0 until (1 << 24)).iterator.map { i => + unsafeRow.setInt(0, i % 1000000) + unsafeRow.setInt(1, i) + unsafeRow + } + + val unsafeRelation = UnsafeHashedRelation(rows, key, 1000, mm) + assert(unsafeRelation.estimatedSize > (2L << 30)) + unsafeRelation.close() + + val rows2 = (0 until (1 << 24)).iterator.map { i => + unsafeRow.setInt(0, i % 1000000) + unsafeRow.setInt(1, i) + unsafeRow + } + val longRelation = LongHashedRelation(rows2, key, 1000, mm) + assert(longRelation.estimatedSize > (2L << 30)) + longRelation.close() + } }