diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 0011096d4a47bcc006a2cd0e34a5ecb047c9bcbf..f7c22dddb8cc0dabb99fa260fd31a730a7577321 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -17,12 +17,13 @@ package org.apache.spark.util.sketch; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; /** - * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in + * A Count-min sketch is a probabilistic data structure used for cardinality estimation using * sub-linear space. Currently, supported data types include: * <ul> * <li>{@link Byte}</li> @@ -30,10 +31,6 @@ import java.io.OutputStream; * <li>{@link Integer}</li> * <li>{@link Long}</li> * <li>{@link String}</li> - * <li>{@link Float}</li> - * <li>{@link Double}</li> - * <li>{@link java.math.BigDecimal}</li> - * <li>{@link Boolean}</li> * </ul> * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: * <ol> @@ -177,6 +174,11 @@ public abstract class CountMinSketch { */ public abstract void writeTo(OutputStream out) throws IOException; + /** + * Serializes this {@link CountMinSketch} and returns the serialized form. + */ + public abstract byte[] toByteArray() throws IOException; + /** * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to * close the stream. @@ -185,6 +187,16 @@ public abstract class CountMinSketch { return CountMinSketchImpl.readFrom(in); } + /** + * Reads in a {@link CountMinSketch} from a byte array. + */ + public static CountMinSketch readFrom(byte[] bytes) throws IOException { + InputStream in = new ByteArrayInputStream(bytes); + CountMinSketch cms = readFrom(in); + in.close(); + return cms; + } + /** * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random * {@code seed}. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 94ab3a98cb65af942039fe74d18322d582f71432..045fec33a282a18d401b55b4348d23f5fa6c599f 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -17,15 +17,7 @@ package org.apache.spark.util.sketch; -import java.io.DataInputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.OutputStream; -import java.io.Serializable; -import java.math.BigDecimal; +import java.io.*; import java.util.Arrays; import java.util.Random; @@ -153,16 +145,8 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); - } else if (item instanceof BigDecimal) { - addString(((BigDecimal) item).toString(), count); } else if (item instanceof byte[]) { addBinary((byte[]) item, count); - } else if (item instanceof Float) { - addLong(Float.floatToIntBits((Float) item), count); - } else if (item instanceof Double) { - addLong(Double.doubleToLongBits((Double) item), count); - } else if (item instanceof Boolean) { - addLong(((Boolean) item) ? 1L : 0L, count); } else { addLong(Utils.integralToLong(item), count); } @@ -227,6 +211,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { return ((int) hash) % width; } + private static int[] getHashBuckets(String key, int hashCount, int max) { + return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); + } + private static int[] getHashBuckets(byte[] b, int hashCount, int max) { int[] result = new int[hashCount]; int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); @@ -240,18 +228,9 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { @Override public long estimateCount(Object item) { if (item instanceof String) { - return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item)); - } else if (item instanceof BigDecimal) { - return estimateCountForBinaryItem( - Utils.getBytesFromUTF8String(((BigDecimal) item).toString())); + return estimateCountForStringItem((String) item); } else if (item instanceof byte[]) { return estimateCountForBinaryItem((byte[]) item); - } else if (item instanceof Float) { - return estimateCountForLongItem(Float.floatToIntBits((Float) item)); - } else if (item instanceof Double) { - return estimateCountForLongItem(Double.doubleToLongBits((Double) item)); - } else if (item instanceof Boolean) { - return estimateCountForLongItem(((Boolean) item) ? 1L : 0L); } else { return estimateCountForLongItem(Utils.integralToLong(item)); } @@ -265,6 +244,15 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { return res; } + private long estimateCountForStringItem(String item) { + long res = Long.MAX_VALUE; + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][buckets[i]]); + } + return res; + } + private long estimateCountForBinaryItem(byte[] item) { long res = Long.MAX_VALUE; int[] buckets = getHashBuckets(item, depth, width); @@ -332,6 +320,14 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable { } } + @Override + public byte[] toByteArray() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeTo(out); + out.close(); + return out.toByteArray(); + } + public static CountMinSketchImpl readFrom(InputStream in) throws IOException { CountMinSketchImpl sketch = new CountMinSketchImpl(); sketch.readFrom0(in); diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index 2c358fcee46ce5a7e284683a914b4746a2fb0cdc..174eb01986c4f1e0562c0b7af13b0315128c5178 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.util.sketch import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import scala.util.Random @@ -26,9 +25,9 @@ import scala.util.Random import org.scalatest.FunSuite // scalastyle:ignore funsuite class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite - private val epsOfTotalCount = 0.0001 + private val epsOfTotalCount = 0.01 - private val confidence = 0.99 + private val confidence = 0.9 private val seed = 42 @@ -45,12 +44,6 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { - def getProbeItem(item: T): Any = item match { - // Use a string to represent the content of an array of bytes - case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) - case i => identity(i) - } - test(s"accuracy - $typeName") { // Uses fixed seed to ensure reproducible test execution val r = new Random(31) @@ -63,7 +56,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val exactFreq = { val sampledItems = sampledItemIndices.map(allItems) - sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) + sampledItems.groupBy(identity).mapValues(_.length.toLong) } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -74,12 +67,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val probCorrect = { val numErrors = allItems.map { item => - val count = exactFreq.getOrElse(getProbeItem(item), 0L) + val count = exactFreq.getOrElse(item, 0L) val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems if (ratio > epsOfTotalCount) 1 else 0 }.sum - 1D - numErrors.toDouble / numAllItems + 1.0 - (numErrors.toDouble / numAllItems) } assert( @@ -96,9 +89,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val numToMerge = 5 val numItemsPerSketch = 100000 - val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { - itemGenerator(r) - } + val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { itemGenerator(r) } val sketches = perSketchItems.map { items => val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -113,11 +104,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val mergedSketch = sketches.reduce(_ mergeInPlace _) checkSerDe(mergedSketch) - val expectedSketch = { - val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) - perSketchItems.foreach(_.foreach(sketch.add)) - sketch - } + val expectedSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + perSketchItems.foreach(_.foreach(expectedSketch.add)) perSketchItems.foreach { _.foreach { item => @@ -142,17 +130,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } - testItemType[Float]("Float") { _.nextFloat() } - - testItemType[Double]("Double") { _.nextDouble() } - - testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) } - - testItemType[Boolean]("Boolean") { _.nextBoolean() } - - testItemType[Array[Byte]]("Binary") { r => - Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20))) - } + testItemType[Array[Byte]]("Byte array") { r => r.nextString(r.nextInt(60)).getBytes } test("incompatible merge") { intercept[IncompatibleMergeException] { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4995af034f654e0faa6c9ed4710ecc4ad5b61314..b113bbf803d950cea8327a89dee27fab002b5ffb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,11 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + lazy val v22excludes = v21excludes ++ Seq( + // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray") + ) + // Exclude rules for 2.1.x lazy val v21excludes = v20excludes ++ { Seq( @@ -912,7 +917,8 @@ object MimaExcludes { } def excludes(version: String) = version match { - case v if v.startsWith("2.1") => v21excludes + case v if v.startsWith("2.2") => v22excludes + case v if v.startsWith("2.1") => v22excludes // TODO: Update this when we bump version to 2.2 case v if v.startsWith("2.0") => v20excludes case _ => Seq() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index f5f185f2c54287180adf24e7a1aae3ef103cbc64..612c19831f0b2bb22c71e222f96cce8dfdbd5382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -42,9 +40,9 @@ import org.apache.spark.util.sketch.CountMinSketch @ExpressionDescription( usage = """ _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp, - confidence and seed. The result is an array of bytes, which should be deserialized to a - `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join - size estimation. + confidence and seed. The result is an array of bytes, which can be deserialized to a + `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for + cardinality estimation using sub-linear space. """) case class CountMinSketchAgg( child: Expression, @@ -75,13 +73,13 @@ case class CountMinSketchAgg( } else if (!epsExpression.foldable || !confidenceExpression.foldable || !seedExpression.foldable) { TypeCheckFailure( - "The eps, confidence or seed provided must be a literal or constant foldable") + "The eps, confidence or seed provided must be a literal or foldable") } else if (epsExpression.eval() == null || confidenceExpression.eval() == null || seedExpression.eval() == null) { TypeCheckFailure("The eps, confidence or seed provided should not be null") - } else if (eps <= 0D) { + } else if (eps <= 0.0) { TypeCheckFailure(s"Relative error must be positive (current value = $eps)") - } else if (confidence <= 0D || confidence >= 1D) { + } else if (confidence <= 0.0 || confidence >= 1.0) { TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)") } else { TypeCheckSuccess @@ -97,9 +95,6 @@ case class CountMinSketchAgg( // Ignore empty rows if (value != null) { child.dataType match { - // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them - // into acceptable types for `CountMinSketch`. - case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal) // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` // instead of `addString` to avoid unnecessary conversion. case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes) @@ -115,14 +110,11 @@ case class CountMinSketchAgg( override def eval(buffer: CountMinSketch): Any = serialize(buffer) override def serialize(buffer: CountMinSketch): Array[Byte] = { - val out = new ByteArrayOutputStream() - buffer.writeTo(out) - out.toByteArray + buffer.toByteArray } override def deserialize(storageFormat: Array[Byte]): CountMinSketch = { - val in = new ByteArrayInputStream(storageFormat) - CountMinSketch.readFrom(in) + CountMinSketch.readFrom(storageFormat) } override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg = @@ -132,8 +124,7 @@ case class CountMinSketchAgg( copy(inputAggBufferOffset = newInputAggBufferOffset) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType), - DoubleType, DoubleType, IntegerType) + Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType) } override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 8456e244609bc988e49f1801883f82d9f0f923f2..fcb370ae8460f3fb2a103f8e135676b61b9f5719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -86,7 +86,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { (headBufferSize + bufferSize) * 2 } - val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count => + Seq(100, 1000, 10000, 100000, 1000000, 10000000).foreach { count => val buffer = new PercentileDigest(relativeError) // Worst case, data is linear sorted (0 until count).foreach(buffer.add(_)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala index 6e08e29c0449f6280971e5e3fc83cf6935966766..10479630f3f9950e0e49e7a3ed623374f7314b26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala @@ -17,199 +17,114 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import java.io.ByteArrayInputStream -import java.nio.charset.StandardCharsets +import java.{lang => jl} -import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal} -import org.apache.spark.sql.types.{DecimalType, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch +/** + * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]]. + */ class CountMinSketchAggSuite extends SparkFunSuite { private val childExpression = BoundReference(0, IntegerType, nullable = true) private val epsOfTotalCount = 0.0001 private val confidence = 0.99 private val seed = 42 - - test("serialize and de-serialize") { - // Check empty serialize and de-serialize - val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence), - Literal(seed)) - val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed) - assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) - - // Check non-empty serialize and de-serialize - val random = new Random(31) - (0 until 10000).map(_ => random.nextInt(100)).foreach { value => - buffer.add(value) - } - assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) + private val rand = new Random(seed) + + /** Creates a count-min sketch aggregate expression, using the child expression defined above. */ + private def cms(eps: jl.Double, confidence: jl.Double, seed: jl.Integer): CountMinSketchAgg = { + new CountMinSketchAgg( + child = childExpression, + epsExpression = Literal(eps, DoubleType), + confidenceExpression = Literal(confidence, DoubleType), + seedExpression = Literal(seed, IntegerType)) } - def testHighLevelInterface[T: ClassTag]( - dataType: DataType, - sampledItemIndices: Array[Int], - allItems: Array[T], - exactFreq: Map[Any, Long]): Any = { - test(s"high level interface, update, merge, eval... - $dataType") { + /** + * Creates a new test case that compares our aggregate function with a reference implementation + * (using the underlying [[CountMinSketch]]). + * + * This works by splitting the items into two separate groups, aggregates them, and then merges + * the two groups back (to emulate partial aggregation), and then compares the result with + * that generated by [[CountMinSketch]] directly. This assumes insertion order does not impact + * the result in count-min sketch. + */ + private def testDataType[T](dataType: DataType, items: Seq[T]): Unit = { + test("test data type " + dataType) { val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true), Literal(epsOfTotalCount), Literal(confidence), Literal(seed)) assert(!agg.nullable) - val group1 = 0 until sampledItemIndices.length / 2 - val group1Buffer = agg.createAggregationBuffer() - group1.foreach { index => - val input = InternalRow(allItems(sampledItemIndices(index))) - agg.update(group1Buffer, input) + val (seq1, seq2) = items.splitAt(items.size / 2) + val buf1 = addToAggregateBuffer(agg, seq1) + val buf2 = addToAggregateBuffer(agg, seq2) + + val sketch = agg.createAggregationBuffer() + agg.merge(sketch, buf1) + agg.merge(sketch, buf2) + + // Validate cardinality estimation against reference implementation. + val referenceSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + items.foreach { item => + referenceSketch.add(item match { + case u: UTF8String => u.getBytes + case _ => item + }) } - val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length - val group2Buffer = agg.createAggregationBuffer() - group2.foreach { index => - val input = InternalRow(allItems(sampledItemIndices(index))) - agg.update(group2Buffer, input) + items.foreach { item => + withClue(s"For item $item") { + val itemToTest = item match { + case u: UTF8String => u.getBytes + case _ => item + } + assert(referenceSketch.estimateCount(itemToTest) == sketch.estimateCount(itemToTest)) + } } - - var mergeBuffer = agg.createAggregationBuffer() - agg.merge(mergeBuffer, group1Buffer) - agg.merge(mergeBuffer, group2Buffer) - checkResult(agg.eval(mergeBuffer), allItems, exactFreq) - - // Merge in a different order - mergeBuffer = agg.createAggregationBuffer() - agg.merge(mergeBuffer, group2Buffer) - agg.merge(mergeBuffer, group1Buffer) - checkResult(agg.eval(mergeBuffer), allItems, exactFreq) - - // Merge with an empty partition - val emptyBuffer = agg.createAggregationBuffer() - agg.merge(mergeBuffer, emptyBuffer) - checkResult(agg.eval(mergeBuffer), allItems, exactFreq) } - } - - def testLowLevelInterface[T: ClassTag]( - dataType: DataType, - sampledItemIndices: Array[Int], - allItems: Array[T], - exactFreq: Map[Any, Long]): Any = { - test(s"low level interface, update, merge, eval... - ${dataType.typeName}") { - val inputAggregationBufferOffset = 1 - val mutableAggregationBufferOffset = 2 - // Phase one, partial mode aggregation - val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true), - Literal(epsOfTotalCount), Literal(confidence), Literal(seed)) - .withNewInputAggBufferOffset(inputAggregationBufferOffset) - .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) - - val mutableAggBuffer = new GenericInternalRow( - new Array[Any](mutableAggregationBufferOffset + 1)) - agg.initialize(mutableAggBuffer) - - sampledItemIndices.foreach { i => - agg.update(mutableAggBuffer, InternalRow(allItems(i))) - } - agg.serializeAggregateBufferInPlace(mutableAggBuffer) - - // Serialize the aggregation buffer - val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) - val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized)) - - // Phase 2: final mode aggregation - // Re-initialize the aggregation buffer - agg.initialize(mutableAggBuffer) - agg.merge(mutableAggBuffer, inputAggBuffer) - checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq) + def addToAggregateBuffer[T](agg: CountMinSketchAgg, items: Seq[T]): CountMinSketch = { + val buf = agg.createAggregationBuffer() + items.foreach { item => agg.update(buf, InternalRow(item)) } + buf } } - private def checkResult[T: ClassTag]( - result: Any, - data: Array[T], - exactFreq: Map[Any, Long]): Unit = { - result match { - case bytesData: Array[Byte] => - val in = new ByteArrayInputStream(bytesData) - val cms = CountMinSketch.readFrom(in) - val probCorrect = { - val numErrors = data.map { i => - val count = exactFreq.getOrElse(getProbeItem(i), 0L) - val item = i match { - case dec: Decimal => dec.toJavaBigDecimal - case str: UTF8String => str.getBytes - case _ => i - } - val ratio = (cms.estimateCount(item) - count).toDouble / data.length - if (ratio > epsOfTotalCount) 1 else 0 - }.sum + testDataType[Byte](ByteType, Seq.fill(100) { rand.nextInt(10).toByte }) - 1D - numErrors.toDouble / data.length - } + testDataType[Short](ShortType, Seq.fill(100) { rand.nextInt(10).toShort }) - assert( - probCorrect > confidence, - s"Confidence not reached: required $confidence, reached $probCorrect" - ) - case _ => fail("unexpected return type") - } - } + testDataType[Int](IntegerType, Seq.fill(100) { rand.nextInt(10) }) - private def getProbeItem[T: ClassTag](item: T): Any = item match { - // Use a string to represent the content of an array of bytes - case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) - case i => identity(i) - } + testDataType[Long](LongType, Seq.fill(100) { rand.nextInt(10) }) - def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = { - // Uses fixed seed to ensure reproducible test execution - val r = new Random(31) + testDataType[UTF8String](StringType, Seq.fill(100) { UTF8String.fromString(rand.nextString(1)) }) - val numAllItems = 1000000 - val allItems = Array.fill(numAllItems)(itemGenerator(r)) + testDataType[Array[Byte]](BinaryType, Seq.fill(100) { rand.nextString(1).getBytes() }) - val numSamples = numAllItems / 10 - val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + test("serialize and de-serialize") { + // Check empty serialize and de-serialize + val agg = cms(epsOfTotalCount, confidence, seed) + val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed) + assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) - val exactFreq = { - val sampledItems = sampledItemIndices.map(allItems) - sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) + // Check non-empty serialize and de-serialize + val random = new Random(31) + for (i <- 0 until 10) { + buffer.add(random.nextInt(100)) } - - testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq) - testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq) - } - - testItemType[Byte](ByteType) { _.nextInt().toByte } - - testItemType[Short](ShortType) { _.nextInt().toShort } - - testItemType[Int](IntegerType) { _.nextInt() } - - testItemType[Long](LongType) { _.nextLong() } - - testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) } - - testItemType[Float](FloatType) { _.nextFloat() } - - testItemType[Double](DoubleType) { _.nextDouble() } - - testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) } - - testItemType[Boolean](BooleanType) { _.nextBoolean() } - - testItemType[Array[Byte]](BinaryType) { r => - r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8) + assert(buffer.equals(agg.deserialize(agg.serialize(buffer)))) } - - test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") { + test("fails analysis if eps, confidence or seed provided is not foldable") { val wrongEps = new CountMinSketchAgg( childExpression, epsExpression = AttributeReference("a", DoubleType)(), @@ -227,88 +142,55 @@ class CountMinSketchAggSuite extends SparkFunSuite { seedExpression = AttributeReference("c", IntegerType)()) Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg => - assertEqual( - wrongAgg.checkInputDataTypes(), - TypeCheckFailure( - "The eps, confidence or seed provided must be a literal or constant foldable") - ) + assertResult( + TypeCheckFailure("The eps, confidence or seed provided must be a literal or foldable")) { + wrongAgg.checkInputDataTypes() + } } } test("fails analysis if parameters are invalid") { // parameters are null - val wrongEps = new CountMinSketchAgg( - childExpression, - epsExpression = Cast(Literal(null), DoubleType), - confidenceExpression = Literal(confidence), - seedExpression = Literal(seed)) - val wrongConfidence = new CountMinSketchAgg( - childExpression, - epsExpression = Literal(epsOfTotalCount), - confidenceExpression = Cast(Literal(null), DoubleType), - seedExpression = Literal(seed)) - val wrongSeed = new CountMinSketchAgg( - childExpression, - epsExpression = Literal(epsOfTotalCount), - confidenceExpression = Literal(confidence), - seedExpression = Cast(Literal(null), IntegerType)) + val wrongEps = cms(null, confidence, seed) + val wrongConfidence = cms(epsOfTotalCount, null, seed) + val wrongSeed = cms(epsOfTotalCount, confidence, null) Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg => - assertEqual( - wrongAgg.checkInputDataTypes(), - TypeCheckFailure("The eps, confidence or seed provided should not be null") - ) + assertResult(TypeCheckFailure("The eps, confidence or seed provided should not be null")) { + wrongAgg.checkInputDataTypes() + } } // parameters are out of the valid range Seq(0.0, -1000.0).foreach { invalidEps => - val invalidAgg = new CountMinSketchAgg( - childExpression, - epsExpression = Literal(invalidEps), - confidenceExpression = Literal(confidence), - seedExpression = Literal(seed)) - assertEqual( - invalidAgg.checkInputDataTypes(), - TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)") - ) + val invalidAgg = cms(invalidEps, confidence, seed) + assertResult( + TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")) { + invalidAgg.checkInputDataTypes() + } } Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence => - val invalidAgg = new CountMinSketchAgg( - childExpression, - epsExpression = Literal(epsOfTotalCount), - confidenceExpression = Literal(invalidConfidence), - seedExpression = Literal(seed)) - assertEqual( - invalidAgg.checkInputDataTypes(), - TypeCheckFailure( - s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)") - ) + val invalidAgg = cms(epsOfTotalCount, invalidConfidence, seed) + assertResult(TypeCheckFailure( + s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")) { + invalidAgg.checkInputDataTypes() + } } } - private def assertEqual[T](left: T, right: T): Unit = { - assert(left == right) - } - test("null handling") { def isEqual(result: Any, other: CountMinSketch): Boolean = { - result match { - case bytesData: Array[Byte] => - val in = new ByteArrayInputStream(bytesData) - val cms = CountMinSketch.readFrom(in) - cms.equals(other) - case _ => fail("unexpected return type") - } + other.equals(CountMinSketch.readFrom(result.asInstanceOf[Array[Byte]])) } - val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence), - Literal(seed)) + val agg = cms(epsOfTotalCount, confidence, seed) val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed) val buffer = new GenericInternalRow(new Array[Any](1)) agg.initialize(buffer) // Empty aggregation buffer assert(isEqual(agg.eval(buffer), emptyCms)) + // Empty input row agg.update(buffer, InternalRow(null)) assert(isEqual(agg.eval(buffer), emptyCms)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index e98092df49518b691bdd2eeb993c6ad75e04d14a..62a75343a094608637771bf1bdf3cd2d328c5b07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -21,6 +21,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.test.SharedSQLContext +/** + * End-to-end tests for approximate percentile aggregate function. + */ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala index 3e715a393e530f63c686344bbd87201fb221dde5..dea0d4c0c6d405224ef58556b4515ccc7f856ab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala @@ -17,175 +17,29 @@ package org.apache.spark.sql -import java.io.ByteArrayInputStream -import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} - -import scala.reflect.ClassTag -import scala.util.Random - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{Decimal, StringType, _} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch +/** + * End-to-end test suite for count_min_sketch. + */ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext { - private val table = "count_min_sketch_table" - - /** Uses fixed seed to ensure reproducible test execution */ - private val r = new Random(42) - private val numAllItems = 1000 - private val numSamples = numAllItems / 10 - - private val eps = 0.1D - private val confidence = 0.95D - private val seed = 11 - - val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01")) - val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01")) - val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 00:00:00")) - val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 00:00:00")) - - test(s"compute count-min sketch for multiple columns of different types") { - val (allBytes, sampledByteIndices, exactByteFreq) = - generateTestData[Byte] { _.nextInt().toByte } - val (allShorts, sampledShortIndices, exactShortFreq) = - generateTestData[Short] { _.nextInt().toShort } - val (allInts, sampledIntIndices, exactIntFreq) = - generateTestData[Int] { _.nextInt() } - val (allLongs, sampledLongIndices, exactLongFreq) = - generateTestData[Long] { _.nextLong() } - val (allStrings, sampledStringIndices, exactStringFreq) = - generateTestData[String] { r => r.nextString(r.nextInt(20)) } - val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] { r => - DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate) - } - val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r => - DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS) - } - val (allFloats, sampledFloatIndices, exactFloatFreq) = - generateTestData[Float] { _.nextFloat() } - val (allDoubles, sampledDoubleIndices, exactDoubleFreq) = - generateTestData[Double] { _.nextDouble() } - val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) = - generateTestData[Decimal] { r => Decimal(r.nextDouble()) } - val (allBooleans, sampledBooleanIndices, exactBooleanFreq) = - generateTestData[Boolean] { _.nextBoolean() } - val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r => - r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8) - } - - val data = (0 until numSamples).map { i => - Row(allBytes(sampledByteIndices(i)), - allShorts(sampledShortIndices(i)), - allInts(sampledIntIndices(i)), - allLongs(sampledLongIndices(i)), - allStrings(sampledStringIndices(i)), - allDates(sampledDateIndices(i)), - allTimestamps(sampledTSIndices(i)), - allFloats(sampledFloatIndices(i)), - allDoubles(sampledDoubleIndices(i)), - allDeciamls(sampledDecimalIndices(i)), - allBooleans(sampledBooleanIndices(i)), - allBinaries(sampledBinaryIndices(i))) - } + test("count-min sketch") { + import testImplicits._ - val schema = StructType(Seq( - StructField("c1", ByteType), - StructField("c2", ShortType), - StructField("c3", IntegerType), - StructField("c4", LongType), - StructField("c5", StringType), - StructField("c6", DateType), - StructField("c7", TimestampType), - StructField("c8", FloatType), - StructField("c9", DoubleType), - StructField("c10", new DecimalType()), - StructField("c11", BooleanType), - StructField("c12", BinaryType))) + val eps = 0.1 + val confidence = 0.95 + val seed = 11 - withTempView(table) { - val rdd: RDD[Row] = spark.sparkContext.parallelize(data) - spark.createDataFrame(rdd, schema).createOrReplaceTempView(table) + val items = Seq(1, 1, 2, 2, 2, 2, 3, 4, 5) + val sketch = CountMinSketch.readFrom(items.toDF("id") + .selectExpr(s"count_min_sketch(id, ${eps}d, ${confidence}d, $seed)") + .head().get(0).asInstanceOf[Array[Byte]]) - val cmsSql = schema.fieldNames.map { col => - s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)" - } - val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head() - schema.indices.foreach { i => - val binaryData = result.getAs[Array[Byte]](i) - val in = new ByteArrayInputStream(binaryData) - val cms = CountMinSketch.readFrom(in) - schema.fields(i).dataType match { - case ByteType => checkResult(cms, allBytes, exactByteFreq) - case ShortType => checkResult(cms, allShorts, exactShortFreq) - case IntegerType => checkResult(cms, allInts, exactIntFreq) - case LongType => checkResult(cms, allLongs, exactLongFreq) - case StringType => checkResult(cms, allStrings, exactStringFreq) - case DateType => - checkResult(cms, - allDates.map(DateTimeUtils.fromJavaDate), - exactDateFreq.map { e => - (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2) - }) - case TimestampType => - checkResult(cms, - allTimestamps.map(DateTimeUtils.fromJavaTimestamp), - exactTSFreq.map { e => - (DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2) - }) - case FloatType => checkResult(cms, allFloats, exactFloatFreq) - case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq) - case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq) - case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq) - case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq) - } - } - } - } - - private def checkResult[T: ClassTag]( - cms: CountMinSketch, - data: Array[T], - exactFreq: Map[Any, Long]): Unit = { - val probCorrect = { - val numErrors = data.map { i => - val count = exactFreq.getOrElse(getProbeItem(i), 0L) - val item = i match { - case dec: Decimal => dec.toJavaBigDecimal - case str: UTF8String => str.getBytes - case _ => i - } - val ratio = (cms.estimateCount(item) - count).toDouble / data.length - if (ratio > eps) 1 else 0 - }.sum - - 1D - numErrors.toDouble / data.length - } - - assert( - probCorrect > confidence, - s"Confidence not reached: required $confidence, reached $probCorrect" - ) - } - - private def getProbeItem[T: ClassTag](item: T): Any = item match { - // Use a string to represent the content of an array of bytes - case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8) - case i => identity(i) - } + val reference = CountMinSketch.create(eps, confidence, seed) + items.foreach(reference.add) - private def generateTestData[T: ClassTag]( - itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = { - val allItems = Array.fill(numAllItems)(itemGenerator(r)) - val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) - val exactFreq = { - val sampledItems = sampledItemIndices.map(allItems) - sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong) - } - (allItems, sampledItemIndices, exactFreq) + assert(sketch == reference) } }