Skip to content
Snippets Groups Projects
Commit e209fa27 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Sean Owen
Browse files

[SPARK-11271][SPARK-11016][CORE] Use Spark BitSet instead of RoaringBitmap to reduce memory usage

JIRA: https://issues.apache.org/jira/browse/SPARK-11271

As reported in the JIRA ticket, when there are too many tasks, the memory usage of MapStatus will cause problem. Use BitSet instead of RoaringBitMap should be more efficient in memory usage.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #9243 from viirya/mapstatus-bitset.
parent e963070c
No related branches found
No related tags found
No related merge requests found
...@@ -173,10 +173,6 @@ ...@@ -173,10 +173,6 @@
<groupId>net.jpountz.lz4</groupId> <groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId> <artifactId>lz4</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.roaringbitmap</groupId>
<artifactId>RoaringBitmap</artifactId>
</dependency>
<dependency> <dependency>
<groupId>commons-net</groupId> <groupId>commons-net</groupId>
<artifactId>commons-net</artifactId> <artifactId>commons-net</artifactId>
......
...@@ -19,9 +19,8 @@ package org.apache.spark.scheduler ...@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.io.{Externalizable, ObjectInput, ObjectOutput}
import org.roaringbitmap.RoaringBitmap
import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
...@@ -133,7 +132,7 @@ private[spark] class CompressedMapStatus( ...@@ -133,7 +132,7 @@ private[spark] class CompressedMapStatus(
private[spark] class HighlyCompressedMapStatus private ( private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId, private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int, private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap, private[this] var emptyBlocks: BitSet,
private[this] var avgSize: Long) private[this] var avgSize: Long)
extends MapStatus with Externalizable { extends MapStatus with Externalizable {
...@@ -146,7 +145,7 @@ private[spark] class HighlyCompressedMapStatus private ( ...@@ -146,7 +145,7 @@ private[spark] class HighlyCompressedMapStatus private (
override def location: BlockManagerId = loc override def location: BlockManagerId = loc
override def getSizeForBlock(reduceId: Int): Long = { override def getSizeForBlock(reduceId: Int): Long = {
if (emptyBlocks.contains(reduceId)) { if (emptyBlocks.get(reduceId)) {
0 0
} else { } else {
avgSize avgSize
...@@ -161,7 +160,7 @@ private[spark] class HighlyCompressedMapStatus private ( ...@@ -161,7 +160,7 @@ private[spark] class HighlyCompressedMapStatus private (
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in) loc = BlockManagerId(in)
emptyBlocks = new RoaringBitmap() emptyBlocks = new BitSet
emptyBlocks.readExternal(in) emptyBlocks.readExternal(in)
avgSize = in.readLong() avgSize = in.readLong()
} }
...@@ -177,15 +176,15 @@ private[spark] object HighlyCompressedMapStatus { ...@@ -177,15 +176,15 @@ private[spark] object HighlyCompressedMapStatus {
// From a compression standpoint, it shouldn't matter whether we track empty or non-empty // From a compression standpoint, it shouldn't matter whether we track empty or non-empty
// blocks. From a performance standpoint, we benefit from tracking empty blocks because // blocks. From a performance standpoint, we benefit from tracking empty blocks because
// we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length val totalNumBlocks = uncompressedSizes.length
val emptyBlocks = new BitSet(totalNumBlocks)
while (i < totalNumBlocks) { while (i < totalNumBlocks) {
var size = uncompressedSizes(i) var size = uncompressedSizes(i)
if (size > 0) { if (size > 0) {
numNonEmptyBlocks += 1 numNonEmptyBlocks += 1
totalSize += size totalSize += size
} else { } else {
emptyBlocks.add(i) emptyBlocks.set(i)
} }
i += 1 i += 1
} }
......
...@@ -30,7 +30,6 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} ...@@ -30,7 +30,6 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.generic.{GenericData, GenericRecord}
import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap}
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.api.python.PythonBroadcast
...@@ -39,7 +38,7 @@ import org.apache.spark.network.util.ByteUnit ...@@ -39,7 +38,7 @@ import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._ import org.apache.spark.storage._
import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf}
import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.collection.{BitSet, CompactBuffer}
/** /**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
...@@ -363,12 +362,7 @@ private[serializer] object KryoSerializer { ...@@ -363,12 +362,7 @@ private[serializer] object KryoSerializer {
classOf[StorageLevel], classOf[StorageLevel],
classOf[CompressedMapStatus], classOf[CompressedMapStatus],
classOf[HighlyCompressedMapStatus], classOf[HighlyCompressedMapStatus],
classOf[RoaringBitmap], classOf[BitSet],
classOf[RoaringArray],
classOf[RoaringArray.Element],
classOf[Array[RoaringArray.Element]],
classOf[ArrayContainer],
classOf[BitmapContainer],
classOf[CompactBuffer[_]], classOf[CompactBuffer[_]],
classOf[BlockManagerId], classOf[BlockManagerId],
classOf[Array[Byte]], classOf[Array[Byte]],
......
...@@ -17,14 +17,21 @@ ...@@ -17,14 +17,21 @@
package org.apache.spark.util.collection package org.apache.spark.util.collection
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import org.apache.spark.util.{Utils => UUtils}
/** /**
* A simple, fixed-size bit set implementation. This implementation is fast because it avoids * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
* safety/bound checking. * safety/bound checking.
*/ */
class BitSet(numBits: Int) extends Serializable { class BitSet(private[this] var numBits: Int) extends Externalizable {
private val words = new Array[Long](bit2words(numBits)) private var words = new Array[Long](bit2words(numBits))
private val numWords = words.length private def numWords = words.length
def this() = this(0)
/** /**
* Compute the capacity (number of bits) that can be represented * Compute the capacity (number of bits) that can be represented
...@@ -230,4 +237,19 @@ class BitSet(numBits: Int) extends Serializable { ...@@ -230,4 +237,19 @@ class BitSet(numBits: Int) extends Serializable {
/** Return the number of longs it would take to hold numBits. */ /** Return the number of longs it would take to hold numBits. */
private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException {
out.writeInt(numBits)
words.foreach(out.writeLong(_))
}
override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException {
numBits = in.readInt()
words = new Array[Long](bit2words(numBits))
var index = 0
while (index < words.length) {
words(index) = in.readLong()
index += 1
}
}
} }
...@@ -322,12 +322,6 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { ...@@ -322,12 +322,6 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
val conf = new SparkConf(false) val conf = new SparkConf(false)
conf.set("spark.kryo.registrationRequired", "true") conf.set("spark.kryo.registrationRequired", "true")
// these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16
// values, and they use a bitmap (dense) if they have more than 4096 values, and an
// array (sparse) if they use less. So we just create two cases, one sparse and one dense.
// and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly
// empty blocks
val ser = new KryoSerializer(conf).newInstance() val ser = new KryoSerializer(conf).newInstance()
val denseBlockSizes = new Array[Long](5000) val denseBlockSizes = new Array[Long](5000)
val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L)
......
...@@ -17,7 +17,10 @@ ...@@ -17,7 +17,10 @@
package org.apache.spark.util.collection package org.apache.spark.util.collection
import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream}
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.util.{Utils => UUtils}
class BitSetSuite extends SparkFunSuite { class BitSetSuite extends SparkFunSuite {
...@@ -152,4 +155,50 @@ class BitSetSuite extends SparkFunSuite { ...@@ -152,4 +155,50 @@ class BitSetSuite extends SparkFunSuite {
assert(bitsetDiff.nextSetBit(85) === 85) assert(bitsetDiff.nextSetBit(85) === 85)
assert(bitsetDiff.nextSetBit(86) === -1) assert(bitsetDiff.nextSetBit(86) === -1)
} }
test("read and write externally") {
val tempDir = UUtils.createTempDir()
val outputFile = File.createTempFile("bits", null, tempDir)
val fos = new FileOutputStream(outputFile)
val oos = new ObjectOutputStream(fos)
// Create BitSet
val setBits = Seq(0, 9, 1, 10, 90, 96)
val bitset = new BitSet(100)
for (i <- 0 until 100) {
assert(!bitset.get(i))
}
setBits.foreach(i => bitset.set(i))
for (i <- 0 until 100) {
if (setBits.contains(i)) {
assert(bitset.get(i))
} else {
assert(!bitset.get(i))
}
}
assert(bitset.cardinality() === setBits.size)
bitset.writeExternal(oos)
oos.close()
val fis = new FileInputStream(outputFile)
val ois = new ObjectInputStream(fis)
// Read BitSet from the file
val bitset2 = new BitSet(0)
bitset2.readExternal(ois)
for (i <- 0 until 100) {
if (setBits.contains(i)) {
assert(bitset2.get(i))
} else {
assert(!bitset2.get(i))
}
}
assert(bitset2.cardinality() === setBits.size)
}
} }
...@@ -623,11 +623,6 @@ ...@@ -623,11 +623,6 @@
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
<dependency>
<groupId>org.roaringbitmap</groupId>
<artifactId>RoaringBitmap</artifactId>
<version>0.4.5</version>
</dependency>
<dependency> <dependency>
<groupId>commons-net</groupId> <groupId>commons-net</groupId>
<artifactId>commons-net</artifactId> <artifactId>commons-net</artifactId>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment