From 5ecdc7c5c019acc6b1f9c2e6c5b7d35957eadb88 Mon Sep 17 00:00:00 2001
From: Zhenhua Wang <wzh_zju@163.com>
Date: Fri, 25 Nov 2016 05:02:48 -0800
Subject: [PATCH] [SPARK-18559][SQL] Fix HLL++ with small relative error

## What changes were proposed in this pull request?

In `HyperLogLogPlusPlus`, if the relative error is so small that p >= 19, it will cause ArrayIndexOutOfBoundsException in `THRESHOLDS(p-4)` . We should check `p` and when p >= 19, regress to the original HLL result and use the small range correction they use.

The pr also fixes the upper bound in the log info in `require()`.
The upper bound is computed by:
```
val relativeSD = 1.106d / Math.pow(Math.E, p * Math.log(2.0d) / 2.0d)
```
which is derived from the equation for computing `p`:
```
val p = 2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)
```

## How was this patch tested?

add test cases for:
1. checking validity of parameter relatvieSD
2. estimation with smaller relative error so that p >= 19

Author: Zhenhua Wang <wzh_zju@163.com>
Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #15990 from wzhfy/hllppRsd.
---
 .../expressions/aggregate/HyperLogLogPlusPlus.scala      | 9 ++++++---
 .../expressions/aggregate/HyperLogLogPlusPlusSuite.scala | 9 ++++++++-
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index b9862aa04f..77b7eb228e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -93,7 +93,7 @@ case class HyperLogLogPlusPlus(
   private[this] val p = Math.ceil(2.0d * Math.log(1.106d / relativeSD) / Math.log(2.0d)).toInt
 
   require(p >= 4, "HLL++ requires at least 4 bits for addressing. " +
-    "Use a lower error, at most 27%.")
+    "Use a lower error, at most 39%.")
 
   /**
    * Shift used to extract the index of the register from the hashed value.
@@ -296,8 +296,9 @@ case class HyperLogLogPlusPlus(
     // We integrate two steps from the paper:
     // val Z = 1.0d / zInverse
     // val E = alphaM2 * Z
+    val E = alphaM2 / zInverse
     @inline
-    def EBiasCorrected = alphaM2 / zInverse match {
+    def EBiasCorrected = E match {
       case e if p < 19 && e < 5.0d * m => e - estimateBias(e)
       case e => e
     }
@@ -306,7 +307,9 @@ case class HyperLogLogPlusPlus(
     val estimate = if (V > 0) {
       // Use linear counting for small cardinality estimates.
       val H = m * Math.log(m / V)
-      if (H <= THRESHOLDS(p - 4)) {
+      // HLL++ is defined only when p < 19, otherwise we need to fallback to HLL.
+      // The threshold `2.5 * m` is from the original HLL algorithm.
+      if ((p < 19 && H <= THRESHOLDS(p - 4)) || E <= 2.5 * m) {
         H
       } else {
         EBiasCorrected
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
index 17f6b71bb2..cc53880af5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
@@ -50,6 +50,13 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
     assert(error < hll.trueRsd * 3.0d, "Error should be within 3 std. errors.")
   }
 
+  test("test invalid parameter relativeSD") {
+    // `relativeSD` should be at most 39%.
+    intercept[IllegalArgumentException] {
+      new HyperLogLogPlusPlus(new BoundReference(0, IntegerType, true), relativeSD = 0.4)
+    }
+  }
+
   test("add nulls") {
     val (hll, input, buffer) = createEstimator(0.05)
     input.setNullAt(0)
@@ -83,7 +90,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
   test("deterministic cardinality estimation") {
     val repeats = 10
     testCardinalityEstimates(
-      Seq(0.1, 0.05, 0.025, 0.01),
+      Seq(0.1, 0.05, 0.025, 0.01, 0.001),
       Seq(100, 500, 1000, 5000, 10000, 50000, 100000, 500000, 1000000).map(_ * repeats),
       i => i / repeats,
       i => i / repeats)
-- 
GitLab