Skip to content
Snippets Groups Projects
Commit 609ba5f2 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Wenchen Fan
Browse files

[SPARK-20399][SQL] Add a config to fallback string literal parsing consistent...

[SPARK-20399][SQL] Add a config to fallback string literal parsing consistent with old sql parser behavior

## What changes were proposed in this pull request?

The new SQL parser is introduced into Spark 2.0. All string literals are unescaped in parser. Seems it bring an issue regarding the regex pattern string.

The following codes can reproduce it:

    val data = Seq("\u0020\u0021\u0023", "abc")
    val df = data.toDF()

    // 1st usage: works in 1.6
    // Let parser parse pattern string
    val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'")
    // 2nd usage: works in 1.6, 2.x
    // Call Column.rlike so the pattern string is a literal which doesn't go through parser
    val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$"))

    // In 2.x, we need add backslashes to make regex pattern parsed correctly
    val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'")

Follow the discussion in #17736, this patch adds a config to fallback to 1.6 string literal parsing and mitigate migration issue.

## How was this patch tested?

Jenkins tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #17887 from viirya/add-config-fallback-string-parsing.
parent 04901dd0
No related branches found
No related tags found
No related merge requests found
Showing
with 171 additions and 42 deletions
...@@ -73,7 +73,7 @@ class SessionCatalog( ...@@ -73,7 +73,7 @@ class SessionCatalog(
functionRegistry, functionRegistry,
conf, conf,
new Configuration(), new Configuration(),
CatalystSqlParser, new CatalystSqlParser(conf),
DummyFunctionResourceLoader) DummyFunctionResourceLoader)
} }
......
...@@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression ...@@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression
escape character, the following character is matched literally. It is invalid to escape escape character, the following character is matched literally. It is invalid to escape
any other character. any other character.
Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order
to match "\abc", the pattern should be "\\abc".
When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it fallbacks
to Spark 1.6 behavior regarding string literal parsing. For example, if the config is
enabled, the pattern to match "\abc" should be "\abc".
Examples: Examples:
> SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%'
true true
...@@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ...@@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
} }
@ExpressionDescription( @ExpressionDescription(
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.",
extended = """
Arguments:
str - a string expression
regexp - a string expression. The pattern string should be a Java regular expression.
Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser.
For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$".
There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to fallback
to the Spark 1.6 behavior regarding string literal parsing. For example, if the config is
enabled, the `regexp` that can match "\abc" is "^\abc$".
Examples:
When spark.sql.parser.escapedStringLiterals is disabled (default).
> SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*'
true
When spark.sql.parser.escapedStringLiterals is enabled.
> SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\Users.*'
true
See also:
Use LIKE to match with simple string pattern.
""")
case class RLike(left: Expression, right: Expression) extends StringRegexExpression { case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
override def escape(v: String): String = v override def escape(v: String): String = v
......
...@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} ...@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler import org.apache.spark.util.random.RandomSampler
...@@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler ...@@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
* TableIdentifier. * TableIdentifier.
*/ */
class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
import ParserUtils._ import ParserUtils._
def this() = this(new SQLConf())
protected def typedVisit[T](ctx: ParseTree): T = { protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T] ctx.accept(this).asInstanceOf[T]
} }
...@@ -1423,7 +1426,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ...@@ -1423,7 +1426,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Special characters can be escaped by using Hive/C-style escaping. * Special characters can be escaped by using Hive/C-style escaping.
*/ */
private def createString(ctx: StringLiteralContext): String = { private def createString(ctx: StringLiteralContext): String = {
ctx.STRING().asScala.map(string).mkString if (conf.escapedStringLiterals) {
ctx.STRING().asScala.map(stringWithoutUnescape).mkString
} else {
ctx.STRING().asScala.map(string).mkString
}
} }
/** /**
......
...@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} ...@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.types.{DataType, StructType}
/** /**
...@@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { ...@@ -121,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
/** /**
* Concrete SQL parser for Catalyst-only SQL statements. * Concrete SQL parser for Catalyst-only SQL statements.
*/ */
class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
val astBuilder = new AstBuilder(conf)
}
/** For test-only. */
object CatalystSqlParser extends AbstractSqlParser { object CatalystSqlParser extends AbstractSqlParser {
val astBuilder = new AstBuilder val astBuilder = new AstBuilder(new SQLConf())
} }
/** /**
......
...@@ -68,6 +68,12 @@ object ParserUtils { ...@@ -68,6 +68,12 @@ object ParserUtils {
/** Convert a string node into a string. */ /** Convert a string node into a string. */
def string(node: TerminalNode): String = unescapeSQLString(node.getText) def string(node: TerminalNode): String = unescapeSQLString(node.getText)
/** Convert a string node into a string without unescaping. */
def stringWithoutUnescape(node: TerminalNode): String = {
// STRING parser rule forces that the input always has quotes at the starting and ending.
node.getText.slice(1, node.getText.size - 1)
}
/** Get the origin (line and position) of the token. */ /** Get the origin (line and position) of the token. */
def position(token: Token): Origin = { def position(token: Token): Origin = {
val opt = Option(token) val opt = Option(token)
......
...@@ -196,6 +196,14 @@ object SQLConf { ...@@ -196,6 +196,14 @@ object SQLConf {
.booleanConf .booleanConf
.createWithDefault(true) .createWithDefault(true)
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
"parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " +
"prior to Spark 2.0.")
.booleanConf
.createWithDefault(false)
val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema")
.doc("When true, the Parquet data source merges schemas collected from all data files, " + .doc("When true, the Parquet data source merges schemas collected from all data files, " +
"otherwise the schema is picked from the summary file or a random data file " + "otherwise the schema is picked from the summary file or a random data file " +
...@@ -917,6 +925,8 @@ class SQLConf extends Serializable with Logging { ...@@ -917,6 +925,8 @@ class SQLConf extends Serializable with Logging {
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
/** /**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two * Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal. * identifiers are equal.
......
...@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} ...@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.unsafe.types.CalendarInterval
...@@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest { ...@@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest {
import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.plans._
def assertEqual(sqlCommand: String, e: Expression): Unit = { val defaultParser = CatalystSqlParser
compareExpressions(parseExpression(sqlCommand), e)
def assertEqual(
sqlCommand: String,
e: Expression,
parser: ParserInterface = defaultParser): Unit = {
compareExpressions(parser.parseExpression(sqlCommand), e)
} }
def intercept(sqlCommand: String, messages: String*): Unit = { def intercept(sqlCommand: String, messages: String*): Unit = {
val e = intercept[ParseException](parseExpression(sqlCommand)) val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand))
messages.foreach { message => messages.foreach { message =>
assert(e.message.contains(message)) assert(e.message.contains(message))
} }
...@@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest { ...@@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest {
test("long binary logical expressions") { test("long binary logical expressions") {
def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
val e = parseExpression(sql) val e = defaultParser.parseExpression(sql)
assert(e.collect { case _: EqualTo => true }.size === 1000) assert(e.collect { case _: EqualTo => true }.size === 1000)
assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
} }
...@@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest { ...@@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
} }
test("like expressions with ESCAPED_STRING_LITERALS = true") {
val conf = new SQLConf()
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true")
val parser = new CatalystSqlParser(conf)
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
}
test("is null expressions") { test("is null expressions") {
assertEqual("a is null", 'a.isNull) assertEqual("a is null", 'a.isNull)
assertEqual("a is not null", 'a.isNotNull) assertEqual("a is not null", 'a.isNotNull)
...@@ -418,38 +433,79 @@ class ExpressionParserSuite extends PlanTest { ...@@ -418,38 +433,79 @@ class ExpressionParserSuite extends PlanTest {
} }
test("strings") { test("strings") {
// Single Strings. Seq(true, false).foreach { escape =>
assertEqual("\"hello\"", "hello") val conf = new SQLConf()
assertEqual("'hello'", "hello") conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString)
val parser = new CatalystSqlParser(conf)
// Multi-Strings.
assertEqual("\"hello\" 'world'", "helloworld") // tests that have same result whatever the conf is
assertEqual("'hello' \" \" 'world'", "hello world") // Single Strings.
assertEqual("\"hello\"", "hello", parser)
// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a assertEqual("'hello'", "hello", parser)
// regular '%'; to get the correct result you need to add another escaped '\'.
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? // Multi-Strings.
assertEqual("'pattern%'", "pattern%") assertEqual("\"hello\" 'world'", "helloworld", parser)
assertEqual("'no-pattern\\%'", "no-pattern\\%") assertEqual("'hello' \" \" 'world'", "hello world", parser)
assertEqual("'pattern\\\\%'", "pattern\\%")
assertEqual("'pattern\\\\\\%'", "pattern\\\\%") // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
// regular '%'; to get the correct result you need to add another escaped '\'.
// Escaped characters. // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html assertEqual("'pattern%'", "pattern%", parser)
assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)
assertEqual("'\\''", "\'") // Single quote
assertEqual("'\\\"'", "\"") // Double quote // tests that have different result regarding the conf
assertEqual("'\\b'", "\b") // Backspace if (escape) {
assertEqual("'\\n'", "\n") // Newline // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
assertEqual("'\\r'", "\r") // Carriage return // Spark 1.6 behavior.
assertEqual("'\\t'", "\t") // Tab character
assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) // 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
// Octals assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
// Escaped characters.
// Unicode assertEqual("'\0'", "\u0000", parser) // ASCII NUL (X'00')
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
// Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled.
val e = intercept[ParseException](parser.parseExpression("'\''"))
assert(e.message.contains("extraneous input '''"))
assertEqual("'\"'", "\"", parser) // Double quote
assertEqual("'\b'", "\b", parser) // Backspace
assertEqual("'\n'", "\n", parser) // Newline
assertEqual("'\r'", "\r", parser) // Carriage return
assertEqual("'\t'", "\t", parser) // Tab character
// Octals
assertEqual("'\110\145\154\154\157\041'", "Hello!", parser)
// Unicode
assertEqual("'\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'", "World :)", parser)
} else {
// Default behavior
// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)
// Escaped characters.
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
assertEqual("'\\''", "\'", parser) // Single quote
assertEqual("'\\\"'", "\"", parser) // Double quote
assertEqual("'\\b'", "\b", parser) // Backspace
assertEqual("'\\n'", "\n", parser) // Newline
assertEqual("'\\r'", "\r", parser) // Carriage return
assertEqual("'\\t'", "\t", parser) // Tab character
assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)
// Octals
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)
// Unicode
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
parser)
}
}
} }
test("intervals") { test("intervals") {
......
...@@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { ...@@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
/** /**
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
*/ */
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
import org.apache.spark.sql.catalyst.parser.ParserUtils._ import org.apache.spark.sql.catalyst.parser.ParserUtils._
/** /**
......
...@@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} ...@@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -1168,6 +1169,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ...@@ -1168,6 +1169,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS()
checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) checkDataset(ds, WithMapInOption(Some(Map(1 -> 1))))
} }
test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") {
withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") {
val data = Seq("\u0020\u0021\u0023", "abc")
val df = data.toDF()
val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'")
val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$"))
val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'")
checkAnswer(rlike1, rlike2)
assert(rlike3.count() == 0)
}
}
} }
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment