diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 18e514681e8119adc76576f706f155230ddb40c3..f6653d384fe1dd9f705029b3466c1a3d3cfb6705 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -73,7 +73,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser, + new CatalystSqlParser(conf), DummyFunctionResourceLoader) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3fa84589e3c6886365d30e09266436b36c70f235..aa5a1b5448c6dfee612171baeb0dde517bb5586f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression escape character, the following character is matched literally. It is invalid to escape any other character. + Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order + to match "\abc", the pattern should be "\\abc". + + When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it fallbacks + to Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the pattern to match "\abc" should be "\abc". + Examples: > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' true @@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } @ExpressionDescription( - usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") + usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.", + extended = """ + Arguments: + str - a string expression + regexp - a string expression. The pattern string should be a Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser. + For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to fallback + to the Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the `regexp` that can match "\abc" is "^\abc$". + + Examples: + When spark.sql.parser.escapedStringLiterals is disabled (default). + > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*' + true + + When spark.sql.parser.escapedStringLiterals is enabled. + > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\Users.*' + true + + See also: + Use LIKE to match with simple string pattern. +""") case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 046ea65d454a1ceecde6964b0aa34b0013c44cfd..4b11b6f8d2cf0e9de7ec65db5b9cff414c7c54d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ + def this() = this(new SQLConf()) + protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -1423,7 +1426,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - ctx.STRING().asScala.map(string).mkString + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(stringWithoutUnescape).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index dcccbd0ed8d6b5efb00fd27e961bb7b2e4932368..8e2e973485e1c467bd86ebb1f5903e2a144b3dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ +class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new AstBuilder(conf) +} + +/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder + val astBuilder = new AstBuilder(new SQLConf()) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 6fbc33fad735c935430666586e5184267c12a1d2..77fdaa8255aa650bfb2199fd03e88af6d5d8c9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -68,6 +68,12 @@ object ParserUtils { /** Convert a string node into a string. */ def string(node: TerminalNode): String = unescapeSQLString(node.getText) + /** Convert a string node into a string without unescaping. */ + def stringWithoutUnescape(node: TerminalNode): String = { + // STRING parser rule forces that the input always has quotes at the starting and ending. + node.getText.slice(1, node.getText.size - 1) + } + /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { val opt = Option(token) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b24419a41edb0e903fe5cb0ea2ddb93a9ae82e01..b97adf7221d1865df1e1a4cac2069c25e070176b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -196,6 +196,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") + .internal() + .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + + "parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " + + "prior to Spark 2.0.") + .booleanConf + .createWithDefault(false) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -917,6 +925,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index eb68eb9851b8537c2cd2dd7a046448921cc44674..8bc2010cabeced24607e1121b3a12a24e1ce6391 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest { import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ - def assertEqual(sqlCommand: String, e: Expression): Unit = { - compareExpressions(parseExpression(sqlCommand), e) + val defaultParser = CatalystSqlParser + + def assertEqual( + sqlCommand: String, + e: Expression, + parser: ParserInterface = defaultParser): Unit = { + compareExpressions(parser.parseExpression(sqlCommand), e) } def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parseExpression(sqlCommand)) + val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) } @@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest { test("long binary logical expressions") { def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) - val e = parseExpression(sql) + val e = defaultParser.parseExpression(sql) assert(e.collect { case _: EqualTo => true }.size === 1000) assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) } @@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) } + test("like expressions with ESCAPED_STRING_LITERALS = true") { + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") + val parser = new CatalystSqlParser(conf) + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + } + test("is null expressions") { assertEqual("a is null", 'a.isNull) assertEqual("a is not null", 'a.isNotNull) @@ -418,38 +433,79 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { - // Single Strings. - assertEqual("\"hello\"", "hello") - assertEqual("'hello'", "hello") - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld") - assertEqual("'hello' \" \" 'world'", "hello world") - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%") - assertEqual("'no-pattern\\%'", "no-pattern\\%") - assertEqual("'pattern\\\\%'", "pattern\\%") - assertEqual("'pattern\\\\\\%'", "pattern\\\\%") - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') - assertEqual("'\\''", "\'") // Single quote - assertEqual("'\\\"'", "\"") // Double quote - assertEqual("'\\b'", "\b") // Backspace - assertEqual("'\\n'", "\n") // Newline - assertEqual("'\\r'", "\r") // Carriage return - assertEqual("'\\t'", "\t") // Tab character - assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") + Seq(true, false).foreach { escape => + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) + val parser = new CatalystSqlParser(conf) + + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + assertEqual("'\0'", "\u0000", parser) // ASCII NUL (X'00') + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + assertEqual("'\"'", "\"", parser) // Double quote + assertEqual("'\b'", "\b", parser) // Backspace + assertEqual("'\n'", "\n", parser) // Newline + assertEqual("'\r'", "\r", parser) // Carriage return + assertEqual("'\t'", "\t", parser) // Tab character + + // Octals + assertEqual("'\110\145\154\154\157\041'", "Hello!", parser) + // Unicode + assertEqual("'\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'", "World :)", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) + } + + } } test("intervals") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 20dacf88504f19bfb2b3acc5b0b38a4581d49993..c2c52894860b544d93e02c3a9edf298e32e7a4f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { +class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5b5cd28ad0c99af2e58e0800f9209cf1355d8a2d..8eb381b91f46df8192bd65c54ea927f41716e90f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -1168,6 +1169,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) } + + test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") { + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { + val data = Seq("\u0020\u0021\u0023", "abc") + val df = data.toDF() + val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'") + val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$")) + val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'") + checkAnswer(rlike1, rlike2) + assert(rlike3.count() == 0) + } + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])