Skip to content
Snippets Groups Projects
Commit c33b8dcb authored by larvaboy's avatar larvaboy Committed by Reynold Xin
Browse files

Implement ApproximateCountDistinct for SparkSql

Add the implementation for ApproximateCountDistinct to SparkSql. We use the HyperLogLog algorithm implemented in stream-lib, and do the count in two phases: 1) counting the number of distinct elements in each partitions, and 2) merge the HyperLogLog results from different partitions.

A simple serializer and test cases are added as well.

Author: larvaboy <larvaboy@gmail.com>

Closes #737 from larvaboy/master and squashes the following commits:

bd8ef3f [larvaboy] Add support of user-provided standard deviation to ApproxCountDistinct.
9ba8360 [larvaboy] Fix alignment and null handling issues.
95b4067 [larvaboy] Add a test case for count distinct and approximate count distinct.
f57917d [larvaboy] Add the parser for the approximate count.
a2d5d10 [larvaboy] Add ApproximateCountDistinct aggregates and functions.
7ad273a [larvaboy] Add SparkSql serializer for HyperLogLog.
1d9aacf [larvaboy] Fix a minor typo in the toString method of the Count case class.
653542b [larvaboy] Fix a couple of minor typos.
parent 92cebada
No related branches found
No related tags found
No related merge requests found
......@@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. Uses the provided
* more accurate counts but increase the memory footprint and vice versa. Uses the provided
* Partitioner to partition the output RDD.
*/
def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
......@@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
* more accurate counts but increase the memory footprint and vice versa. HashPartitions the
* output RDD into numPartitions.
*
*/
......@@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. The default value of
* more accurate counts but increase the memory footprint and vice versa. The default value of
* relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
* level.
*/
......
......@@ -93,6 +93,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val AND = Keyword("AND")
protected val AS = Keyword("AS")
protected val ASC = Keyword("ASC")
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
protected val BY = Keyword("BY")
protected val CAST = Keyword("CAST")
......@@ -318,6 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
case exp => ApproxCountDistinct(exp)
} |
APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
} |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
......@@ -146,7 +148,6 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: Row): Any = currentMax
}
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
......@@ -166,10 +167,47 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
override def newInstance() = new CountDistinctFunction(expressions, this)
}
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = child.dataType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def asPartial: SplitEvaluation = {
val partialCount =
Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
SplitEvaluation(
ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
partialCount :: Nil)
}
override def newInstance() = new CountDistinctFunction(child :: Nil, this)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
......@@ -269,6 +307,42 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: Row): Any = count
}
case class ApproxCountDistinctPartitionFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
hyperLogLog.offer(evaluatedExpr)
}
}
override def eval(input: Row): Any = hyperLogLog
}
case class ApproxCountDistinctMergeFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
}
override def eval(input: Row): Any = hyperLogLog.cardinality()
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
......
......@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
......@@ -44,6 +45,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
......@@ -81,6 +84,20 @@ private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
}
}
private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
val bytes = hyperLogLog.getBytes()
output.writeInt(bytes.length)
output.writeBytes(bytes)
}
def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
val length = input.readInt()
val bytes = input.readBytes(length)
HyperLogLog.Builder.build(bytes)
}
}
/**
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
* them as `Array[(k,v)]`.
......
......@@ -96,8 +96,25 @@ class SQLQuerySuite extends QueryTest {
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
testData2.count()
)
testData2.count())
}
test("count distinct") {
checkAnswer(
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
2)
}
test("approximate count distinct") {
checkAnswer(
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
3)
}
test("approximate count distinct with user provided standard deviation") {
checkAnswer(
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
3)
}
// No support for primitive nulls yet.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment