Skip to content
Snippets Groups Projects
Commit 6743de3a authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-12937][SQL] bloom filter serialization

This PR adds serialization support for BloomFilter.

A version number is added to version the serialized binary format.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10920 from cloud-fan/bloom-filter.
parent d54cfed5
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,9 @@
package org.apache.spark.util.sketch;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
public final class BitArray {
......@@ -24,6 +27,9 @@ public final class BitArray {
private long bitCount;
static int numWords(long numBits) {
if (numBits <= 0) {
throw new IllegalArgumentException("numBits must be positive, but got " + numBits);
}
long numWords = (long) Math.ceil(numBits / 64.0);
if (numWords > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits");
......@@ -32,13 +38,14 @@ public final class BitArray {
}
BitArray(long numBits) {
if (numBits <= 0) {
throw new IllegalArgumentException("numBits must be positive");
}
this.data = new long[numWords(numBits)];
this(new long[numWords(numBits)]);
}
private BitArray(long[] data) {
this.data = data;
long bitCount = 0;
for (long value : data) {
bitCount += Long.bitCount(value);
for (long word : data) {
bitCount += Long.bitCount(word);
}
this.bitCount = bitCount;
}
......@@ -78,13 +85,28 @@ public final class BitArray {
this.bitCount = bitCount;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || !(o instanceof BitArray)) return false;
void writeTo(DataOutputStream out) throws IOException {
out.writeInt(data.length);
for (long datum : data) {
out.writeLong(datum);
}
}
static BitArray readFrom(DataInputStream in) throws IOException {
int numWords = in.readInt();
long[] data = new long[numWords];
for (int i = 0; i < numWords; i++) {
data[i] = in.readLong();
}
return new BitArray(data);
}
BitArray bitArray = (BitArray) o;
return Arrays.equals(data, bitArray.data);
@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || !(other instanceof BitArray)) return false;
BitArray that = (BitArray) other;
return Arrays.equals(data, that.data);
}
@Override
......
......@@ -17,6 +17,10 @@
package org.apache.spark.util.sketch;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether
* an element is a member of a set. It returns false when the element is definitely not in the
......@@ -39,6 +43,28 @@ package org.apache.spark.util.sketch;
* The implementation is largely based on the {@code BloomFilter} class from guava.
*/
public abstract class BloomFilter {
public enum Version {
/**
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total number of words of the underlying bit array (32 bit)
* - The words/longs (numWords * 64 bit)
* - Number of hash functions (32 bit)
*/
V1(1);
private final int versionNumber;
Version(int versionNumber) {
this.versionNumber = versionNumber;
}
int getVersionNumber() {
return versionNumber;
}
}
/**
* Returns the false positive probability, i.e. the probability that
* {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that
......@@ -83,7 +109,7 @@ public abstract class BloomFilter {
* bloom filters are appropriately sized to avoid saturating them.
*
* @param other The bloom filter to combine this bloom filter with. It is not mutated.
* @throws IllegalArgumentException if {@code isCompatible(that) == false}
* @throws IncompatibleMergeException if {@code isCompatible(other) == false}
*/
public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException;
......@@ -93,6 +119,20 @@ public abstract class BloomFilter {
*/
public abstract boolean mightContain(Object item);
/**
* Writes out this {@link BloomFilter} to an output stream in binary format.
* It is the caller's responsibility to close the stream.
*/
public abstract void writeTo(OutputStream out) throws IOException;
/**
* Reads in a {@link BloomFilter} from an input stream.
* It is the caller's responsibility to close the stream.
*/
public static BloomFilter readFrom(InputStream in) throws IOException {
return BloomFilterImpl.readFrom(in);
}
/**
* Computes the optimal k (number of hashes per element inserted in Bloom filter), given the
* expected insertions and total number of bits in the Bloom filter.
......
......@@ -17,7 +17,7 @@
package org.apache.spark.util.sketch;
import java.io.UnsupportedEncodingException;
import java.io.*;
public class BloomFilterImpl extends BloomFilter {
......@@ -25,8 +25,32 @@ public class BloomFilterImpl extends BloomFilter {
private final BitArray bits;
BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
}
private BloomFilterImpl(BitArray bits, int numHashFunctions) {
this.bits = bits;
this.numHashFunctions = numHashFunctions;
this.bits = new BitArray(numBits);
}
@Override
public boolean equals(Object other) {
if (other == this) {
return true;
}
if (other == null || !(other instanceof BloomFilterImpl)) {
return false;
}
BloomFilterImpl that = (BloomFilterImpl) other;
return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits);
}
@Override
public int hashCode() {
return bits.hashCode() * 31 + numHashFunctions;
}
@Override
......@@ -161,4 +185,24 @@ public class BloomFilterImpl extends BloomFilter {
this.bits.putAll(that.bits);
return this;
}
@Override
public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);
dos.writeInt(Version.V1.getVersionNumber());
bits.writeTo(dos);
dos.writeInt(numHashFunctions);
}
public static BloomFilterImpl readFrom(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);
int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}
return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
}
}
......@@ -55,10 +55,21 @@ import java.io.OutputStream;
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
*/
abstract public class CountMinSketch {
/**
* Version number of the serialized binary format.
*/
public enum Version {
/**
* {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
*/
V1(1);
private final int versionNumber;
......@@ -67,13 +78,11 @@ abstract public class CountMinSketch {
this.versionNumber = versionNumber;
}
public int getVersionNumber() {
int getVersionNumber() {
return versionNumber;
}
}
public abstract Version version();
/**
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
*/
......@@ -128,13 +137,13 @@ abstract public class CountMinSketch {
/**
* Writes out this {@link CountMinSketch} to an output stream in binary format.
* It is the caller's responsibility to close the stream
* It is the caller's responsibility to close the stream.
*/
public abstract void writeTo(OutputStream out) throws IOException;
/**
* Reads in a {@link CountMinSketch} from an input stream.
* It is the caller's responsibility to close the stream
* It is the caller's responsibility to close the stream.
*/
public static CountMinSketch readFrom(InputStream in) throws IOException {
return CountMinSketchImpl.readFrom(in);
......
......@@ -26,21 +26,6 @@ import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;
/*
* Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian
* order):
*
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
*/
class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;
......@@ -112,11 +97,6 @@ class CountMinSketchImpl extends CountMinSketch {
return hash;
}
@Override
public Version version() {
return Version.V1;
}
private void initTablesWith(int depth, int width, int seed) {
this.table = new long[depth][width];
this.hashA = new long[depth];
......@@ -327,7 +307,7 @@ class CountMinSketchImpl extends CountMinSketch {
public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);
dos.writeInt(version().getVersionNumber());
dos.writeInt(Version.V1.getVersionNumber());
dos.writeLong(this.totalCount);
dos.writeInt(this.depth);
......
......@@ -17,6 +17,8 @@
package org.apache.spark.util.sketch
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.reflect.ClassTag
import scala.util.Random
......@@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite
class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
private final val EPSILON = 0.01
// Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized
// version is equivalent to the original one.
private def checkSerDe(filter: BloomFilter): Unit = {
val out = new ByteArrayOutputStream()
filter.writeTo(out)
out.close()
val in = new ByteArrayInputStream(out.toByteArray)
val deserialized = BloomFilter.readFrom(in)
in.close()
assert(filter == deserialized)
}
def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = {
test(s"accuracy - $typeName") {
// use a fixed seed to make the test predictable.
......@@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
// Also check the actual fpp is not significantly higher than we expected.
val actualFpp = errorCount.toDouble / (numItems - numInsertion)
assert(actualFpp - fpp < EPSILON)
checkSerDe(filter)
}
}
......@@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
items1.foreach(i => assert(filter1.mightContain(i)))
items2.foreach(i => assert(filter1.mightContain(i)))
checkSerDe(filter1)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment