From 2cbfc975ba937a4eb761de7a6473b7747941f386 Mon Sep 17 00:00:00 2001 From: Jane Wang <janewang@fb.com> Date: Tue, 11 Jul 2017 22:00:36 -0700 Subject: [PATCH] [SPARK-12139][SQL] REGEX Column Specification ## What changes were proposed in this pull request? Hive interprets regular expression, e.g., `(a)?+.+` in query specification. This PR enables spark to support this feature when hive.support.quoted.identifiers is set to true. ## How was this patch tested? - Add unittests in SQLQuerySuite.scala - Run spark-shell tested the original failed query: scala> hc.sql("SELECT `(a|b)?+.+` from test1").collect.foreach(println) Author: Jane Wang <janewang@fb.com> Closes #18023 from janewangfb/support_select_regex. --- .../sql/catalyst/analysis/unresolved.scala | 29 +- .../sql/catalyst/parser/AstBuilder.scala | 43 ++- .../sql/catalyst/parser/ParserUtils.scala | 6 + .../apache/spark/sql/internal/SQLConf.scala | 8 + .../scala/org/apache/spark/sql/Dataset.scala | 27 +- .../sql-tests/inputs/query_regex_column.sql | 52 +++ .../results/query_regex_column.sql.out | 313 ++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 10 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 225 +++++++------ .../apache/spark/sql/DataFrameStatSuite.scala | 87 ++--- .../org/apache/spark/sql/DataFrameSuite.scala | 55 +-- .../org/apache/spark/sql/DatasetSuite.scala | 81 ++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 41 ++- .../datasources/json/JsonSuite.scala | 38 ++- .../parquet/ParquetFilterSuite.scala | 3 +- .../spark/sql/sources/DataSourceTest.scala | 7 +- .../spark/sql/sources/TableScanSuite.scala | 55 +-- 17 files changed, 825 insertions(+), 255 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 42b9641bef..fb322697c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -123,7 +124,10 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def toString: String = s"'$name" - override def sql: String = quoteIdentifier(name) + override def sql: String = name match { + case ParserUtils.escapedIdentifier(_) | ParserUtils.qualifiedEscapedIdentifier(_, _) => name + case _ => quoteIdentifier(name) + } } object UnresolvedAttribute { @@ -306,6 +310,29 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu override def toString: String = target.map(_ + ".").getOrElse("") + "*" } +/** + * Represents all of the input attributes to a given relational operator, for example in + * "SELECT `(id)?+.+` FROM ...". + * + * @param table an optional table that should be the target of the expansion. If omitted all + * tables' columns are produced. + */ +case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSensitive: Boolean) + extends Star with Unevaluable { + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { + val pattern = if (caseSensitive) regexPattern else s"(?i)$regexPattern" + table match { + // If there is no table specified, use all input attributes that match expr + case None => input.output.filter(_.name.matches(pattern)) + // If there is a table, pick out attributes that are part of this table that match expr + case Some(t) => input.output.filter(_.qualifier.exists(resolver(_, t))) + .filter(_.name.matches(pattern)) + } + } + + override def toString: String = table.map(_ + "." + regexPattern).getOrElse(regexPattern) +} + /** * Used to assign new names to Generator's output, such as hive udtf. * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a616b0f773..ad359e714b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1261,25 +1261,54 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Create a dereference expression. The return type depends on the type of the parent, this can - * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an - * [[UnresolvedExtractValue]] if the parent is some expression. + * Currently only regex in expressions of SELECT statements are supported; in other + * places, e.g., where `(a)?+.+` = 2, regex are not meaningful. + */ + private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) { + var parent = ctx.getParent + while (parent != null) { + if (parent.isInstanceOf[NamedExpressionContext]) return true + parent = parent.getParent + } + return false + } + + /** + * Create a dereference expression. The return type depends on the type of the parent. + * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or + * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression, + * it can be [[UnresolvedExtractValue]]. */ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { val attr = ctx.fieldName.getText expression(ctx.base) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ attr) + case unresolved_attr @ UnresolvedAttribute(nameParts) => + ctx.fieldName.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), + conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute(nameParts :+ attr) + } case e => UnresolvedExtractValue(e, Literal(attr)) } } /** - * Create an [[UnresolvedAttribute]] expression. + * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex + * quoted in `` */ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { - UnresolvedAttribute.quoted(ctx.getText) + ctx.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute.quoted(ctx.getText) + } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 77fdaa8255..9c1031e803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -177,6 +177,12 @@ object ParserUtils { sb.toString() } + /** the column name pattern in quoted regex without qualifier */ + val escapedIdentifier = "`(.+)`".r + + /** the column name pattern in quoted regex with qualifier */ + val qualifiedEscapedIdentifier = ("(.+)" + """.""" + "`(.+)`").r + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 643587a6eb..55558ca9f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -855,6 +855,12 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") + .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + + " as regular expressions.") + .booleanConf + .createWithDefault(false) + val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enable") .internal() @@ -1133,6 +1139,8 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7f3ae05411..b825b6cd61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions -import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} @@ -1178,8 +1178,29 @@ class Dataset[T] private[sql]( case "*" => Column(ResolvedStar(queryExecution.analyzed.output)) case _ => - val expr = resolve(colName) - Column(expr) + if (sqlContext.conf.supportQuotedRegexColumnName) { + colRegex(colName) + } else { + val expr = resolve(colName) + Column(expr) + } + } + + /** + * Selects column based on the column name specified as a regex and return it as [[Column]]. + * @group untypedrel + * @since 2.3.0 + */ + def colRegex(colName: String): Column = { + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + colName match { + case ParserUtils.escapedIdentifier(columnNameRegex) => + Column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) + case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => + Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) + case _ => + Column(resolve(colName)) + } } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql b/sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql new file mode 100644 index 0000000000..ad96754826 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/query_regex_column.sql @@ -0,0 +1,52 @@ +set spark.sql.parser.quotedRegexColumnNames=false; + +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, "1", "11"), (2, "2", "22"), (3, "3", "33"), (4, "4", "44"), (5, "5", "55"), (6, "6", "66") +AS testData(key, value1, value2); + +CREATE OR REPLACE TEMPORARY VIEW testData2 AS SELECT * FROM VALUES +(1, 1, 1, 2), (1, 2, 1, 2), (2, 1, 2, 3), (2, 2, 2, 3), (3, 1, 3, 4), (3, 2, 3, 4) +AS testData2(A, B, c, d); + +-- AnalysisException +SELECT `(a)?+.+` FROM testData2 WHERE a = 1; +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1; +SELECT `(a|b)` FROM testData2 WHERE a = 2; +SELECT `(a|b)?+.+` FROM testData2 WHERE a = 2; +SELECT SUM(`(a|b)?+.+`) FROM testData2; +SELECT SUM(`(a)`) FROM testData2; + +set spark.sql.parser.quotedRegexColumnNames=true; + +-- Regex columns +SELECT `(a)?+.+` FROM testData2 WHERE a = 1; +SELECT `(A)?+.+` FROM testData2 WHERE a = 1; +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1; +SELECT t.`(A)?+.+` FROM testData2 t WHERE a = 1; +SELECT `(a|B)` FROM testData2 WHERE a = 2; +SELECT `(A|b)` FROM testData2 WHERE a = 2; +SELECT `(a|B)?+.+` FROM testData2 WHERE a = 2; +SELECT `(A|b)?+.+` FROM testData2 WHERE a = 2; +SELECT `(e|f)` FROM testData2; +SELECT t.`(e|f)` FROM testData2 t; +SELECT p.`(KEY)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3; +SELECT p.`(key)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3; + +set spark.sql.caseSensitive=true; + +CREATE OR REPLACE TEMPORARY VIEW testdata3 AS SELECT * FROM VALUES +(0, 1), (1, 2), (2, 3), (3, 4) +AS testdata3(a, b); + +-- Regex columns +SELECT `(A)?+.+` FROM testdata3; +SELECT `(a)?+.+` FROM testdata3; +SELECT `(A)?+.+` FROM testdata3 WHERE a > 1; +SELECT `(a)?+.+` FROM testdata3 where `a` > 1; +SELECT SUM(`a`) FROM testdata3; +SELECT SUM(`(a)`) FROM testdata3; +SELECT SUM(`(a)?+.+`) FROM testdata3; +SELECT SUM(a) FROM testdata3 GROUP BY `a`; +-- AnalysisException +SELECT SUM(a) FROM testdata3 GROUP BY `(a)`; +SELECT SUM(a) FROM testdata3 GROUP BY `(a)?+.+`; diff --git a/sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out b/sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out new file mode 100644 index 0000000000..2dade86f35 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/query_regex_column.sql.out @@ -0,0 +1,313 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 34 + + +-- !query 0 +set spark.sql.parser.quotedRegexColumnNames=false +-- !query 0 schema +struct<key:string,value:string> +-- !query 0 output +spark.sql.parser.quotedRegexColumnNames false + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, "1", "11"), (2, "2", "22"), (3, "3", "33"), (4, "4", "44"), (5, "5", "55"), (6, "6", "66") +AS testData(key, value1, value2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE OR REPLACE TEMPORARY VIEW testData2 AS SELECT * FROM VALUES +(1, 1, 1, 2), (1, 2, 1, 2), (2, 1, 2, 3), (2, 2, 2, 3), (3, 1, 3, 4), (3, 2, 3, 4) +AS testData2(A, B, c, d) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT `(a)?+.+` FROM testData2 WHERE a = 1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)?+.+`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 7 + + +-- !query 4 +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 't.`(a)?+.+`' given input columns: [t.A, t.B, t.c, t.d]; line 1 pos 7 + + +-- !query 5 +SELECT `(a|b)` FROM testData2 WHERE a = 2 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a|b)`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 7 + + +-- !query 6 +SELECT `(a|b)?+.+` FROM testData2 WHERE a = 2 +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a|b)?+.+`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 7 + + +-- !query 7 +SELECT SUM(`(a|b)?+.+`) FROM testData2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a|b)?+.+`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 11 + + +-- !query 8 +SELECT SUM(`(a)`) FROM testData2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)`' given input columns: [testdata2.A, testdata2.B, testdata2.c, testdata2.d]; line 1 pos 11 + + +-- !query 9 +set spark.sql.parser.quotedRegexColumnNames=true +-- !query 9 schema +struct<key:string,value:string> +-- !query 9 output +spark.sql.parser.quotedRegexColumnNames true + + +-- !query 10 +SELECT `(a)?+.+` FROM testData2 WHERE a = 1 +-- !query 10 schema +struct<B:int,c:int,d:int> +-- !query 10 output +1 1 2 +2 1 2 + + +-- !query 11 +SELECT `(A)?+.+` FROM testData2 WHERE a = 1 +-- !query 11 schema +struct<B:int,c:int,d:int> +-- !query 11 output +1 1 2 +2 1 2 + + +-- !query 12 +SELECT t.`(a)?+.+` FROM testData2 t WHERE a = 1 +-- !query 12 schema +struct<B:int,c:int,d:int> +-- !query 12 output +1 1 2 +2 1 2 + + +-- !query 13 +SELECT t.`(A)?+.+` FROM testData2 t WHERE a = 1 +-- !query 13 schema +struct<B:int,c:int,d:int> +-- !query 13 output +1 1 2 +2 1 2 + + +-- !query 14 +SELECT `(a|B)` FROM testData2 WHERE a = 2 +-- !query 14 schema +struct<A:int,B:int> +-- !query 14 output +2 1 +2 2 + + +-- !query 15 +SELECT `(A|b)` FROM testData2 WHERE a = 2 +-- !query 15 schema +struct<A:int,B:int> +-- !query 15 output +2 1 +2 2 + + +-- !query 16 +SELECT `(a|B)?+.+` FROM testData2 WHERE a = 2 +-- !query 16 schema +struct<c:int,d:int> +-- !query 16 output +2 3 +2 3 + + +-- !query 17 +SELECT `(A|b)?+.+` FROM testData2 WHERE a = 2 +-- !query 17 schema +struct<c:int,d:int> +-- !query 17 output +2 3 +2 3 + + +-- !query 18 +SELECT `(e|f)` FROM testData2 +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT t.`(e|f)` FROM testData2 t +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +SELECT p.`(KEY)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3 +-- !query 20 schema +struct<value1:string,value2:string,b:int,A:int,c:int,d:int> +-- !query 20 output +1 11 1 1 1 2 +1 11 2 1 1 2 +2 22 1 2 2 3 +2 22 2 2 2 3 + + +-- !query 21 +SELECT p.`(key)?+.+`, b, testdata2.`(b)?+.+` FROM testData p join testData2 ON p.key = testData2.a WHERE key < 3 +-- !query 21 schema +struct<value1:string,value2:string,b:int,A:int,c:int,d:int> +-- !query 21 output +1 11 1 1 1 2 +1 11 2 1 1 2 +2 22 1 2 2 3 +2 22 2 2 2 3 + + +-- !query 22 +set spark.sql.caseSensitive=true +-- !query 22 schema +struct<key:string,value:string> +-- !query 22 output +spark.sql.caseSensitive true + + +-- !query 23 +CREATE OR REPLACE TEMPORARY VIEW testdata3 AS SELECT * FROM VALUES +(0, 1), (1, 2), (2, 3), (3, 4) +AS testdata3(a, b) +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT `(A)?+.+` FROM testdata3 +-- !query 24 schema +struct<a:int,b:int> +-- !query 24 output +0 1 +1 2 +2 3 +3 4 + + +-- !query 25 +SELECT `(a)?+.+` FROM testdata3 +-- !query 25 schema +struct<b:int> +-- !query 25 output +1 +2 +3 +4 + + +-- !query 26 +SELECT `(A)?+.+` FROM testdata3 WHERE a > 1 +-- !query 26 schema +struct<a:int,b:int> +-- !query 26 output +2 3 +3 4 + + +-- !query 27 +SELECT `(a)?+.+` FROM testdata3 where `a` > 1 +-- !query 27 schema +struct<b:int> +-- !query 27 output +3 +4 + + +-- !query 28 +SELECT SUM(`a`) FROM testdata3 +-- !query 28 schema +struct<sum(a):bigint> +-- !query 28 output +6 + + +-- !query 29 +SELECT SUM(`(a)`) FROM testdata3 +-- !query 29 schema +struct<sum(a):bigint> +-- !query 29 output +6 + + +-- !query 30 +SELECT SUM(`(a)?+.+`) FROM testdata3 +-- !query 30 schema +struct<sum(b):bigint> +-- !query 30 output +10 + + +-- !query 31 +SELECT SUM(a) FROM testdata3 GROUP BY `a` +-- !query 31 schema +struct<sum(a):bigint> +-- !query 31 output +0 +1 +2 +3 + + +-- !query 32 +SELECT SUM(a) FROM testdata3 GROUP BY `(a)` +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)`' given input columns: [testdata3.a, testdata3.b]; line 1 pos 38 + + +-- !query 33 +SELECT SUM(a) FROM testdata3 GROUP BY `(a)?+.+` +-- !query 33 schema +struct<> +-- !query 33 output +org.apache.spark.sql.AnalysisException +cannot resolve '`(a)?+.+`' given input columns: [testdata3.a, testdata3.b]; line 1 pos 38 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b52d50b195..4568b67024 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -543,10 +543,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-17237 remove backticks in a pivot result schema") { val df = Seq((2, 3, 4), (3, 4, 5)).toDF("a", "x", "y") - checkAnswer( - df.groupBy("a").pivot("x").agg(count("y"), avg("y")).na.fill(0), - Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + df.groupBy("a").pivot("x").agg(count("y"), avg("y")).na.fill(0), + Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) + ) + } } test("aggregate function in GROUP BY") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index e63c5cb194..47c9ba5847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext - class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -111,119 +111,124 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { (null, null) ).toDF("name", "spy") - val fillNumeric = input.na.fill(50.6) - checkAnswer( - fillNumeric, - Row("Bob", 16, 176.5) :: - Row("Alice", 50, 164.3) :: - Row("David", 60, 50.6) :: - Row("Nina", 25, 50.6) :: - Row("Amy", 50, 50.6) :: - Row(null, 50, 50.6) :: Nil) - - // Make sure the columns are properly named. - assert(fillNumeric.columns.toSeq === input.columns.toSeq) - - // string - checkAnswer( - input.na.fill("unknown").select("name"), - Row("Bob") :: Row("Alice") :: Row("David") :: - Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) - assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) - - // boolean - checkAnswer( - boolInput.na.fill(true).select("spy"), - Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil) - assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq) - - // fill double with subset columns - checkAnswer( - input.na.fill(50.6, "age" :: Nil).select("name", "age"), - Row("Bob", 16) :: - Row("Alice", 50) :: - Row("David", 60) :: - Row("Nina", 25) :: - Row("Amy", 50) :: - Row(null, 50) :: Nil) - - // fill boolean with subset columns - checkAnswer( - boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"), - Row("Bob", false) :: - Row("Alice", true) :: - Row("Mallory", true) :: - Row(null, true) :: Nil) - - // fill string with subset columns - checkAnswer( - Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), - Row("test", null)) - - checkAnswer( - Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) - .toDF("a", "b").na.fill(0), - Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), - (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), - Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) - :: Row(0, 0.2) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), - (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), - Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) - :: Row(0, 0.2f) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) - .toDF("a", "b").na.fill(2.34), - Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil - ) - - checkAnswer( - Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) - .toDF("a", "b").na.fill(5), - Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // boolean + checkAnswer( + boolInput.na.fill(true).select("spy"), + Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil) + assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) + + // fill boolean with subset columns + checkAnswer( + boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"), + Row("Bob", false) :: + Row("Alice", true) :: + Row("Mallory", true) :: + Row(null, true) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + + checkAnswer( + Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) + .toDF("a", "b").na.fill(0), + Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(2.34), + Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(5), + Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil + ) + } } test("fill with map") { - val df = Seq[(String, String, java.lang.Integer, java.lang.Long, + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val df = Seq[(String, String, java.lang.Integer, java.lang.Long, java.lang.Float, java.lang.Double, java.lang.Boolean)]( - (null, null, null, null, null, null, null)) - .toDF("stringFieldA", "stringFieldB", "integerField", "longField", - "floatField", "doubleField", "booleanField") - - val fillMap = Map( - "stringFieldA" -> "test", - "integerField" -> 1, - "longField" -> 2L, - "floatField" -> 3.3f, - "doubleField" -> 4.4d, - "booleanField" -> false) - - val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) - - checkAnswer(df.na.fill(fillMap), expectedRow) - checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version - - // Ensure replacement values are cast to the column data type. - checkAnswer(df.na.fill(Map( - "integerField" -> 1d, - "longField" -> 2d, - "floatField" -> 3d, - "doubleField" -> 4d)), - Row(null, null, 1, 2L, 3f, 4d, null)) - - // Ensure column types do not change. Columns that have null values replaced - // will no longer be flagged as nullable, so do not compare schemas directly. - assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) + (null, null, null, null, null, null, null)) + .toDF("stringFieldA", "stringFieldB", "integerField", "longField", + "floatField", "doubleField", "booleanField") + + val fillMap = Map( + "stringFieldA" -> "test", + "integerField" -> 1, + "longField" -> 2L, + "floatField" -> 3.3f, + "doubleField" -> 4.4d, + "booleanField" -> false) + + val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) + + + checkAnswer(df.na.fill(fillMap), expectedRow) + checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version + + // Ensure replacement values are cast to the column data type. + checkAnswer(df.na.fill(Map( + "integerField" -> 1d, + "longField" -> 2d, + "floatField" -> 3d, + "doubleField" -> 4d)), + Row(null, null, 1, 2L, 3f, 4d, null)) + + // Ensure column types do not change. Columns that have null values replaced + // will no longer be flagged as nullable, so do not compare schemas directly. + assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) + } } test("replace") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index dd118f88e3..09502d05f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.Matchers._ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -263,52 +264,56 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("crosstab") { - val rng = new Random() - val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) - val df = data.toDF("a", "b") - val crosstab = df.stat.crosstab("a", "b") - val columnNames = crosstab.schema.fieldNames - assert(columnNames(0) === "a_b") - // reduce by key - val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) - val rows = crosstab.collect() - rows.foreach { row => - val i = row.getString(0).toInt - for (col <- 1 until columnNames.length) { - val j = columnNames(col).toInt - assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") + val crosstab = df.stat.crosstab("a", "b") + val columnNames = crosstab.schema.fieldNames + assert(columnNames(0) === "a_b") + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 until columnNames.length) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } } } } test("special crosstab elements (., '', null, ``)") { - val data = Seq( - ("a", Double.NaN, "ho"), - (null, 2.0, "ho"), - ("a.b", Double.NegativeInfinity, ""), - ("b", Double.PositiveInfinity, "`ha`"), - ("a", 1.0, null) - ) - val df = data.toDF("1", "2", "3") - val ct1 = df.stat.crosstab("1", "2") - // column fields should be 1 + distinct elements of second column - assert(ct1.schema.fields.length === 6) - assert(ct1.collect().length === 4) - val ct2 = df.stat.crosstab("1", "3") - assert(ct2.schema.fields.length === 5) - assert(ct2.schema.fieldNames.contains("ha")) - assert(ct2.collect().length === 4) - val ct3 = df.stat.crosstab("3", "2") - assert(ct3.schema.fields.length === 6) - assert(ct3.schema.fieldNames.contains("NaN")) - assert(ct3.schema.fieldNames.contains("Infinity")) - assert(ct3.schema.fieldNames.contains("-Infinity")) - assert(ct3.collect().length === 4) - val ct4 = df.stat.crosstab("3", "1") - assert(ct4.schema.fields.length === 5) - assert(ct4.schema.fieldNames.contains("null")) - assert(ct4.schema.fieldNames.contains("a.b")) - assert(ct4.collect().length === 4) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) + } } test("Frequent Items") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5ae27032e0..3f3a6221d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1179,28 +1179,31 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS()) - checkAnswer( - df.select(df("`a.b`.c.`d..e`.`f`")), - Row(1) - ) - - val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS()) - checkAnswer( - df2.select(df2("`a b`.c.d e.f")), - Row(1) - ) - - def checkError(testFun: => Unit): Unit = { - val e = intercept[org.apache.spark.sql.AnalysisException] { - testFun + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS()) + checkAnswer( + df.select(df("`a.b`.c.`d..e`.`f`")), + Row(1) + ) + + val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS()) + checkAnswer( + df2.select(df2("`a b`.c.d e.f")), + Row(1) + ) + + def checkError(testFun: => Unit): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + testFun + } + assert(e.getMessage.contains("syntax error in attribute name:")) } - assert(e.getMessage.contains("syntax error in attribute name:")) + + checkError(df("`abc.`c`")) + checkError(df("`abc`..d")) + checkError(df("`a`.b.")) + checkError(df("`a.b`.c.`d")) } - checkError(df("`abc.`c`")) - checkError(df("`abc`..d")) - checkError(df("`a`.b.")) - checkError(df("`a.b`.c.`d")) } test("SPARK-7324 dropDuplicates") { @@ -1928,11 +1931,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-17957: outer join + na.fill") { - val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") - val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") - val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) - val df3 = Seq((3, 1)).toDF("a", "d") - checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1)) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") + val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") + val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) + val df3 = Seq((3, 1)).toDF("a", "d") + checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1)) + } } test("SPARK-17123: Performing set operations that combine non-scala native types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 87b7b090de..825840707d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} +import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ @@ -244,6 +244,85 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) } + test("REGEX column specification") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + var e = intercept[AnalysisException] { + ds.select(expr("`(_1)?+.+`").as[Int]) + }.getMessage + assert(e.contains("cannot resolve '`(_1)?+.+`'")) + + e = intercept[AnalysisException] { + ds.select(expr("`(_1|_2)`").as[Int]) + }.getMessage + assert(e.contains("cannot resolve '`(_1|_2)`'")) + + e = intercept[AnalysisException] { + ds.select(ds("`(_1)?+.+`")) + }.getMessage + assert(e.contains("Cannot resolve column name \"`(_1)?+.+`\"")) + + e = intercept[AnalysisException] { + ds.select(ds("`(_1|_2)`")) + }.getMessage + assert(e.contains("Cannot resolve column name \"`(_1|_2)`\"")) + } + + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "true") { + checkDataset( + ds.select(ds.col("_2")).as[Int], + 1, 2, 3) + + checkDataset( + ds.select(ds.colRegex("`(_1)?+.+`")).as[Int], + 1, 2, 3) + + checkDataset( + ds.select(ds("`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + + checkDataset( + ds.alias("g") + .select(ds("g.`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + + checkDataset( + ds.select(ds("`(_1)?+.+`")) + .select(expr("_2").as[Int]), + 1, 2, 3) + + checkDataset( + ds.alias("g") + .select(ds("g.`(_1)?+.+`")) + .select(expr("_2").as[Int]), + 1, 2, 3) + + checkDataset( + ds.select(expr("`(_1)?+.+`").as[Int]), + 1, 2, 3) + val m = ds.select(expr("`(_1|_2)`")) + + checkDataset( + ds.select(expr("`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + + checkDataset( + ds.alias("g") + .select(expr("g.`(_1)?+.+`").as[Int]), + 1, 2, 3) + + checkDataset( + ds.alias("g") + .select(expr("g.`(_1|_2)`")) + .select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + } + } + test("filter") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 472ff7385b..c78ec6d9a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1219,7 +1219,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3483 Special chars in column names") { val data = Seq("""{"key?number1": "value1", "key.number2": "value2"}""").toDS() spark.read.json(data).createOrReplaceTempView("records") - sql("SELECT `key?number1`, `key.number2` FROM records") + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + sql("SELECT `key?number1`, `key.number2` FROM records") + } } test("SPARK-3814 Support Bitwise & operator") { @@ -1339,7 +1341,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .json(Seq("""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""").toDS()) .createOrReplaceTempView("t") - checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } test("SPARK-6583 order by aggregated function") { @@ -1835,25 +1839,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } // Create paths with unusual characters - val specialCharacterPath = sql( - """ + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + val specialCharacterPath = sql( + """ | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp """.stripMargin) - withTempView("specialCharacterTable") { - specialCharacterPath.createOrReplaceTempView("specialCharacterTable") - checkAnswer( - specialCharacterPath.select($"`r&&b.c`.*"), - nestedStructData.select($"record.*")) - checkAnswer( - sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + withTempView("specialCharacterTable") { + specialCharacterPath.createOrReplaceTempView("specialCharacterTable") + checkAnswer( + specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer( + sql( + "SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), nestedStructData.select($"record.r1")) - checkAnswer( - sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), - nestedStructData.select($"record.r2")) - checkAnswer( - sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), - nestedStructData.select($"record.r1.*")) + checkAnswer( + sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + } } // Try star expanding a scalar. This should fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 704823ad51..1cde137edb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -937,14 +937,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Map("e" -> null)) :: Nil ) - checkAnswer( - sql("select `map`['c'] from jsonWithSimpleMap"), - Row(null) :: - Row(null) :: - Row(3) :: - Row(1) :: - Row(null) :: Nil - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + sql("select `map`['c'] from jsonWithSimpleMap"), + Row(null) :: + Row(null) :: + Row(3) :: + Row(1) :: + Row(null) :: Nil + ) + } val innerStruct = StructType( StructField("field1", ArrayType(IntegerType, true), true) :: @@ -966,15 +968,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Map("f" -> Row(null, null))) :: Nil ) - checkAnswer( - sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), - Row(Seq(1, 2, 3, null), null) :: - Row(null, null) :: - Row(null, 4) :: - Row(null, 3) :: - Row(null, null) :: - Row(null, null) :: Nil - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), + Row(Seq(1, 2, 3, null), null) :: + Row(null, null) :: + Row(null, 4) :: + Row(null, 3) :: + Row(null, null) :: + Row(null, null) :: Nil + ) + } } test("SPARK-2096 Correctly parse dot notations") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 98427cfe30..c43c1ec8b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -544,7 +544,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex Seq(true, false).foreach { vectorized => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, - SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString) { + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString, + SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { withTempPath { path => Seq(Some(1), None).toDF("col.dots").write.parquet(path.getAbsolutePath) val readBack = spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NOT NULL") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 80868fff89..70338670c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { - protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row], enableRegex: Boolean = false) { test(sqlString) { - checkAnswer(spark.sql(sqlString), expectedAnswer) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> enableRegex.toString) { + checkAnswer(spark.sql(sqlString), expectedAnswer) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index b01d15eb91..65474a52dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -247,32 +248,34 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { assert(expectedSchema == spark.table("tableWithSchema").schema) - checkAnswer( - sql( - """SELECT - | `string$%Field`, - | cast(binaryField as string), - | booleanField, - | byteField, - | shortField, - | int_Field, - | `longField_:,<>=+/~^`, - | floatField, - | doubleField, - | decimalField1, - | decimalField2, - | dateField, - | timestampField, - | varcharField, - | charField, - | arrayFieldSimple, - | arrayFieldComplex, - | mapFieldSimple, - | mapFieldComplex, - | structFieldSimple, - | structFieldComplex FROM tableWithSchema""".stripMargin), - tableWithSchemaExpected - ) + withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { + checkAnswer( + sql( + """SELECT + | `string$%Field`, + | cast(binaryField as string), + | booleanField, + | byteField, + | shortField, + | int_Field, + | `longField_:,<>=+/~^`, + | floatField, + | doubleField, + | decimalField1, + | decimalField2, + | dateField, + | timestampField, + | varcharField, + | charField, + | arrayFieldSimple, + | arrayFieldComplex, + | mapFieldSimple, + | mapFieldComplex, + | structFieldSimple, + | structFieldComplex FROM tableWithSchema""".stripMargin), + tableWithSchemaExpected + ) + } } sqlTest( -- GitLab