From b6e0a5ae6f243139f11c9cbbf18cddd3f25db208 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Wed, 4 Nov 2015 16:49:25 -0800
Subject: [PATCH] [SPARK-11510][SQL] Remove SQL aggregation tests for higher
 order statistics

We have some aggregate function tests in both DataFrameAggregateSuite and SQLQuerySuite. The two have almost the same coverage and we should just remove the SQL one.

Author: Reynold Xin <rxin@databricks.com>

Closes #9475 from rxin/SPARK-11510.
---
 .../spark/sql/DataFrameAggregateSuite.scala   | 97 ++++++-------------
 .../org/apache/spark/sql/SQLQuerySuite.scala  | 77 ---------------
 .../spark/sql/StringFunctionsSuite.scala      |  1 -
 3 files changed, 28 insertions(+), 147 deletions(-)

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index b0e2ffaa60..2e679e7bc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -83,13 +83,8 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
 
   test("average") {
     checkAnswer(
-      testData2.agg(avg('a)),
-      Row(2.0))
-
-    // Also check mean
-    checkAnswer(
-      testData2.agg(mean('a)),
-      Row(2.0))
+      testData2.agg(avg('a), mean('a)),
+      Row(2.0, 2.0))
 
     checkAnswer(
       testData2.agg(avg('a), sumDistinct('a)), // non-partial
@@ -98,6 +93,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
     checkAnswer(
       decimalData.agg(avg('a)),
       Row(new java.math.BigDecimal(2.0)))
+
     checkAnswer(
       decimalData.agg(avg('a), sumDistinct('a)), // non-partial
       Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
@@ -168,44 +164,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
 
   test("zero count") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
-    assert(emptyTableData.count() === 0)
-
     checkAnswer(
       emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
       Row(0, null))
   }
 
   test("stddev") {
-    val testData2ADev = math.sqrt(4/5.0)
-
+    val testData2ADev = math.sqrt(4 / 5.0)
     checkAnswer(
-      testData2.agg(stddev('a)),
-      Row(testData2ADev))
-
-    checkAnswer(
-      testData2.agg(stddev_pop('a)),
-      Row(math.sqrt(4/6.0)))
-
-    checkAnswer(
-      testData2.agg(stddev_samp('a)),
-      Row(testData2ADev))
+      testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
+      Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
   }
 
   test("zero stddev") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
-    assert(emptyTableData.count() == 0)
-
-    checkAnswer(
-    emptyTableData.agg(stddev('a)),
-    Row(null))
-
     checkAnswer(
-    emptyTableData.agg(stddev_pop('a)),
-    Row(null))
-
-    checkAnswer(
-    emptyTableData.agg(stddev_samp('a)),
-    Row(null))
+    emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
+    Row(null, null, null))
   }
 
   test("zero sum") {
@@ -227,6 +202,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
 
     val sparkVariance = testData2.agg(variance('a))
     checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
+
     val sparkVariancePop = testData2.agg(var_pop('a))
     checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol)
 
@@ -241,52 +217,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
   }
 
   test("zero moments") {
-    val emptyTableData = Seq((1, 2)).toDF("a", "b")
-    assert(emptyTableData.count() === 1)
-
-    checkAnswer(
-      emptyTableData.agg(variance('a)),
-      Row(Double.NaN))
-
-    checkAnswer(
-      emptyTableData.agg(var_samp('a)),
-      Row(Double.NaN))
-
+    val input = Seq((1, 2)).toDF("a", "b")
     checkAnswer(
-      emptyTableData.agg(var_pop('a)),
-      Row(0.0))
+      input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
+      Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
 
     checkAnswer(
-      emptyTableData.agg(skewness('a)),
-      Row(Double.NaN))
-
-    checkAnswer(
-      emptyTableData.agg(kurtosis('a)),
-      Row(Double.NaN))
+      input.agg(
+        expr("variance(a)"),
+        expr("var_samp(a)"),
+        expr("var_pop(a)"),
+        expr("skewness(a)"),
+        expr("kurtosis(a)")),
+      Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN))
   }
 
   test("null moments") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
-    assert(emptyTableData.count() === 0)
-
-    checkAnswer(
-      emptyTableData.agg(variance('a)),
-      Row(Double.NaN))
-
-    checkAnswer(
-      emptyTableData.agg(var_samp('a)),
-      Row(Double.NaN))
-
-    checkAnswer(
-      emptyTableData.agg(var_pop('a)),
-      Row(Double.NaN))
 
     checkAnswer(
-      emptyTableData.agg(skewness('a)),
-      Row(Double.NaN))
+      emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)),
+      Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
 
     checkAnswer(
-      emptyTableData.agg(kurtosis('a)),
-      Row(Double.NaN))
+      emptyTableData.agg(
+        expr("variance(a)"),
+        expr("var_samp(a)"),
+        expr("var_pop(a)"),
+        expr("skewness(a)"),
+        expr("kurtosis(a)")),
+      Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5731a35624..3de277a79a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -726,83 +726,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     }
   }
 
-  test("stddev") {
-    checkAnswer(
-      sql("SELECT STDDEV(a) FROM testData2"),
-      Row(math.sqrt(4.0 / 5.0))
-    )
-  }
-
-  test("stddev_pop") {
-    checkAnswer(
-      sql("SELECT STDDEV_POP(a) FROM testData2"),
-      Row(math.sqrt(4.0 / 6.0))
-    )
-  }
-
-  test("stddev_samp") {
-    checkAnswer(
-      sql("SELECT STDDEV_SAMP(a) FROM testData2"),
-      Row(math.sqrt(4/5.0))
-    )
-  }
-
-  test("var_samp") {
-    val absTol = 1e-8
-    val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2")
-    val expectedAnswer = Row(4.0 / 5.0)
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
-  }
-
-  test("variance") {
-    val absTol = 1e-8
-    val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
-    val expectedAnswer = Row(0.8)
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
-  }
-
-  test("var_pop") {
-    val absTol = 1e-8
-    val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2")
-    val expectedAnswer = Row(4.0 / 6.0)
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
-  }
-
-  test("skewness") {
-    val absTol = 1e-8
-    val sparkAnswer = sql("SELECT skewness(a) FROM testData2")
-    val expectedAnswer = Row(0.0)
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
-  }
-
-  test("kurtosis") {
-    val absTol = 1e-8
-    val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2")
-    val expectedAnswer = Row(-1.5)
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
-  }
-
-  test("stddev agg") {
-    checkAnswer(
-      sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
-      (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
-  }
-
-  test("variance agg") {
-    val absTol = 1e-8
-    checkAggregatesWithTol(
-      sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"),
-      (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)),
-      absTol)
-  }
-
-  test("skewness and kurtosis agg") {
-    val absTol = 1e-8
-    val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a")
-    val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0))
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
-  }
-
   test("inner join where, one match per row") {
     checkAnswer(
       sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index e12e6bea30..e2090b0a83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.Decimal
 
 
 class StringFunctionsSuite extends QueryTest with SharedSQLContext {
-- 
GitLab