diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 65168998c8aee227f736d377b7014b9235e27644..c5f91c15905429be118feec9d37da33cfb0eb240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap +import org.apache.spark.sql.types._ /** @@ -408,8 +409,21 @@ object FunctionRegistry { expression[BitwiseAnd]("&"), expression[BitwiseNot]("~"), expression[BitwiseOr]("|"), - expression[BitwiseXor]("^") - + expression[BitwiseXor]("^"), + + // Cast aliases (SPARK-16730) + castAlias("boolean", BooleanType), + castAlias("tinyint", ByteType), + castAlias("smallint", ShortType), + castAlias("int", IntegerType), + castAlias("bigint", LongType), + castAlias("float", FloatType), + castAlias("double", DoubleType), + castAlias("decimal", DecimalType.USER_DEFAULT), + castAlias("date", DateType), + castAlias("timestamp", TimestampType), + castAlias("binary", BinaryType), + castAlias("string", StringType) ) val builtin: SimpleFunctionRegistry = { @@ -452,14 +466,37 @@ object FunctionRegistry { } } - val clazz = tag.runtimeClass + (name, (expressionInfo[T](name), builder)) + } + + /** + * Creates a function registry lookup entry for cast aliases (SPARK-16730). + * For example, if name is "int", and dataType is IntegerType, this means int(x) would become + * an alias for cast(x as IntegerType). + * See usage above. + */ + private def castAlias( + name: String, + dataType: DataType): (String, (ExpressionInfo, FunctionBuilder)) = { + val builder = (args: Seq[Expression]) => { + if (args.size != 1) { + throw new AnalysisException(s"Function $name accepts only one argument") + } + Cast(args.head, dataType) + } + (name, (expressionInfo[Cast](name), builder)) + } + + /** + * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name. + */ + private def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = { + val clazz = scala.reflect.classTag[T].runtimeClass val df = clazz.getAnnotation(classOf[ExpressionDescription]) if (df != null) { - (name, - (new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()), - builder)) + new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()) } else { - (name, (new ExpressionInfo(clazz.getCanonicalName, name), builder)) + new ExpressionInfo(clazz.getCanonicalName, name) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a12fba047b3d4df4c85e3c2f78ca51f17ae53ec0..c452765af2dd9e67f1e505b89e6f0b54eda892c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -113,6 +113,9 @@ object Cast { } /** Cast the child expression to the target data type. */ +@ExpressionDescription( + usage = " - Cast value v to the target data type.", + extended = "> SELECT _FUNC_('10' as int);\n 10") case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala index 1e3239550fb819504d6d0ef840f9d6547055143e..27b60e0d9def8a28bfe91b2e91c898c0fe040804 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql +import java.math.BigDecimal +import java.sql.Timestamp + import org.apache.spark.sql.test.SharedSQLContext /** * A test suite for functions added for compatibility with other databases such as Oracle, MSSQL. + * * These functions are typically implemented using the trait * [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]]. */ @@ -69,4 +73,26 @@ class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext { sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"), Row(2.1, 1.0)) } + + test("SPARK-16730 cast alias functions for Hive compatibility") { + checkAnswer( + sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"), + Row(true, 1.toByte, 1.toShort, 1, 1L)) + + checkAnswer( + sql("SELECT float(1), double(1), decimal(1)"), + Row(1.toFloat, 1.0, new BigDecimal(1))) + + checkAnswer( + sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"), + Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0))) + + checkAnswer( + sql("SELECT string(1)"), + Row("1")) + + // Error handling: only one argument + val errorMsg = intercept[AnalysisException](sql("SELECT string(1, 2)")).getMessage + assert(errorMsg.contains("Function string accepts only one argument")) + } }