From 5ecdc7c5c019acc6b1f9c2e6c5b7d35957eadb88 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang <wzh_zju@163.com> Date: Fri, 25 Nov 2016 05:02:48 -0800 Subject: [PATCH] [SPARK-18559][SQL] Fix HLL++ with small relative error ## What changes were proposed in this pull request? In `HyperLogLogPlusPlus`, if the relative error is so small that p >= 19, it will cause ArrayIndexOutOfBoundsException in `THRESHOLDS(p-4)` . We should check `p` and when p >= 19, regress to the original HLL result and use the small range correction they use. The pr also fixes the upper bound in the log info in `require()`. The upper bound is computed by: ``` val relativeSD = 1.106d / Math.pow(Math.E, p * Math.log(2.0d) / 2.0d) ``` which is derived from the equation for computing `p`: ``` val p = 2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d) ``` ## How was this patch tested? add test cases for: 1. checking validity of parameter relatvieSD 2. estimation with smaller relative error so that p >= 19 Author: Zhenhua Wang <wzh_zju@163.com> Author: wangzhenhua <wangzhenhua@huawei.com> Closes #15990 from wzhfy/hllppRsd. --- .../expressions/aggregate/HyperLogLogPlusPlus.scala | 9 ++++++--- .../expressions/aggregate/HyperLogLogPlusPlusSuite.scala | 9 ++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index b9862aa04f..77b7eb228e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -93,7 +93,7 @@ case class HyperLogLogPlusPlus( private[this] val p = Math.ceil(2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)).toInt require(p >= 4, "HLL++ requires at least 4 bits for addressing. " + - "Use a lower error, at most 27%.") + "Use a lower error, at most 39%.") /** * Shift used to extract the index of the register from the hashed value. @@ -296,8 +296,9 @@ case class HyperLogLogPlusPlus( // We integrate two steps from the paper: // val Z = 1.0d / zInverse // val E = alphaM2 * Z + val E = alphaM2 / zInverse @inline - def EBiasCorrected = alphaM2 / zInverse match { + def EBiasCorrected = E match { case e if p < 19 && e < 5.0d * m => e - estimateBias(e) case e => e } @@ -306,7 +307,9 @@ case class HyperLogLogPlusPlus( val estimate = if (V > 0) { // Use linear counting for small cardinality estimates. val H = m * Math.log(m / V) - if (H <= THRESHOLDS(p - 4)) { + // HLL++ is defined only when p < 19, otherwise we need to fallback to HLL. + // The threshold `2.5 * m` is from the original HLL algorithm. + if ((p < 19 && H <= THRESHOLDS(p - 4)) || E <= 2.5 * m) { H } else { EBiasCorrected diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index 17f6b71bb2..cc53880af5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -50,6 +50,13 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { assert(error < hll.trueRsd * 3.0d, "Error should be within 3 std. errors.") } + test("test invalid parameter relativeSD") { + // `relativeSD` should be at most 39%. + intercept[IllegalArgumentException] { + new HyperLogLogPlusPlus(new BoundReference(0, IntegerType, true), relativeSD = 0.4) + } + } + test("add nulls") { val (hll, input, buffer) = createEstimator(0.05) input.setNullAt(0) @@ -83,7 +90,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { test("deterministic cardinality estimation") { val repeats = 10 testCardinalityEstimates( - Seq(0.1, 0.05, 0.025, 0.01), + Seq(0.1, 0.05, 0.025, 0.01, 0.001), Seq(100, 500, 1000, 5000, 10000, 50000, 100000, 500000, 1000000).map(_ * repeats), i => i / repeats, i => i / repeats) -- GitLab