From d642b273544bb77ef7f584326aa2d214649ac61b Mon Sep 17 00:00:00 2001 From: Daoyuan Wang <daoyuan.wang@intel.com> Date: Mon, 23 May 2016 23:29:15 -0700 Subject: [PATCH] [SPARK-15397][SQL] fix string udf locate as hive ## What changes were proposed in this pull request? in hive, `locate("aa", "aaa", 0)` would yield 0, `locate("aa", "aaa", 1)` would yield 1 and `locate("aa", "aaa", 2)` would yield 2, while in Spark, `locate("aa", "aaa", 0)` would yield 1, `locate("aa", "aaa", 1)` would yield 2 and `locate("aa", "aaa", 2)` would yield 0. This results from the different understanding of the third parameter in udf `locate`. It means the starting index and starts from 1, so when we use 0, the return would always be 0. ## How was this patch tested? tested with modified `StringExpressionsSuite` and `StringFunctionsSuite` Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #13186 from adrian-wang/locate. --- R/pkg/R/functions.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- python/pyspark/sql/functions.py | 2 +- .../expressions/stringExpressions.scala | 19 +++++++++++++------ .../expressions/StringExpressionsSuite.scala | 16 +++++++++------- .../spark/sql/StringFunctionsSuite.scala | 10 +++++----- 6 files changed, 30 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4a0bdf3315..2665d1d477 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2226,7 +2226,7 @@ setMethod("window", signature(x = "Column"), #' @export #' @examples \dontrun{locate('b', df$c, 1)} setMethod("locate", signature(substr = "character", str = "Column"), - function(substr, str, pos = 0) { + function(substr, str, pos = 1) { jc <- callJStatic("org.apache.spark.sql.functions", "locate", substr, str@jc, as.integer(pos)) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6a99b43e5a..b2d769f2ac 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1152,7 +1152,7 @@ test_that("string operators", { l2 <- list(list(a = "aaads")) df2 <- createDataFrame(sqlContext, l2) expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) - expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, locate("aa", df2$a, 2)))[1, 1], 2) expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1f15eec645..64b8bc442d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1359,7 +1359,7 @@ def levenshtein(left, right): @since(1.5) -def locate(substr, str, pos=0): +def locate(substr, str, pos=1): """ Locate the position of the first occurrence of substr in a string column, after position pos. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 78e846d3f5..44ff7fda8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -494,7 +494,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) extends TernaryExpression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { - this(substr, str, Literal(0)) + this(substr, str, Literal(1)) } override def children: Seq[Expression] = substr :: str :: start :: Nil @@ -516,9 +516,14 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) if (l == null) { null } else { - l.asInstanceOf[UTF8String].indexOf( - r.asInstanceOf[UTF8String], - s.asInstanceOf[Int]) + 1 + val sVal = s.asInstanceOf[Int] + if (sVal < 1) { + 0 + } else { + l.asInstanceOf[UTF8String].indexOf( + r.asInstanceOf[UTF8String], + s.asInstanceOf[Int] - 1) + 1 + } } } } @@ -537,8 +542,10 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) if (!${substrGen.isNull}) { ${strGen.code} if (!${strGen.isNull}) { - ${ev.value} = ${strGen.value}.indexOf(${substrGen.value}, - ${startGen.value}) + 1; + if (${startGen.value} > 0) { + ${ev.value} = ${strGen.value}.indexOf(${substrGen.value}, + ${startGen.value} - 1) + 1; + } } else { ${ev.isNull} = true; } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index c09c64fd6b..29bf15bf52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -508,16 +508,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s2 = 'b.string.at(1) val s3 = 'c.string.at(2) val s4 = 'd.int.at(3) - val row1 = create_row("aaads", "aa", "zz", 1) - val row2 = create_row(null, "aa", "zz", 0) - val row3 = create_row("aaads", null, "zz", 0) - val row4 = create_row(null, null, null, 0) + val row1 = create_row("aaads", "aa", "zz", 2) + val row2 = create_row(null, "aa", "zz", 1) + val row3 = create_row("aaads", null, "zz", 1) + val row4 = create_row(null, null, null, 1) checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) - checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) - checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(0)), 0, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 1, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 2, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(3)), 0, row1) checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) - checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 2), 0, row1) checkEvaluation(new StringLocate(s2, s1), 1, row1) checkEvaluation(StringLocate(s2, s1, s4), 2, row1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index c7b95c2683..1de2d9b5ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -189,15 +189,15 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string locate function") { - val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + val df = Seq(("aaads", "aa", "zz", 2)).toDF("a", "b", "c", "d") checkAnswer( - df.select(locate("aa", $"a"), locate("aa", $"a", 1)), - Row(1, 2)) + df.select(locate("aa", $"a"), locate("aa", $"a", 2), locate("aa", $"a", 0)), + Row(1, 2, 0)) checkAnswer( - df.selectExpr("locate(b, a)", "locate(b, a, d)"), - Row(1, 2)) + df.selectExpr("locate(b, a)", "locate(b, a, d)", "locate(b, a, 3)"), + Row(1, 2, 0)) } test("string padding functions") { -- GitLab