diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 5efb4388f6c71e053023dfb8cf5a9d19f9c8875d..bc6d204434ad8962e947c869b9bcfc4cb3012551 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return approximate number of distinct values for each key in this RDD. * The accuracy of approximation can be controlled through the relative standard deviation * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in - * more accurate counts but increase the memory footprint and vise versa. Uses the provided + * more accurate counts but increase the memory footprint and vice versa. Uses the provided * Partitioner to partition the output RDD. */ def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = { @@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return approximate number of distinct values for each key in this RDD. * The accuracy of approximation can be controlled through the relative standard deviation * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in - * more accurate counts but increase the memory footprint and vise versa. HashPartitions the + * more accurate counts but increase the memory footprint and vice versa. HashPartitions the * output RDD into numPartitions. * */ @@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return approximate number of distinct values for each key this RDD. * The accuracy of approximation can be controlled through the relative standard deviation * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in - * more accurate counts but increase the memory footprint and vise versa. The default value of + * more accurate counts but increase the memory footprint and vice versa. The default value of * relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism * level. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b3a3a1ef1b5eb5d13289d66b37fb09c3355aed26..f2b9b2c1a3ad5055f794f2cc82b841aab0d6df25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -93,6 +93,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val AND = Keyword("AND") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") + protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AVG = Keyword("AVG") protected val BY = Keyword("BY") protected val CAST = Keyword("CAST") @@ -318,6 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | + APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { + case exp => ApproxCountDistinct(exp) + } | + APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { + case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) + } | FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 7777d372903e23763747b596df3b669b1d7cd7f0..5dbaaa3b0ce35b012b4f060b6b2bd575ef8dcff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException @@ -146,7 +148,6 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: Row): Any = currentMax } - case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def references = child.references override def nullable = false @@ -166,10 +167,47 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType - override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" + override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" override def newInstance() = new CountDistinctFunction(expressions, this) } +case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) + extends AggregateExpression with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = false + override def dataType = child.dataType + override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD) +} + +case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) + extends AggregateExpression with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = false + override def dataType = IntegerType + override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD) +} + +case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) + extends PartialAggregate with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = false + override def dataType = IntegerType + override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + + override def asPartial: SplitEvaluation = { + val partialCount = + Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() + + SplitEvaluation( + ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), + partialCount :: Nil) + } + + override def newInstance() = new CountDistinctFunction(child :: Nil, this) +} + case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def references = child.references override def nullable = false @@ -269,6 +307,42 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: Row): Any = count } +case class ApproxCountDistinctPartitionFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: Row): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + hyperLogLog.offer(evaluatedExpr) + } + } + + override def eval(input: Row): Any = hyperLogLog +} + +case class ApproxCountDistinctMergeFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: Row): Unit = { + val evaluatedExpr = expr.eval(input) + hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + } + + override def eval(input: Row): Any = hyperLogLog.cardinality() +} + case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 1c6e29b3cdee972e11d556d7b37afe4ae0e47c73..94c2a249ef8f87e39a3f420beb55995ae61f28e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag +import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} @@ -44,6 +45,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) + kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], + new HyperLogLogSerializer) kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]]) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) kryo.setReferences(false) @@ -81,6 +84,20 @@ private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] { } } +private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { + def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) { + val bytes = hyperLogLog.getBytes() + output.writeInt(bytes.length) + output.writeBytes(bytes) + } + + def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = { + val length = input.readInt() + val bytes = input.readBytes(length) + HyperLogLog.Builder.build(bytes) + } +} + /** * Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize * them as `Array[(k,v)]`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e966d89c30cf50435145fe4f8f029d101718422f..524549eb544fc67cd34571e1c6da04ca4e312519 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -96,8 +96,25 @@ class SQLQuerySuite extends QueryTest { test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), - testData2.count() - ) + testData2.count()) + } + + test("count distinct") { + checkAnswer( + sql("SELECT COUNT(DISTINCT b) FROM testData2"), + 2) + } + + test("approximate count distinct") { + checkAnswer( + sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), + 3) + } + + test("approximate count distinct with user provided standard deviation") { + checkAnswer( + sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), + 3) } // No support for primitive nulls yet.