diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 754832b8a4ca7ea221cf595a472bf4b09615a811..b7bc087855b9f360a443ca1c4ce05047ade5a5d2 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -200,6 +200,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserializeStream(s: InputStream): DeserializationStream = { new KryoDeserializationStream(kryo, s) } + + /** + * Returns true if auto-reset is on. The only reason this would be false is if the user-supplied + * registrator explicitly turns auto-reset off. + */ + def getAutoReset(): Boolean = { + val field = classOf[Kryo].getDeclaredField("autoReset") + field.setAccessible(true) + field.get(kryo).asInstanceOf[Boolean] + } } /** diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ca6e971d227fb793b8c75c43ffde2813888d280e..c381672a4f588f73488d7f2c116a0f064b707fb1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -101,7 +101,12 @@ abstract class SerializerInstance { */ @DeveloperApi abstract class SerializationStream { + /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream + /** Writes the object representing the key of a key-value pair. */ + def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key) + /** Writes the object representing the value of a key-value pair. */ + def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit def close(): Unit @@ -120,7 +125,12 @@ abstract class SerializationStream { */ @DeveloperApi abstract class DeserializationStream { + /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T + /** Reads the object representing the key of a key-value pair. */ + def readKey[T: ClassTag](): T = readObject[T]() + /** Reads the object representing the value of a key-value pair. */ + def readValue[T: ClassTag](): T = readObject[T]() def close(): Unit /** @@ -141,4 +151,25 @@ abstract class DeserializationStream { DeserializationStream.this.close() } } + + /** + * Read the elements of this stream through an iterator over key-value pairs. This can only be + * called once, as reading each element will consume data from the input source. + */ + def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] { + override protected def getNext() = { + try { + (readKey[Any](), readValue[Any]()) + } catch { + case eof: EOFException => { + finished = true + null + } + } + } + + override protected def close() { + DeserializationStream.this.close() + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 755f17d6aa15a0ff35759e3926a39146f7c60b03..cd27c9e07a3cd122a47d2beab30f995696c2fa08 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V]( for (elem <- iter) { val bucketId = dep.partitioner.getPartition(elem._1) - shuffle.writers(bucketId).write(elem) + shuffle.writers(bucketId).write(elem._1, elem._2) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 14833791f7a4d0aab0884fbe31473cb7993ef07d..499dd97c0656a06a07cb6b226066d13368c9d90f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.Utils * This interface does not support concurrent writes. Also, once the writer has * been opened, it cannot be reopened again. */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { +private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { def open(): BlockObjectWriter @@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { def revertPartialWritesAndClose() /** - * Writes an object. + * Writes a key-value pair. */ - def write(value: Any) + def write(key: Any, value: Any) + + /** + * Notify the writer that a record worth of bytes has been written with writeBytes. + */ + def recordWritten() /** * Returns the file segment of committed data that this Writer has written. @@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(value: Any) { + override def write(key: Any, value: Any) { + if (!initialized) { + open() + } + + objOut.writeKey(key) + objOut.writeValue(value) + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) + + if (numRecordsWritten % 32 == 0) { + updateBytesWritten() + } + } + + override def write(b: Int): Unit = throw new UnsupportedOperationException() + + override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { open() } - objOut.writeObject(value) + bs.write(kvBytes, offs, len) + } + + override def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter( } // For testing - private[spark] def flush() { + private[spark] override def flush() { objOut.flush() bs.flush() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f3379521d55e24c08bd5f7bcaacbd0418ac14dc0..d0faab62c9e9eaa2748d1b30c50f0f7c539355b4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,14 +17,12 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.{SerializerInstance, Serializer} @@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator( // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { is0 => val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asIterator + val iter = serializerInstance.deserializeStream(is).asKeyValueIterator CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null // so we don't release it again in cleanup. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala new file mode 100644 index 0000000000000000000000000000000000000000..a60bffe611f14d3d7a2b7cb9ddad9a2f73001530 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.OutputStream + +import scala.collection.mutable.ArrayBuffer + +/** + * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The + * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts + * of memory and needing to copy the full contents. The disadvantage is that the contents don't + * occupy a contiguous segment of memory. + */ +private[spark] class ChainedBuffer(chunkSize: Int) { + private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt + assert(math.pow(2, chunkSizeLog2).toInt == chunkSize, + s"ChainedBuffer chunk size $chunkSize must be a power of two") + private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() + private var _size: Int = _ + + /** + * Feed bytes from this buffer into a BlockObjectWriter. + * + * @param pos Offset in the buffer to read from. + * @param os OutputStream to read into. + * @param len Number of bytes to read. + */ + def read(pos: Int, os: OutputStream, len: Int): Unit = { + if (pos + len > _size) { + throw new IndexOutOfBoundsException( + s"Read of $len bytes at position $pos would go past size ${_size} of buffer") + } + var chunkIndex = pos >> chunkSizeLog2 + var posInChunk = pos - (chunkIndex << chunkSizeLog2) + var written = 0 + while (written < len) { + val toRead = math.min(len - written, chunkSize - posInChunk) + os.write(chunks(chunkIndex), posInChunk, toRead) + written += toRead + chunkIndex += 1 + posInChunk = 0 + } + } + + /** + * Read bytes from this buffer into a byte array. + * + * @param pos Offset in the buffer to read from. + * @param bytes Byte array to read into. + * @param offs Offset in the byte array to read to. + * @param len Number of bytes to read. + */ + def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + if (pos + len > _size) { + throw new IndexOutOfBoundsException( + s"Read of $len bytes at position $pos would go past size of buffer") + } + var chunkIndex = pos >> chunkSizeLog2 + var posInChunk = pos - (chunkIndex << chunkSizeLog2) + var written = 0 + while (written < len) { + val toRead = math.min(len - written, chunkSize - posInChunk) + System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) + written += toRead + chunkIndex += 1 + posInChunk = 0 + } + } + + /** + * Write bytes from a byte array into this buffer. + * + * @param pos Offset in the buffer to write to. + * @param bytes Byte array to write from. + * @param offs Offset in the byte array to write from. + * @param len Number of bytes to write. + */ + def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + if (pos > _size) { + throw new IndexOutOfBoundsException( + s"Write at position $pos starts after end of buffer ${_size}") + } + // Grow if needed + val endChunkIndex = (pos + len - 1) >> chunkSizeLog2 + while (endChunkIndex >= chunks.length) { + chunks += new Array[Byte](chunkSize) + } + + var chunkIndex = pos >> chunkSizeLog2 + var posInChunk = pos - (chunkIndex << chunkSizeLog2) + var written = 0 + while (written < len) { + val toWrite = math.min(len - written, chunkSize - posInChunk) + System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) + written += toWrite + chunkIndex += 1 + posInChunk = 0 + } + + _size = math.max(_size, pos + len) + } + + /** + * Total size of buffer that can be written to without allocating additional memory. + */ + def capacity: Int = chunks.size * chunkSize + + /** + * Size of the logical buffer. + */ + def size: Int = _size +} + +/** + * Output stream that writes to a ChainedBuffer. + */ +private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { + private var pos = 0 + + override def write(b: Int): Unit = { + throw new UnsupportedOperationException() + } + + override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { + chainedBuffer.write(pos, bytes, offs, len) + pos += len + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f91204956390604cd53d8418e4c489a2065f266d..b8509731450777c931b6f2bd58593e436a38ae42 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -174,7 +174,7 @@ class ExternalAppendOnlyMap[K, V, C]( val it = currentMap.destructiveSortedIterator(keyComparator) while (it.hasNext) { val kv = it.next() - writer.write(kv) + writer.write(kv._1, kv._2) objectsWritten += 1 if (objectsWritten == serializerBatchSize) { @@ -435,7 +435,9 @@ class ExternalAppendOnlyMap[K, V, C]( */ private def readNextItem(): (K, C) = { try { - val item = deserializeStream.readObject().asInstanceOf[(K, C)] + val k = deserializeStream.readKey().asInstanceOf[K] + val c = deserializeStream.readValue().asInstanceOf[C] + val item = (k, c) objectsRead += 1 if (objectsRead == serializerBatchSize) { objectsRead = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4ed8a740f99db04ff63eb214cbae6a7bc5a683ad..b7306cd5519187c3bc6de64d88ee598504b5168a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,7 +26,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockObjectWriter, BlockId} @@ -66,10 +66,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if - * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, - * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to - * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner). + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we + * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first * by partition ID and possibly second by key or by hash code of the key, if we want to do @@ -96,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C]( partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) - extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] { + extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 @@ -126,11 +127,22 @@ private[spark] class ExternalSorter[K, V, C]( if (shouldPartition) partitioner.get.getPartition(key) else 0 } + private val metaInitialRecords = 256 + private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB + private val useSerializedPairBuffer = + !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && + ser.isInstanceOf[KryoSerializer] && + serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset + // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. - private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] - private var buffer = new SizeTrackingPairBuffer[(Int, K), C] + private var map = new PartitionedAppendOnlyMap[K, C] + private var buffer = if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } // Total spilling statistics private var _diskBytesSpilled = 0L @@ -163,33 +175,6 @@ private[spark] class ExternalSorter[K, V, C]( } }) - // A comparator for (Int, K) pairs that orders them by only their partition ID - private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } - - // A comparator that orders (Int, K) pairs by partition ID and then possibly by key - private val partitionKeyComparator: Comparator[(Int, K)] = { - if (ordering.isDefined || aggregator.isDefined) { - // Sort by partition ID then key comparator - new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - val partitionDiff = a._1 - b._1 - if (partitionDiff != 0) { - partitionDiff - } else { - keyComparator.compare(a._2, b._2) - } - } - } - } else { - // Just sort it by partition ID - partitionComparator - } - } - // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. @@ -221,16 +206,18 @@ private[spark] class ExternalSorter[K, V, C]( } else if (bypassMergeSort) { // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies if (records.hasNext) { - spillToPartitionFiles(records.map { kv => - ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - }) + spillToPartitionFiles( + WritablePartitionedIterator.fromIterator(records.map { kv => + ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + }) + ) } } else { // Stick values into our buffer while (records.hasNext) { addElementsRead() val kv = records.next() - buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) } } @@ -248,11 +235,15 @@ private[spark] class ExternalSorter[K, V, C]( if (usingMap) { if (maybeSpill(map, map.estimateSize())) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] + map = new PartitionedAppendOnlyMap[K, C] } } else { if (maybeSpill(buffer, buffer.estimateSize())) { - buffer = new SizeTrackingPairBuffer[(Int, K), C] + buffer = if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } } } } @@ -260,7 +251,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. */ - override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { if (bypassMergeSort) { spillToPartitionFiles(collection) } else { @@ -277,7 +268,7 @@ private[spark] class ExternalSorter[K, V, C]( * * @param collection whichever collection we're using (map or buffer) */ - private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = { assert(!bypassMergeSort) // Because these files may be read during shuffle, their compression must be controlled by @@ -308,14 +299,10 @@ private[spark] class ExternalSorter[K, V, C]( var success = false try { - val it = collection.destructiveSortedIterator(partitionKeyComparator) + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { - val elem = it.next() - val partitionId = elem._1._1 - val key = elem._1._2 - val value = elem._2 - writer.write(key) - writer.write(value) + val partitionId = it.nextPartition() + it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -357,11 +344,11 @@ private[spark] class ExternalSorter[K, V, C]( * * @param collection whichever collection we're using (map or buffer) */ - private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { - spillToPartitionFiles(collection.iterator) + private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = { + spillToPartitionFiles(collection.writablePartitionedIterator()) } - private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = { + private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = { assert(bypassMergeSort) // Create our file writers if we haven't done so yet @@ -385,11 +372,8 @@ private[spark] class ExternalSorter[K, V, C]( // No need to sort stuff, just write each element out while (iterator.hasNext) { - val elem = iterator.next() - val partitionId = elem._1._1 - val key = elem._1._2 - val value = elem._2 - partitionWriters(partitionId).write((key, value)) + val partitionId = iterator.nextPartition() + iterator.writeNext(partitionWriters(partitionId)) } } @@ -618,8 +602,8 @@ private[spark] class ExternalSorter[K, V, C]( if (finished || deserializeStream == null) { return null } - val k = deserializeStream.readObject().asInstanceOf[K] - val c = deserializeStream.readObject().asInstanceOf[C] + val k = deserializeStream.readKey().asInstanceOf[K] + val c = deserializeStream.readValue().asInstanceOf[C] lastPartitionId = partitionId // Start reading the next batch if we're done with this one indexInBatch += 1 @@ -695,27 +679,27 @@ private[spark] class ExternalSorter[K, V, C]( */ def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { // The user hasn't requested sorted keys, so only sort by partition ID, not key - groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + groupByPartition(collection.partitionedDestructiveSortedIterator(None)) } else { // We do need to sort by both partition ID and key - groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) + groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) } } else if (bypassMergeSort) { // Read data from each partition file and merge it together with the data in memory; // note that there's no ordering or aggregator in this case -- we just partition objects - val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None)) collIter.map { case (partitionId, values) => (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) } } else { // Merge spilled and in-memory data - merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) + merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) } } @@ -762,15 +746,29 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } + } else if (spills.isEmpty && partitionWriters == null) { + // Case where we only have in-memory data + val collection = if (aggregator.isDefined) map else buffer + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + while (it.hasNext) { + val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, + context.taskMetrics.shuffleWriteMetrics.get) + val partitionId = it.nextPartition() + while (it.hasNext && it.nextPartition() == partitionId) { + it.writeNext(writer) + } + writer.commitAndClose() + val segment = writer.fileSegment() + lengths(partitionId) = segment.length + } } else { - // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by - // partition and just write everything directly. + // Not bypassing merge-sort; get an iterator by partition and just write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) for (elem <- elements) { - writer.write(elem) + writer.write(elem._1, elem._2) } writer.commitAndClose() val segment = writer.fileSegment() @@ -799,7 +797,7 @@ private[spark] class ExternalSorter[K, V, C]( if (writer.isOpen) { writer.commitAndClose() } - blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get) } def stop(): Unit = { @@ -829,6 +827,14 @@ private[spark] class ExternalSorter[K, V, C]( (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) } + private def comparator: Option[Comparator[K]] = { + if (ordering.isDefined || aggregator.isDefined) { + Some(keyComparator) + } else { + None + } + } + /** * An iterator that reads only the elements for a given partition ID from an underlying buffered * stream, assuming this partition is the next one to be read. Used to make it easier to return diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala similarity index 55% rename from core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala rename to core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala index faa4e2b12ddb60355f8f1e8a7f20a1aa860cfceb..d75959f480756f8a057640299d7b4eb158639fdc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala @@ -17,18 +17,8 @@ package org.apache.spark.util.collection -import java.util.Comparator +private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] { + def hasNext: Boolean = iter.hasNext -/** - * A common interface for our size-tracking collections of key-value pairs, which are used in - * external operations. These all support estimating the size and obtaining a memory-efficient - * sorted iterator. - */ -// TODO: should extend Iterable[Product2[K, V]] instead of (K, V) -private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] { - /** Estimate the collection's current memory usage in bytes. */ - def estimateSize(): Long - - /** Iterate through the data in a given key order. This may destroy the underlying collection. */ - def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] + def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V]) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..e2e2f1faae9d1a54240128ad5d24eca5b362fc19 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.util.collection.WritablePartitionedPairCollection._ + +/** + * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys are tuples + * of (partition ID, K) + */ +private[spark] class PartitionedAppendOnlyMap[K, V] + extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] { + + def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + destructiveSortedIterator(comparator) + } + + def writablePartitionedIterator(): WritablePartitionedIterator = { + WritablePartitionedIterator.fromIterator(super.iterator) + } + + def insert(partition: Int, key: K, value: V): Unit = { + update((partition, key), value) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala similarity index 66% rename from core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala rename to core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 9e9c16c5a29626d4098215c228d36613bbd5b3ea..e8332e1a87eacf16625970e0ab71ac75df8d9154 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,11 +19,15 @@ package org.apache.spark.util.collection import java.util.Comparator +import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.util.collection.WritablePartitionedPairCollection._ + /** - * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes. + * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track + * of its estimated size in bytes. */ -private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64) - extends SizeTracker with SizeTrackingPairCollection[K, V] +private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) + extends WritablePartitionedPairCollection[K, V] with SizeTracker { require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") require(initialCapacity >= 1, "Invalid initial capacity") @@ -35,35 +39,16 @@ private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64) private var data = new Array[AnyRef](2 * initialCapacity) /** Add an element into the buffer */ - def insert(key: K, value: V): Unit = { + def insert(partition: Int, key: K, value: V): Unit = { if (curSize == capacity) { growArray() } - data(2 * curSize) = key.asInstanceOf[AnyRef] + data(2 * curSize) = (partition, key.asInstanceOf[AnyRef]) data(2 * curSize + 1) = value.asInstanceOf[AnyRef] curSize += 1 afterUpdate() } - /** Total number of elements in buffer */ - override def size: Int = curSize - - /** Iterate over the elements of the buffer */ - override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { - var pos = 0 - - override def hasNext: Boolean = pos < curSize - - override def next(): (K, V) = { - if (!hasNext) { - throw new NoSuchElementException - } - val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) - pos += 1 - pair - } - } - /** Double the size of the array because we've reached capacity */ private def growArray(): Unit = { if (capacity == (1 << 29)) { @@ -79,8 +64,29 @@ private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64) } /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = { - new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator) + override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator) iterator } + + override def writablePartitionedIterator(): WritablePartitionedIterator = { + WritablePartitionedIterator.fromIterator(iterator) + } + + private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] { + var pos = 0 + + override def hasNext: Boolean = pos < curSize + + override def next(): ((Int, K), V) = { + if (!hasNext) { + throw new NoSuchElementException + } + val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V]) + pos += 1 + pair + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5ca0c62a04f22b1fa412956c375a28949be85bd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.InputStream +import java.nio.IntBuffer +import java.util.Comparator + +import org.apache.spark.SparkEnv +import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} +import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ + +/** + * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes + * its records upon insert and stores them as raw bytes. + * + * We use two data-structures to store the contents. The serialized records are stored in a + * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a + * metadata buffer that stores pointers into the data buffer as well as the partition ID of each + * record. Each entry in the metadata buffer takes up a fixed amount of space. + * + * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not + * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can + * happen without following any pointers, which should minimize cache misses. + * + * Currently, only sorting by partition is supported. + * + * @param metaInitialRecords The initial number of entries in the metadata buffer. + * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. + * @param serializerInstance the serializer used for serializing inserted records. + */ +private[spark] class PartitionedSerializedPairBuffer[K, V]( + metaInitialRecords: Int, + kvBlockSize: Int, + serializerInstance: SerializerInstance) + extends WritablePartitionedPairCollection[K, V] with SizeTracker { + + if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { + throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + + " Java-serialized objects.") + } + + private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) + + private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) + private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) + private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) + + def insert(partition: Int, key: K, value: V): Unit = { + if (metaBuffer.position == metaBuffer.capacity) { + growMetaBuffer() + } + + val keyStart = kvBuffer.size + if (keyStart < 0) { + throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes") + } + kvSerializationStream.writeObject[Any](key) + kvSerializationStream.flush() + val valueStart = kvBuffer.size + kvSerializationStream.writeObject[Any](value) + kvSerializationStream.flush() + val valueEnd = kvBuffer.size + + metaBuffer.put(keyStart) + metaBuffer.put(valueStart) + metaBuffer.put(valueEnd) + metaBuffer.put(partition) + } + + /** Double the size of the array because we've reached capacity */ + private def growMetaBuffer(): Unit = { + if (metaBuffer.capacity.toLong * 2 > Int.MaxValue) { + // Doubling the capacity would create an array bigger than Int.MaxValue, so don't + throw new Exception(s"Can't grow buffer beyond ${Int.MaxValue} bytes") + } + val newMetaBuffer = IntBuffer.allocate(metaBuffer.capacity * 2) + newMetaBuffer.put(metaBuffer.array) + metaBuffer = newMetaBuffer + } + + /** Iterate through the data in a given order. For this class this is not really destructive. */ + override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + sort(keyComparator) + val is = orderedInputStream + val deserStream = serializerInstance.deserializeStream(is) + new Iterator[((Int, K), V)] { + var metaBufferPos = 0 + def hasNext: Boolean = metaBufferPos < metaBuffer.position + def next(): ((Int, K), V) = { + val key = deserStream.readKey[Any]().asInstanceOf[K] + val value = deserStream.readValue[Any]().asInstanceOf[V] + val partition = metaBuffer.get(metaBufferPos + PARTITION) + metaBufferPos += RECORD_SIZE + ((partition, key), value) + } + } + } + + override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity + + override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) + : WritablePartitionedIterator = { + sort(keyComparator) + writablePartitionedIterator + } + + override def writablePartitionedIterator(): WritablePartitionedIterator = { + new WritablePartitionedIterator { + // current position in the meta buffer in ints + var pos = 0 + + def writeNext(writer: BlockObjectWriter): Unit = { + val keyStart = metaBuffer.get(pos + KEY_START) + val valueEnd = metaBuffer.get(pos + VAL_END) + pos += RECORD_SIZE + kvBuffer.read(keyStart, writer, valueEnd - keyStart) + writer.recordWritten() + } + def nextPartition(): Int = metaBuffer.get(pos + PARTITION) + def hasNext(): Boolean = pos < metaBuffer.position + } + } + + // Visible for testing + def orderedInputStream: OrderedInputStream = { + new OrderedInputStream(metaBuffer, kvBuffer) + } + + private def sort(keyComparator: Option[Comparator[K]]): Unit = { + val comparator = if (keyComparator.isEmpty) { + new Comparator[Int]() { + def compare(partition1: Int, partition2: Int): Int = { + partition1 - partition2 + } + } + } else { + throw new UnsupportedOperationException() + } + + val sorter = new Sorter(new SerializedSortDataFormat) + sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) + } +} + +private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) + extends InputStream { + + private var metaBufferPos = 0 + private var kvBufferPos = + if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0 + + override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) + + override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { + if (metaBufferPos >= metaBuffer.position) { + return -1 + } + val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos + val toRead = math.min(bytesRemainingInRecord, len) + kvBuffer.read(kvBufferPos, bytes, offs, toRead) + if (toRead == bytesRemainingInRecord) { + metaBufferPos += RECORD_SIZE + if (metaBufferPos < metaBuffer.position) { + kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START) + } + } else { + kvBufferPos += toRead + } + toRead + } + + override def read(): Int = { + throw new UnsupportedOperationException() + } +} + +private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { + + private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) + + /** Return the sort key for the element at the given index. */ + override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { + metaBuffer.get(pos * RECORD_SIZE + PARTITION) + } + + /** Swap two elements. */ + override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { + val iOff = pos0 * RECORD_SIZE + val jOff = pos1 * RECORD_SIZE + System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) + System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) + System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) + } + + /** Copy a single element from src(srcPos) to dst(dstPos). */ + override def copyElement( + src: IntBuffer, + srcPos: Int, + dst: IntBuffer, + dstPos: Int): Unit = { + val srcOff = srcPos * RECORD_SIZE + val dstOff = dstPos * RECORD_SIZE + System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) + } + + /** + * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. + * Overlapping ranges are allowed. + */ + override def copyRange( + src: IntBuffer, + srcPos: Int, + dst: IntBuffer, + dstPos: Int, + length: Int): Unit = { + val srcOff = srcPos * RECORD_SIZE + val dstOff = dstPos * RECORD_SIZE + System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) + } + + /** + * Allocates a Buffer that can hold up to 'length' elements. + * All elements of the buffer should be considered invalid until data is explicitly copied in. + */ + override def allocate(length: Int): IntBuffer = { + IntBuffer.allocate(length * RECORD_SIZE) + } +} + +private[spark] object PartitionedSerializedPairBuffer { + val KEY_START = 0 + val VAL_START = 1 + val VAL_END = 2 + val PARTITION = 3 + val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala index eb4de413867a088ed561f925a9f9c83d55dc6d8e..722f78bd159861c7fe0895e446dc23f6e5e0094f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -21,7 +21,7 @@ package org.apache.spark.util.collection * An append-only map that keeps track of its estimated size in bytes. */ private[spark] class SizeTrackingAppendOnlyMap[K, V] - extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V] + extends AppendOnlyMap[K, V] with SizeTracker { override def update(key: K, value: V): Unit = { super.update(key, value) diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala new file mode 100644 index 0000000000000000000000000000000000000000..f26d1618c9200ab7d1665c1970e60f4788e21c95 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.storage.BlockObjectWriter + +/** + * A common interface for size-tracking collections of key-value pairs that + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + */ +private[spark] trait WritablePartitionedPairCollection[K, V] { + /** + * Insert a key-value pair with a partition into the collection + */ + def insert(partition: Int, key: K, value: V): Unit + + /** + * Iterate through the data in order of partition ID and then the given comparator. This may + * destroy the underlying collection. + */ + def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] + + /** + * Iterate through the data and write out the elements instead of returning them. Records are + * returned in order of their partition ID and then the given comparator. + * This may destroy the underlying collection. + */ + def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) + : WritablePartitionedIterator = { + WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator)) + } + + /** + * Iterate through the data and write out the elements instead of returning them. + */ + def writablePartitionedIterator(): WritablePartitionedIterator +} + +private[spark] object WritablePartitionedPairCollection { + /** + * A comparator for (Int, K) pairs that orders them by only their partition ID. + */ + def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + + /** + * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering. + */ + def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = { + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + keyComparator.compare(a._2, b._2) + } + } + } + } +} + +/** + * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * has an associated partition. + */ +private[spark] trait WritablePartitionedIterator { + def writeNext(writer: BlockObjectWriter): Unit + + def hasNext(): Boolean + + def nextPartition(): Int +} + +private[spark] object WritablePartitionedIterator { + def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = { + new WritablePartitionedIterator { + var cur = if (it.hasNext) it.next() else null + + def writeNext(writer: BlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 1b13559e77cb8b84979ec7a3a9bfb606f409fbca..778a7eee73b2302a1c61d4ad85cdf0a4034c0b20 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -280,6 +280,15 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val thrown = intercept[SparkException](ser.serialize(largeObject)) assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + + test("getAutoReset") { + val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] + assert(ser.getAutoReset) + val conf = new SparkConf().set("spark.kryo.registrator", + classOf[RegistratorWithoutAutoReset].getName) + val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + assert(!ser2.getAutoReset) + } } @@ -313,4 +322,10 @@ object KryoTest { k.register(classOf[java.util.HashMap[_, _]]) } } + + class RegistratorWithoutAutoReset extends KryoRegistrator { + override def registerClasses(k: Kryo) { + k.setAutoReset(false) + } + } } diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 963264cef3a7181809f921333d3af2c5b2ab8c62..86fcf447287f7458f0435f819dfd03eabd269e12 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag /** - * A serializer implementation that always return a single element in a deserialization stream. + * A serializer implementation that always returns two elements in a deserialization stream. */ class TestSerializer extends Serializer { override def newInstance(): TestSerializerInstance = new TestSerializerInstance @@ -51,7 +51,7 @@ class TestDeserializationStream extends DeserializationStream { override def readObject[T: ClassTag](): T = { count += 1 - if (count == 2) { + if (count == 3) { throw new EOFException } new Object().asInstanceOf[T] diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 7d76435cd75e7e7ffd469dae10c84529b4e7b943..84384bb48999aa32b99e99a950d0e0695103e742 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -59,8 +59,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) for (writer <- shuffle1.writers) { - writer.write("test1") - writer.write("test2") + writer.write("test1", "value") + writer.write("test2", "value") } for (writer <- shuffle1.writers) { writer.commitAndClose() @@ -73,8 +73,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { new ShuffleWriteMetrics) for (writer <- shuffle2.writers) { - writer.write("test3") - writer.write("test4") + writer.write("test3", "value") + writer.write("test4", "vlue") } for (writer <- shuffle2.writers) { writer.commitAndClose() @@ -91,8 +91,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), new ShuffleWriteMetrics) for (writer <- shuffle3.writers) { - writer.write("test3") - writer.write("test4") + writer.write("test3", "value") + writer.write("test4", "value") } for (writer <- shuffle3.writers) { writer.commitAndClose() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 003a728cb84a0e71a2255c94fbf8af5a20b1eb19..43ef469c1fd4806798809be2e58fc8693379bab3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -32,7 +32,7 @@ class BlockObjectWriterSuite extends FunSuite { val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - writer.write(Long.box(20)) + writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write @@ -40,7 +40,7 @@ class BlockObjectWriterSuite extends FunSuite { // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() - writer.write(Long.box(i)) + writer.write(Long.box(i), Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) assert(writeMetrics.shuffleRecordsWritten === 33) @@ -54,7 +54,7 @@ class BlockObjectWriterSuite extends FunSuite { val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - writer.write(Long.box(20)) + writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write @@ -62,7 +62,7 @@ class BlockObjectWriterSuite extends FunSuite { // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() - writer.write(Long.box(i)) + writer.write(Long.box(i), Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) assert(writeMetrics.shuffleRecordsWritten === 33) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..c0c38cd4ac4ad8a66783a8f24aeab6a8540c53f5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.nio.ByteBuffer + +import org.scalatest.FunSuite +import org.scalatest.Matchers._ + +class ChainedBufferSuite extends FunSuite { + test("write and read at start") { + // write from start of source array + val buffer = new ChainedBuffer(8) + buffer.capacity should be (0) + verifyWriteAndRead(buffer, 0, 0, 0, 4) + buffer.capacity should be (8) + + // write from middle of source array + verifyWriteAndRead(buffer, 0, 5, 0, 4) + buffer.capacity should be (8) + + // read to middle of target array + verifyWriteAndRead(buffer, 0, 0, 5, 4) + buffer.capacity should be (8) + + // write up to border + verifyWriteAndRead(buffer, 0, 0, 0, 8) + buffer.capacity should be (8) + + // expand into second buffer + verifyWriteAndRead(buffer, 0, 0, 0, 12) + buffer.capacity should be (16) + + // expand into multiple buffers + verifyWriteAndRead(buffer, 0, 0, 0, 28) + buffer.capacity should be (32) + } + + test("write and read at middle") { + val buffer = new ChainedBuffer(8) + + // fill to a middle point + verifyWriteAndRead(buffer, 0, 0, 0, 3) + + // write from start of source array + verifyWriteAndRead(buffer, 3, 0, 0, 4) + buffer.capacity should be (8) + + // write from middle of source array + verifyWriteAndRead(buffer, 3, 5, 0, 4) + buffer.capacity should be (8) + + // read to middle of target array + verifyWriteAndRead(buffer, 3, 0, 5, 4) + buffer.capacity should be (8) + + // write up to border + verifyWriteAndRead(buffer, 3, 0, 0, 5) + buffer.capacity should be (8) + + // expand into second buffer + verifyWriteAndRead(buffer, 3, 0, 0, 12) + buffer.capacity should be (16) + + // expand into multiple buffers + verifyWriteAndRead(buffer, 3, 0, 0, 28) + buffer.capacity should be (32) + } + + test("write and read at later buffer") { + val buffer = new ChainedBuffer(8) + + // fill to a middle point + verifyWriteAndRead(buffer, 0, 0, 0, 11) + + // write from start of source array + verifyWriteAndRead(buffer, 11, 0, 0, 4) + buffer.capacity should be (16) + + // write from middle of source array + verifyWriteAndRead(buffer, 11, 5, 0, 4) + buffer.capacity should be (16) + + // read to middle of target array + verifyWriteAndRead(buffer, 11, 0, 5, 4) + buffer.capacity should be (16) + + // write up to border + verifyWriteAndRead(buffer, 11, 0, 0, 5) + buffer.capacity should be (16) + + // expand into second buffer + verifyWriteAndRead(buffer, 11, 0, 0, 12) + buffer.capacity should be (24) + + // expand into multiple buffers + verifyWriteAndRead(buffer, 11, 0, 0, 28) + buffer.capacity should be (40) + } + + + // Used to make sure we're writing different bytes each time + var rangeStart = 0 + + /** + * @param buffer The buffer to write to and read from. + * @param offsetInBuffer The offset to write to in the buffer. + * @param offsetInSource The offset in the array that the bytes are written from. + * @param offsetInTarget The offset in the array to read the bytes into. + * @param length The number of bytes to read and write + */ + def verifyWriteAndRead( + buffer: ChainedBuffer, + offsetInBuffer: Int, + offsetInSource: Int, + offsetInTarget: Int, + length: Int): Unit = { + val source = new Array[Byte](offsetInSource + length) + (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource) + buffer.write(offsetInBuffer, source, offsetInSource, length) + val target = new Array[Byte](offsetInTarget + length) + buffer.read(offsetInBuffer, target, offsetInTarget, length) + ByteBuffer.wrap(source, offsetInSource, length) should be + (ByteBuffer.wrap(target, offsetInTarget, length)) + + rangeStart += 100 + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index de26aa351b0d2a720a1eb132ea97a27c73e71cd2..20fd22b78ef5d87bb6f840236dc1f164e27ffd4a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,19 +19,24 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.{PrivateMethodTester, FunSuite} - -import org.apache.spark._ +import org.scalatest.{FunSuite, PrivateMethodTester} import scala.util.Random +import org.apache.spark._ +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} + class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { - private def createSparkConf(loadDefaults: Boolean): SparkConf = { + private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) - // Make the Java serializer write a reset instruction (TC_RESET) after each object to test - // for a bug we had with bytes written past the last object in a batch (SPARK-2792) - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + if (kryo) { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + } else { + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", classOf[JavaSerializer].getName) + } // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") conf @@ -47,8 +52,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") } - test("empty data stream") { - val conf = new SparkConf(false) + test("empty data stream with kryo ser") { + emptyDataStream(createSparkConf(false, true)) + } + + test("empty data stream with java ser") { + emptyDataStream(createSparkConf(false, false)) + } + + def emptyDataStream(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -81,8 +93,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe sorter4.stop() } - test("few elements per partition") { - val conf = createSparkConf(false) + test("few elements per partition with kryo ser") { + fewElementsPerPartition(createSparkConf(false, true)) + } + + test("few elements per partition with java ser") { + fewElementsPerPartition(createSparkConf(false, false)) + } + + def fewElementsPerPartition(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -123,8 +142,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe sorter4.stop() } - test("empty partitions with spilling") { - val conf = createSparkConf(false) + test("empty partitions with spilling with kryo ser") { + emptyPartitionsWithSpilling(createSparkConf(false, true)) + } + + test("empty partitions with spilling with java ser") { + emptyPartitionsWithSpilling(createSparkConf(false, false)) + } + + def emptyPartitionsWithSpilling(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") @@ -149,8 +175,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe sorter.stop() } - test("empty partitions with spilling, bypass merge-sort") { - val conf = createSparkConf(false) + test("empty partitions with spilling, bypass merge-sort with kryo ser") { + emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true)) + } + + test("empty partitions with spilling, bypass merge-sort with java ser") { + emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false)) + } + + def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") @@ -174,8 +207,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe sorter.stop() } - test("spilling in local cluster") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + test("spilling in local cluster with kryo ser") { + // Load defaults, otherwise SPARK_HOME is not found + testSpillingInLocalCluster(createSparkConf(true, true)) + } + + test("spilling in local cluster with java ser") { + // Load defaults, otherwise SPARK_HOME is not found + testSpillingInLocalCluster(createSparkConf(true, false)) + } + + def testSpillingInLocalCluster(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -245,8 +287,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } - test("spilling in local cluster with many reduce tasks") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + test("spilling in local cluster with many reduce tasks with kryo ser") { + spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, true)) + } + + test("spilling in local cluster with many reduce tasks with java ser") { + spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, false)) + } + + def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[2,1,512]", "test", conf) @@ -317,7 +366,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("cleanup of intermediate files in sorter") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -344,7 +393,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("cleanup of intermediate files in sorter, bypass merge-sort") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -367,7 +416,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("cleanup of intermediate files in sorter if there are errors") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -392,7 +441,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { - val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -414,7 +463,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("cleanup of intermediate files in shuffle") { - val conf = createSparkConf(false) + val conf = createSparkConf(false, false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -429,7 +478,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("cleanup of intermediate files in shuffle with errors") { - val conf = createSparkConf(false) + val conf = createSparkConf(false, false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -450,8 +499,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(diskBlockManager.getAllFiles().length === 2) } - test("no partial aggregation or sorting") { - val conf = createSparkConf(false) + test("no partial aggregation or sorting with kryo ser") { + noPartialAggregationOrSorting(createSparkConf(false, true)) + } + + test("no partial aggregation or sorting with java ser") { + noPartialAggregationOrSorting(createSparkConf(false, false)) + } + + def noPartialAggregationOrSorting(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -465,8 +521,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(results === expected) } - test("partial aggregation without spill") { - val conf = createSparkConf(false) + test("partial aggregation without spill with kryo ser") { + partialAggregationWithoutSpill(createSparkConf(false, true)) + } + + test("partial aggregation without spill with java ser") { + partialAggregationWithoutSpill(createSparkConf(false, false)) + } + + def partialAggregationWithoutSpill(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -481,8 +544,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(results === expected) } - test("partial aggregation with spill, no ordering") { - val conf = createSparkConf(false) + test("partial aggregation with spill, no ordering with kryo ser") { + partialAggregationWIthSpillNoOrdering(createSparkConf(false, true)) + } + + test("partial aggregation with spill, no ordering with java ser") { + partialAggregationWIthSpillNoOrdering(createSparkConf(false, false)) + } + + def partialAggregationWIthSpillNoOrdering(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -497,8 +567,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(results === expected) } - test("partial aggregation with spill, with ordering") { - val conf = createSparkConf(false) + test("partial aggregation with spill, with ordering with kryo ser") { + partialAggregationWithSpillWithOrdering(createSparkConf(false, true)) + } + + + test("partial aggregation with spill, with ordering with java ser") { + partialAggregationWithSpillWithOrdering(createSparkConf(false, false)) + } + + def partialAggregationWithSpillWithOrdering(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -517,8 +595,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(results === expected) } - test("sorting without aggregation, no spill") { - val conf = createSparkConf(false) + test("sorting without aggregation, no spill with kryo ser") { + sortingWithoutAggregationNoSpill(createSparkConf(false, true)) + } + + test("sorting without aggregation, no spill with java ser") { + sortingWithoutAggregationNoSpill(createSparkConf(false, false)) + } + + def sortingWithoutAggregationNoSpill(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -534,8 +619,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(results === expected) } - test("sorting without aggregation, with spill") { - val conf = createSparkConf(false) + test("sorting without aggregation, with spill with kryo ser") { + sortingWithoutAggregationWithSpill(createSparkConf(false, true)) + } + + test("sorting without aggregation, with spill with java ser") { + sortingWithoutAggregationWithSpill(createSparkConf(false, false)) + } + + def sortingWithoutAggregationWithSpill(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -552,7 +644,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("spilling with hash collisions") { - val conf = createSparkConf(true) + val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -609,7 +701,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("spilling with many hash collisions") { - val conf = createSparkConf(true) + val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -632,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = createSparkConf(true) + val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -656,7 +748,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("spilling with null keys and values") { - val conf = createSparkConf(true) + val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -685,7 +777,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } test("conditions for bypassing merge-sort") { - val conf = createSparkConf(false) + val conf = createSparkConf(false, false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -718,8 +810,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assertDidNotBypassMergeSort(sorter4) } - test("sort without breaking sorting contracts") { - val conf = createSparkConf(true) + test("sort without breaking sorting contracts with kryo ser") { + sortWithoutBreakingSortingContracts(createSparkConf(true, true)) + } + + test("sort without breaking sorting contracts with java ser") { + sortWithoutBreakingSortingContracts(createSparkConf(true, false)) + } + + def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5a2d9ef720c18ad15b527db59881706a673809b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} + +import com.google.common.io.ByteStreams + +import org.scalatest.FunSuite +import org.scalatest.Matchers._ + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.storage.{FileSegment, BlockObjectWriter} + +class PartitionedSerializedPairBufferSuite extends FunSuite { + test("OrderedInputStream single record") { + val serializerInstance = new KryoSerializer(new SparkConf()).newInstance + + val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) + val struct = SomeStruct("something", 5) + buffer.insert(4, 10, struct) + + val bytes = ByteStreams.toByteArray(buffer.orderedInputStream) + + val baos = new ByteArrayOutputStream() + val stream = serializerInstance.serializeStream(baos) + stream.writeObject(10) + stream.writeObject(struct) + stream.close() + + baos.toByteArray should be (bytes) + } + + test("insert single record") { + val serializerInstance = new KryoSerializer(new SparkConf()).newInstance + val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) + val struct = SomeStruct("something", 5) + buffer.insert(4, 10, struct) + val elements = buffer.partitionedDestructiveSortedIterator(None).toArray + elements.size should be (1) + elements.head should be (((4, 10), struct)) + } + + test("insert multiple records") { + val serializerInstance = new KryoSerializer(new SparkConf()).newInstance + val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) + val struct1 = SomeStruct("something1", 8) + buffer.insert(6, 1, struct1) + val struct2 = SomeStruct("something2", 9) + buffer.insert(4, 2, struct2) + val struct3 = SomeStruct("something3", 10) + buffer.insert(5, 3, struct3) + + val elements = buffer.partitionedDestructiveSortedIterator(None).toArray + elements.size should be (3) + elements(0) should be (((4, 2), struct2)) + elements(1) should be (((5, 3), struct3)) + elements(2) should be (((6, 1), struct1)) + } + + test("write single record") { + val serializerInstance = new KryoSerializer(new SparkConf()).newInstance + val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) + val struct = SomeStruct("something", 5) + buffer.insert(4, 10, struct) + val it = buffer.destructiveSortedWritablePartitionedIterator(None) + val writer = new SimpleBlockObjectWriter + assert(it.hasNext) + it.nextPartition should be (4) + it.writeNext(writer) + assert(!it.hasNext) + + val stream = serializerInstance.deserializeStream(writer.getInputStream) + stream.readObject[AnyRef]() should be (10) + stream.readObject[AnyRef]() should be (struct) + } + + test("write multiple records") { + val serializerInstance = new KryoSerializer(new SparkConf()).newInstance + val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) + val struct1 = SomeStruct("something1", 8) + buffer.insert(6, 1, struct1) + val struct2 = SomeStruct("something2", 9) + buffer.insert(4, 2, struct2) + val struct3 = SomeStruct("something3", 10) + buffer.insert(5, 3, struct3) + + val it = buffer.destructiveSortedWritablePartitionedIterator(None) + val writer = new SimpleBlockObjectWriter + assert(it.hasNext) + it.nextPartition should be (4) + it.writeNext(writer) + assert(it.hasNext) + it.nextPartition should be (5) + it.writeNext(writer) + assert(it.hasNext) + it.nextPartition should be (6) + it.writeNext(writer) + assert(!it.hasNext) + + val stream = serializerInstance.deserializeStream(writer.getInputStream) + val iter = stream.asIterator + iter.next() should be (2) + iter.next() should be (struct2) + iter.next() should be (3) + iter.next() should be (struct3) + iter.next() should be (1) + iter.next() should be (struct1) + assert(!iter.hasNext) + } +} + +case class SomeStruct(val str: String, val num: Int) + +class SimpleBlockObjectWriter extends BlockObjectWriter(null) { + val baos = new ByteArrayOutputStream() + + override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { + baos.write(bytes, offs, len) + } + + def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray) + + override def open(): BlockObjectWriter = this + override def close(): Unit = { } + override def isOpen: Boolean = true + override def commitAndClose(): Unit = { } + override def revertPartialWritesAndClose(): Unit = { } + override def fileSegment(): FileSegment = null + override def write(key: Any, value: Any): Unit = { } + override def recordWritten(): Unit = { } + override def write(b: Int): Unit = { } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index cec97de2cd8e436fbbf134cba11b3b4472c627ee..9552f41115866f97d2ff64473a96966a465067bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -50,10 +50,10 @@ private[sql] class Serializer2SerializationStream( extends SerializationStream with Logging { val rowOut = new DataOutputStream(out) - val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) - def writeObject[T: ClassTag](t: T): SerializationStream = { + override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] writeKey(kv._1) writeValue(kv._2) @@ -61,6 +61,16 @@ private[sql] class Serializer2SerializationStream( this } + override def writeKey[T: ClassTag](t: T): SerializationStream = { + writeKeyFunc(t.asInstanceOf[Row]) + this + } + + override def writeValue[T: ClassTag](t: T): SerializationStream = { + writeValueFunc(t.asInstanceOf[Row]) + this + } + def flush(): Unit = { rowOut.flush() } @@ -83,17 +93,27 @@ private[sql] class Serializer2DeserializationStream( val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null - val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) - val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) - def readObject[T: ClassTag](): T = { - readKey() - readValue() + override def readObject[T: ClassTag](): T = { + readKeyFunc() + readValueFunc() (key, value).asInstanceOf[T] } - def close(): Unit = { + override def readKey[T: ClassTag](): T = { + readKeyFunc() + key.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + readValueFunc() + value.asInstanceOf[T] + } + + override def close(): Unit = { rowIn.close() } } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index f2d135397ce2fde9ed00dcaa2b13fde67aec6402..baa97616eaff3da737cc7f16f5b069f1156cbb0c 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -46,7 +46,8 @@ object StoragePerfTester { val totalRecords = dataSizeMb * 1000 val recordsPerMap = totalRecords / numMaps - val writeData = "1" * recordLength + val writeKey = "1" * (recordLength / 2) + val writeValue = "1" * (recordLength / 2) val executor = Executors.newFixedThreadPool(numMaps) val conf = new SparkConf() @@ -63,7 +64,7 @@ object StoragePerfTester { new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { - writers(i % numOutputSplits).write(writeData) + writers(i % numOutputSplits).write(writeKey, writeValue) } writers.map { w => w.commitAndClose()