diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 96e8ceec6d4c3caf5341a1707697e3748cef4798..86e40a9713b36b1556d37078d525637633957082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -33,8 +33,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function sum") @@ -42,7 +41,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - case _ => child.dataType + case _: IntegralType => LongType + case _ => DoubleType } private lazy val sumDataType = resultType diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql new file mode 100644 index 0000000000000000000000000000000000000000..1f4780abde2d2a9e9ef30c3d9a4fac81b13aff13 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -0,0 +1,27 @@ +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2); +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2); + +-- Simple Union +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t1); + +-- Type Coerced Union +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t2 + UNION ALL + SELECT * FROM t2); + +-- Regression test for SPARK-18622 +SELECT a +FROM (SELECT 0 a, 0 b + UNION ALL + SELECT SUM(1) a, CAST(0 AS BIGINT) b + UNION ALL SELECT 0 a, 0 b) T; + +-- Clean-up +DROP VIEW IF EXISTS t1; +DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..c57028cabe933b1d9d75bcb8eb4352d5087e03d1 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -0,0 +1,80 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t1) +-- !query 2 schema +struct<c1:int,c2:string> +-- !query 2 output +1 a +1 a +2 b +2 b + + +-- !query 3 +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t2 + UNION ALL + SELECT * FROM t2) +-- !query 3 schema +struct<c1:decimal(11,1),c2:string> +-- !query 3 output +1 1 +1 1 +1 a +2 4 +2 4 +2 b + + +-- !query 4 +SELECT a +FROM (SELECT 0 a, 0 b + UNION ALL + SELECT SUM(1) a, CAST(0 AS BIGINT) b + UNION ALL SELECT 0 a, 0 b) T +-- !query 4 schema +struct<a:bigint> +-- !query 4 output +0 +0 +1 + + +-- !query 5 +DROP VIEW IF EXISTS t1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DROP VIEW IF EXISTS t2 +-- !query 6 schema +struct<> +-- !query 6 output +