diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b0e2ffaa60687adaec107bd8d2630b5625eb3c8a..2e679e7bc4e0ad7100f9885843792b96e52c507b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -83,13 +83,8 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("average") { checkAnswer( - testData2.agg(avg('a)), - Row(2.0)) - - // Also check mean - checkAnswer( - testData2.agg(mean('a)), - Row(2.0)) + testData2.agg(avg('a), mean('a)), + Row(2.0, 2.0)) checkAnswer( testData2.agg(avg('a), sumDistinct('a)), // non-partial @@ -98,6 +93,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) + checkAnswer( decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) @@ -168,44 +164,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - checkAnswer( emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("stddev") { - val testData2ADev = math.sqrt(4/5.0) - + val testData2ADev = math.sqrt(4 / 5.0) checkAnswer( - testData2.agg(stddev('a)), - Row(testData2ADev)) - - checkAnswer( - testData2.agg(stddev_pop('a)), - Row(math.sqrt(4/6.0))) - - checkAnswer( - testData2.agg(stddev_samp('a)), - Row(testData2ADev)) + testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() == 0) - - checkAnswer( - emptyTableData.agg(stddev('a)), - Row(null)) - checkAnswer( - emptyTableData.agg(stddev_pop('a)), - Row(null)) - - checkAnswer( - emptyTableData.agg(stddev_samp('a)), - Row(null)) + emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(null, null, null)) } test("zero sum") { @@ -227,6 +202,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) + val sparkVariancePop = testData2.agg(var_pop('a)) checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol) @@ -241,52 +217,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("zero moments") { - val emptyTableData = Seq((1, 2)).toDF("a", "b") - assert(emptyTableData.count() === 1) - - checkAnswer( - emptyTableData.agg(variance('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) - + val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - emptyTableData.agg(var_pop('a)), - Row(0.0)) + input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) checkAnswer( - emptyTableData.agg(skewness('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + input.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) } test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - - checkAnswer( - emptyTableData.agg(variance('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_pop('a)), - Row(Double.NaN)) checkAnswer( - emptyTableData.agg(skewness('a)), - Row(Double.NaN)) + emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) checkAnswer( - emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + emptyTableData.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5731a356243e5efefe229a9a5555225fcdd741cd..3de277a79a52cd76dcc690ec7a731e8fb4587ae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -726,83 +726,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("stddev") { - checkAnswer( - sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4.0 / 5.0)) - ) - } - - test("stddev_pop") { - checkAnswer( - sql("SELECT STDDEV_POP(a) FROM testData2"), - Row(math.sqrt(4.0 / 6.0)) - ) - } - - test("stddev_samp") { - checkAnswer( - sql("SELECT STDDEV_SAMP(a) FROM testData2"), - Row(math.sqrt(4/5.0)) - ) - } - - test("var_samp") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") - val expectedAnswer = Row(4.0 / 5.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("variance") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Row(0.8) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("var_pop") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") - val expectedAnswer = Row(4.0 / 6.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("skewness") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT skewness(a) FROM testData2") - val expectedAnswer = Row(0.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("kurtosis") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") - val expectedAnswer = Row(-1.5) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("stddev agg") { - checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) - } - - test("variance agg") { - val absTol = 1e-8 - checkAggregatesWithTol( - sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)), - absTol) - } - - test("skewness and kurtosis agg") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a") - val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0)) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e12e6bea302605c019a225068051c2c13236ed26..e2090b0a83ce7df777d034ee5d60fd863e7e4aea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.Decimal class StringFunctionsSuite extends QueryTest with SharedSQLContext {