diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 01b04c036d15098a223f082b09a88324b0672f3a..6662a9e974fc2dd7580b64d7fcd9455735ec130f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -528,8 +528,6 @@ object TypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) - - case e: RuntimeReplaceable => e.replaceForTypeCoercion() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 726a231fd814ef0017dd1e9e7b88d4febae7ef7b..221f830aa8583243a2ceedc80006d786b889d91c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -186,7 +186,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyName: String = nodeName.toLowerCase - protected def flatArguments = productIterator.flatMap { + protected def flatArguments: Iterator[Any] = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } @@ -229,26 +229,16 @@ trait Unevaluable extends Expression { * An expression that gets replaced at runtime (currently by the optimizer) into a different * expression for evaluation. This is mainly used to provide compatibility with other databases. * For example, we use this to support "nvl" by replacing it with "coalesce". + * + * A RuntimeReplaceable should have the original parameters along with a "child" expression in the + * case class constructor, and define a normal constructor that accepts only the original + * parameters. For an example, see [[Nvl]]. To make sure the explain plan and expression SQL + * works correctly, the implementation should also override flatArguments method and sql method. */ -trait RuntimeReplaceable extends Unevaluable { - /** - * Method for concrete implementations to override that specifies how to construct the expression - * that should replace the current one. - */ - def replaceForEvaluation(): Expression - - /** - * Method for concrete implementations to override that specifies how to coerce the input types. - */ - def replaceForTypeCoercion(): Expression - - /** The expression that should be used during evaluation. */ - lazy val replaced: Expression = replaceForEvaluation() - - override def nullable: Boolean = replaced.nullable - override def foldable: Boolean = replaced.foldable - override def dataType: DataType = replaced.dataType - override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes() +trait RuntimeReplaceable extends UnaryExpression with Unevaluable { + override def nullable: Boolean = child.nullable + override def foldable: Boolean = child.foldable + override def dataType: DataType = child.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 67c078ae5e264ecf38a3f41d397489b806cd25df..05bfa7dcfc88f9edeab8836ec755fc52751dced2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -488,8 +488,6 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { }""") } } - - override def prettyName: String = "unix_time" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 1c18265e0fed4bef387796b1338f2b2db5270d0b..70862a87ef9c6ef649dada22c2ce48db06e8019e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -89,78 +89,53 @@ case class Coalesce(children: Seq[Expression]) extends Expression { @ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.") -case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable { - override def children: Seq[Expression] = Seq(left, right) - - override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right)) +case class IfNull(left: Expression, right: Expression, child: Expression) + extends RuntimeReplaceable { - override def replaceForTypeCoercion(): Expression = { - if (left.dataType != right.dataType) { - TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => - copy(left = Cast(left, dtype), right = Cast(right, dtype)) - }.getOrElse(this) - } else { - this - } + def this(left: Expression, right: Expression) = { + this(left, right, Coalesce(Seq(left, right))) } + + override def flatArguments: Iterator[Any] = Iterator(left, right) + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.") -case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable { - override def children: Seq[Expression] = Seq(left, right) +case class NullIf(left: Expression, right: Expression, child: Expression) + extends RuntimeReplaceable { - override def replaceForEvaluation(): Expression = { - If(EqualTo(left, right), Literal.create(null, left.dataType), left) + def this(left: Expression, right: Expression) = { + this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left)) } - override def replaceForTypeCoercion(): Expression = { - if (left.dataType != right.dataType) { - TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => - copy(left = Cast(left, dtype), right = Cast(right, dtype)) - }.getOrElse(this) - } else { - this - } - } + override def flatArguments: Iterator[Any] = Iterator(left, right) + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.") -case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable { - override def children: Seq[Expression] = Seq(left, right) +case class Nvl(left: Expression, right: Expression, child: Expression) extends RuntimeReplaceable { - override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right)) - - override def replaceForTypeCoercion(): Expression = { - if (left.dataType != right.dataType) { - TypeCoercion.findTightestCommonTypeToString(left.dataType, right.dataType).map { dtype => - copy(left = Cast(left, dtype), right = Cast(right, dtype)) - }.getOrElse(this) - } else { - this - } + def this(left: Expression, right: Expression) = { + this(left, right, Coalesce(Seq(left, right))) } + + override def flatArguments: Iterator[Any] = Iterator(left, right) + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.") -case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression) +case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: Expression) extends RuntimeReplaceable { - override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3) - - override def children: Seq[Expression] = Seq(expr1, expr2, expr3) - - override def replaceForTypeCoercion(): Expression = { - if (expr2.dataType != expr3.dataType) { - TypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype => - copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype)) - }.getOrElse(this) - } else { - this - } + def this(expr1: Expression, expr2: Expression, expr3: Expression) = { + this(expr1, expr2, expr3, If(IsNotNull(expr1), expr2, expr3)) } + + override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3) + override def sql: String = s"$prettyName(${expr1.sql}, ${expr2.sql}, ${expr3.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 7c667315870f5e562bccabb9a93db3cfcc0a39f0..f20eb958fe973400d9cdcc7921dbc6b501c91fce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e: RuntimeReplaceable => e.replaced + case e: RuntimeReplaceable => e.child } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index e7363799306192a199f494689831232130e590d1..62c9ab3b67fb6e5e2fd767674c451b2f6ca91fd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -86,18 +88,23 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-16602 Nvl should support numeric-string cases") { + def analyze(expr: Expression): Expression = { + val relation = LocalRelation() + SimpleAnalyzer.execute(Project(Seq(Alias(expr, "c")()), relation)).expressions.head + } + val intLit = Literal.create(1, IntegerType) val doubleLit = Literal.create(2.2, DoubleType) val stringLit = Literal.create("c", StringType) val nullLit = Literal.create(null, NullType) - assert(Nvl(intLit, doubleLit).replaceForTypeCoercion().dataType == DoubleType) - assert(Nvl(intLit, stringLit).replaceForTypeCoercion().dataType == StringType) - assert(Nvl(stringLit, doubleLit).replaceForTypeCoercion().dataType == StringType) + assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) - assert(Nvl(nullLit, intLit).replaceForTypeCoercion().dataType == IntegerType) - assert(Nvl(doubleLit, nullLit).replaceForTypeCoercion().dataType == DoubleType) - assert(Nvl(nullLit, stringLit).replaceForTypeCoercion().dataType == StringType) + assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) + assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) + assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) } test("AtLeastNNonNulls") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql new file mode 100644 index 0000000000000000000000000000000000000000..2b5b692d29ef4f75ad3dedf89c5430d15bd45100 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -0,0 +1,25 @@ +-- A test suite for functions added for compatibility with other databases such as Oracle, MSSQL. +-- These functions are typically implemented using the trait RuntimeReplaceable. + +SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null); +SELECT nullif('x', 'x'), nullif('x', 'y'); +SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null); +SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null); + +-- type coercion +SELECT ifnull(1, 2.1d), ifnull(null, 2.1d); +SELECT nullif(1, 2.1d), nullif(1, 1.0d); +SELECT nvl(1, 2.1d), nvl(null, 2.1d); +SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d); + +-- explain for these functions; use range to avoid constant folding +explain extended +select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') +from range(2); + +-- SPARK-16730 cast alias functions for Hive compatibility +SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1); +SELECT float(1), double(1), decimal(1); +SELECT date("2014-04-04"), timestamp(date("2014-04-04")); +-- error handling: only one argument +SELECT string(1, 2); diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 499a3d5fb72f6171ad482bd5ee7fd8d0dc50d430..981b2504bcaad05c6daa5f35df9833e620aca74f 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 12 -- !query 0 @@ -124,6 +124,7 @@ struct<sort_array(boolean_array, true):array<boolean>,sort_array(tinyint_array, -- !query 8 output [true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00.0,2016-11-15 20:54:00.0] + -- !query 9 select sort_array(array('b', 'd'), '1') -- !query 9 schema @@ -132,6 +133,7 @@ struct<> org.apache.spark.sql.AnalysisException cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + -- !query 10 select sort_array(array('b', 'd'), cast(NULL as boolean)) -- !query 10 schema @@ -140,6 +142,7 @@ struct<> org.apache.spark.sql.AnalysisException cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + -- !query 11 select size(boolean_array), diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..9f0b95994be53873ba4b4ac779592aba884c2a1d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -0,0 +1,124 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 13 + + +-- !query 0 +SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null) +-- !query 0 schema +struct<ifnull(NULL, 'x'):string,ifnull('y', 'x'):string,ifnull(NULL, NULL):null> +-- !query 0 output +x y NULL + + +-- !query 1 +SELECT nullif('x', 'x'), nullif('x', 'y') +-- !query 1 schema +struct<nullif('x', 'x'):string,nullif('x', 'y'):string> +-- !query 1 output +NULL x + + +-- !query 2 +SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null) +-- !query 2 schema +struct<nvl(NULL, 'x'):string,nvl('y', 'x'):string,nvl(NULL, NULL):null> +-- !query 2 output +x y NULL + + +-- !query 3 +SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null) +-- !query 3 schema +struct<nvl2(NULL, 'x', 'y'):string,nvl2('n', 'x', 'y'):string,nvl2(NULL, NULL, NULL):null> +-- !query 3 output +y x NULL + + +-- !query 4 +SELECT ifnull(1, 2.1d), ifnull(null, 2.1d) +-- !query 4 schema +struct<ifnull(1, 2.1D):double,ifnull(NULL, 2.1D):double> +-- !query 4 output +1.0 2.1 + + +-- !query 5 +SELECT nullif(1, 2.1d), nullif(1, 1.0d) +-- !query 5 schema +struct<nullif(1, 2.1D):int,nullif(1, 1.0D):int> +-- !query 5 output +1 NULL + + +-- !query 6 +SELECT nvl(1, 2.1d), nvl(null, 2.1d) +-- !query 6 schema +struct<nvl(1, 2.1D):double,nvl(NULL, 2.1D):double> +-- !query 6 output +1.0 2.1 + + +-- !query 7 +SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d) +-- !query 7 schema +struct<nvl2(NULL, 1, 2.1D):double,nvl2('n', 1, 2.1D):double> +-- !query 7 output +2.1 1.0 + + +-- !query 8 +explain extended +select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') +from range(2) +-- !query 8 schema +struct<plan:string> +-- !query 8 output +== Parsed Logical Plan == +'Project [unresolvedalias('ifnull('id, x), None), unresolvedalias('nullif('id, x), None), unresolvedalias('nvl('id, x), None), unresolvedalias('nvl2('id, x, y), None)] ++- 'UnresolvedTableValuedFunction range, [2] + +== Analyzed Logical Plan == +ifnull(`id`, 'x'): string, nullif(`id`, 'x'): bigint, nvl(`id`, 'x'): string, nvl2(`id`, 'x', 'y'): string +Project [ifnull(id#xL, x) AS ifnull(`id`, 'x')#x, nullif(id#xL, x) AS nullif(`id`, 'x')#xL, nvl(id#xL, x) AS nvl(`id`, 'x')#x, nvl2(id#xL, x, y) AS nvl2(`id`, 'x', 'y')#x] ++- Range (0, 2, step=1, splits=None) + +== Optimized Logical Plan == +Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] ++- Range (0, 2, step=1, splits=None) + +== Physical Plan == +*Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] ++- *Range (0, 2, step=1, splits=None) + + +-- !query 9 +SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1) +-- !query 9 schema +struct<CAST(1 AS BOOLEAN):boolean,CAST(1 AS TINYINT):tinyint,CAST(1 AS SMALLINT):smallint,CAST(1 AS INT):int,CAST(1 AS BIGINT):bigint> +-- !query 9 output +true 1 1 1 1 + + +-- !query 10 +SELECT float(1), double(1), decimal(1) +-- !query 10 schema +struct<CAST(1 AS FLOAT):float,CAST(1 AS DOUBLE):double,CAST(1 AS DECIMAL(10,0)):decimal(10,0)> +-- !query 10 output +1.0 1.0 1 + + +-- !query 11 +SELECT date("2014-04-04"), timestamp(date("2014-04-04")) +-- !query 11 schema +struct<CAST(2014-04-04 AS DATE):date,CAST(CAST(2014-04-04 AS DATE) AS TIMESTAMP):timestamp> +-- !query 11 output +2014-04-04 2014-04-04 00:00:00 + + +-- !query 12 +SELECT string(1, 2) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +Function string accepts only one argument; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala deleted file mode 100644 index 27b60e0d9def8a28bfe91b2e91c898c0fe040804..0000000000000000000000000000000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.math.BigDecimal -import java.sql.Timestamp - -import org.apache.spark.sql.test.SharedSQLContext - -/** - * A test suite for functions added for compatibility with other databases such as Oracle, MSSQL. - * - * These functions are typically implemented using the trait - * [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]]. - */ -class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext { - - test("ifnull") { - checkAnswer( - sql("SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)"), - Row("x", "y", null)) - - // Type coercion - checkAnswer( - sql("SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)"), - Row(1.0, 2.1)) - } - - test("nullif") { - checkAnswer( - sql("SELECT nullif('x', 'x'), nullif('x', 'y')"), - Row(null, "x")) - - // Type coercion - checkAnswer( - sql("SELECT nullif(1, 2.1d), nullif(1, 1.0d)"), - Row(1.0, null)) - } - - test("nvl") { - checkAnswer( - sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), - Row("x", "y", null)) - - // Type coercion - checkAnswer( - sql("SELECT nvl(1, 2.1d), nvl(null, 2.1d)"), - Row(1.0, 2.1)) - } - - test("nvl2") { - checkAnswer( - sql("SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)"), - Row("y", "x", null)) - - // Type coercion - checkAnswer( - sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"), - Row(2.1, 1.0)) - } - - test("SPARK-16730 cast alias functions for Hive compatibility") { - checkAnswer( - sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"), - Row(true, 1.toByte, 1.toShort, 1, 1L)) - - checkAnswer( - sql("SELECT float(1), double(1), decimal(1)"), - Row(1.toFloat, 1.0, new BigDecimal(1))) - - checkAnswer( - sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"), - Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0))) - - checkAnswer( - sql("SELECT string(1)"), - Row("1")) - - // Error handling: only one argument - val errorMsg = intercept[AnalysisException](sql("SELECT string(1, 2)")).getMessage - assert(errorMsg.contains("Function string accepts only one argument")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 2d73d9f1fc802066253448b807e55bc10e0c1fc6..1a4049fb339cb64c72e7a4a0734be7f1bd9f7864 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} -import org.apache.spark.sql.execution.command.ShowColumnsCommand import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -215,7 +214,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { try { val df = session.sql(sql) val schema = df.schema - val answer = df.queryExecution.hiveResultString() + // Get answer, but also get rid of the #1234 expression ids that show up in explain plans + val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x")) // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)