diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f2b9764b0f08890b7c7169885c2a53bae9aea549..1802cd4bb131b59bb696f5ec6d289fad2e3270c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -111,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) RemoveRedundantProject, SimplifyCreateStructOps, SimplifyCreateArrayOps, - SimplifyCreateMapOps) ++ + SimplifyCreateMapOps, + CombineConcats) ++ extendedOperatorOptimizationRules: _*) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 34382bd2724069a9433305ab60124800b295858e..d3ef5ea8409193f42f033503096bf2a4318cbf89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -543,3 +544,28 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } } + +/** + * Combine nested [[Concat]] expressions. + */ +object CombineConcats extends Rule[LogicalPlan] { + + private def flattenConcats(concat: Concat): Concat = { + val stack = Stack[Expression](concat) + val flattened = ArrayBuffer.empty[Expression] + while (stack.nonEmpty) { + stack.pop() match { + case Concat(children) => + stack.pushAll(children.reverse) + case child => + flattened += child + } + } + Concat(flattened) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { + case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) => + flattenConcats(concat) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..7aa9fbba9a10afd84dfe18a7a02de3f03f773c40 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.StringType + + +class CombineConcatsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("CombineConcatsSuite", FixedPoint(50), CombineConcats) :: Nil + } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + test("combine nested Concat exprs") { + def str(s: String): Literal = Literal(s, StringType) + assertEquivalent( + Concat( + Concat(str("a") :: str("b") :: Nil) :: + str("c") :: + str("d") :: + Nil), + Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) + assertEquivalent( + Concat( + str("a") :: + Concat(str("b") :: str("c") :: Nil) :: + str("d") :: + Nil), + Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) + assertEquivalent( + Concat( + str("a") :: + str("b") :: + Concat(str("c") :: str("d") :: Nil) :: + Nil), + Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) + assertEquivalent( + Concat( + Concat( + str("a") :: + Concat( + str("b") :: + Concat(str("c") :: str("d") :: Nil) :: + Nil) :: + Nil) :: + Nil), + Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 7005cafe35cab4e5bd5826e1e03d517cadd28ce0..f685779cd34aff41908454980709f3c93ef1eb97 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -4,3 +4,7 @@ select format_string(); -- A pipe operator for string concatenation select 'a' || 'b' || 'c'; + +-- Check if catalyst combine nested `Concat`s +EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 8ee075118e109086a115ca37795fa444854175b7..d48d1a80c03bc6614508daad8fb89764105e9634 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 +-- Number of queries: 4 -- !query 0 @@ -26,3 +26,29 @@ select 'a' || 'b' || 'c' struct<concat(concat(a, b), c):string> -- !query 2 output abc + + +-- !query 3 +EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col +FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) +-- !query 3 schema +struct<plan:string> +-- !query 3 output +== Parsed Logical Plan == +'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] ++- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] + +- 'UnresolvedTableValuedFunction range, [10] + +== Analyzed Logical Plan == +col: string +Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] ++- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] + +- Range (0, 10, step=1, splits=None) + +== Optimized Logical Plan == +Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] ++- Range (0, 10, step=1, splits=None) + +== Physical Plan == +*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] ++- *Range (0, 10, step=1, splits=2)