From fe767395ff46ee6236cf53aece85fcd61c0b49d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B2=91=E7=8E=89=E6=B5=B7?= <261810726@qq.com> Date: Thu, 15 Sep 2016 20:45:00 +0200 Subject: [PATCH] [SPARK-17429][SQL] use ImplicitCastInputTypes with function Length MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? select length(11); select length(2.0); these sql will return errors, but hive is ok. this PR will support casting input types implicitly for function length the correct result is: select length(11) return 2 select length(2.0) return 3 Author: 岑玉海 <261810726@qq.com> Author: cenyuhai <cenyuhai@didichuxing.com> Closes #15014 from cenyuhai/SPARK-17429. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../org/apache/spark/sql/StringFunctionsSuite.scala | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index a8c23a8b0c..1bcbb6cfc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1057,7 +1057,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) @ExpressionDescription( usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.", extended = "> SELECT _FUNC_('Spark SQL');\n 9") -case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 1cc77464b9..bcc2351049 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -330,7 +330,8 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string / binary length function") { - val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015)) + .toDF("a", "b", "c", "d", "e") checkAnswer( df.select(length($"a"), length($"b")), Row(3, 4)) @@ -339,9 +340,10 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("length(a)", "length(b)"), Row(3, 4)) - intercept[AnalysisException] { - df.selectExpr("length(c)") // int type of the argument is unacceptable - } + checkAnswer( + df.selectExpr("length(c)", "length(d)", "length(e)"), + Row(3, 3, 5) + ) } test("initcap function") { -- GitLab