Skip to content
Snippets Groups Projects
Commit bd11b01e authored by Sandy Ryza's avatar Sandy Ryza Committed by Patrick Wendell
Browse files

[SPARK-7896] Allow ChainedBuffer to store more than 2 GB

Author: Sandy Ryza <sandy@cloudera.com>

Closes #6440 from sryza/sandy-spark-7896 and squashes the following commits:

49d8a0d [Sandy Ryza] Fix bug introduced when reading over record boundaries
6006856 [Sandy Ryza] Fix overflow issues
006b4b2 [Sandy Ryza] Fix scalastyle by removing non ascii characters
8b000ca [Sandy Ryza] Add ascii art to describe layout of data in metaBuffer
f2053c0 [Sandy Ryza] Fix negative overflow issue
0368c78 [Sandy Ryza] Initialize size as 0
a5a4820 [Sandy Ryza] Use explicit types for all numbers in ChainedBuffer
b7e0213 [Sandy Ryza] SPARK-7896. Allow ChainedBuffer to store more than 2 GB
parent 852f4de2
No related branches found
No related tags found
No related merge requests found
......@@ -28,11 +28,13 @@ import scala.collection.mutable.ArrayBuffer
* occupy a contiguous segment of memory.
*/
private[spark] class ChainedBuffer(chunkSize: Int) {
private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
java.lang.Long.highestOneBit(chunkSize))
assert((1 << chunkSizeLog2) == chunkSize,
s"ChainedBuffer chunk size $chunkSize must be a power of two")
private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
private var _size: Int = _
private var _size: Long = 0
/**
* Feed bytes from this buffer into a BlockObjectWriter.
......@@ -41,16 +43,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
* @param os OutputStream to read into.
* @param len Number of bytes to read.
*/
def read(pos: Int, os: OutputStream, len: Int): Unit = {
def read(pos: Long, os: OutputStream, len: Int): Unit = {
if (pos + len > _size) {
throw new IndexOutOfBoundsException(
s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
}
var chunkIndex = pos >> chunkSizeLog2
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
var written = 0
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
var written: Int = 0
while (written < len) {
val toRead = math.min(len - written, chunkSize - posInChunk)
val toRead: Int = math.min(len - written, chunkSize - posInChunk)
os.write(chunks(chunkIndex), posInChunk, toRead)
written += toRead
chunkIndex += 1
......@@ -66,16 +68,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
* @param offs Offset in the byte array to read to.
* @param len Number of bytes to read.
*/
def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
if (pos + len > _size) {
throw new IndexOutOfBoundsException(
s"Read of $len bytes at position $pos would go past size of buffer")
}
var chunkIndex = pos >> chunkSizeLog2
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
var written = 0
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
var written: Int = 0
while (written < len) {
val toRead = math.min(len - written, chunkSize - posInChunk)
val toRead: Int = math.min(len - written, chunkSize - posInChunk)
System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
written += toRead
chunkIndex += 1
......@@ -91,22 +93,22 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
* @param offs Offset in the byte array to write from.
* @param len Number of bytes to write.
*/
def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
if (pos > _size) {
throw new IndexOutOfBoundsException(
s"Write at position $pos starts after end of buffer ${_size}")
}
// Grow if needed
val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
while (endChunkIndex >= chunks.length) {
chunks += new Array[Byte](chunkSize)
}
var chunkIndex = pos >> chunkSizeLog2
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
var written = 0
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
var written: Int = 0
while (written < len) {
val toWrite = math.min(len - written, chunkSize - posInChunk)
val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
written += toWrite
chunkIndex += 1
......@@ -119,19 +121,19 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
/**
* Total size of buffer that can be written to without allocating additional memory.
*/
def capacity: Int = chunks.size * chunkSize
def capacity: Long = chunks.size.toLong * chunkSize
/**
* Size of the logical buffer.
*/
def size: Int = _size
def size: Long = _size
}
/**
* Output stream that writes to a ChainedBuffer.
*/
private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
private var pos = 0
private var pos: Long = 0
override def write(b: Int): Unit = {
throw new UnsupportedOperationException()
......
......@@ -41,6 +41,13 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
*
* Currently, only sorting by partition is supported.
*
* Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
* two integers:
*
* +-------------+------------+------------+-------------+
* | keyStart | keyValLen | partitionId |
* +-------------+------------+------------+-------------+
*
* @param metaInitialRecords The initial number of entries in the metadata buffer.
* @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
* @param serializerInstance the serializer used for serializing inserted records.
......@@ -68,19 +75,15 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
}
val keyStart = kvBuffer.size
if (keyStart < 0) {
throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes")
}
kvSerializationStream.writeKey[Any](key)
kvSerializationStream.flush()
val valueStart = kvBuffer.size
kvSerializationStream.writeValue[Any](value)
kvSerializationStream.flush()
val valueEnd = kvBuffer.size
val keyValLen = (kvBuffer.size - keyStart).toInt
metaBuffer.put(keyStart)
metaBuffer.put(valueStart)
metaBuffer.put(valueEnd)
// keyStart, a long, gets split across two ints
metaBuffer.put(keyStart.toInt)
metaBuffer.put((keyStart >> 32).toInt)
metaBuffer.put(keyValLen)
metaBuffer.put(partition)
}
......@@ -114,7 +117,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
}
}
override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity
override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
......@@ -128,10 +131,10 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
var pos = 0
def writeNext(writer: BlockObjectWriter): Unit = {
val keyStart = metaBuffer.get(pos + KEY_START)
val valueEnd = metaBuffer.get(pos + VAL_END)
val keyStart = getKeyStartPos(metaBuffer, pos)
val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
pos += RECORD_SIZE
kvBuffer.read(keyStart, writer, valueEnd - keyStart)
kvBuffer.read(keyStart, writer, keyValLen)
writer.recordWritten()
}
def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
......@@ -163,9 +166,11 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
extends InputStream {
import PartitionedSerializedPairBuffer._
private var metaBufferPos = 0
private var kvBufferPos =
if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0
if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
......@@ -173,13 +178,14 @@ private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: Chained
if (metaBufferPos >= metaBuffer.position) {
return -1
}
val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos
val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
(kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
val toRead = math.min(bytesRemainingInRecord, len)
kvBuffer.read(kvBufferPos, bytes, offs, toRead)
if (toRead == bytesRemainingInRecord) {
metaBufferPos += RECORD_SIZE
if (metaBufferPos < metaBuffer.position) {
kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START)
kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
}
} else {
kvBufferPos += toRead
......@@ -246,9 +252,14 @@ private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuf
}
private[spark] object PartitionedSerializedPairBuffer {
val KEY_START = 0
val VAL_START = 1
val VAL_END = 2
val KEY_START = 0 // keyStart, a long, gets split across two ints
val KEY_VAL_LEN = 2
val PARTITION = 3
val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata
val RECORD_SIZE = PARTITION + 1 // num ints of metadata
def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
(upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment