From 0f80990bfac1e9969644952d1d8edaf7d26fb436 Mon Sep 17 00:00:00 2001 From: Yin Huai <yhuai@databricks.com> Date: Tue, 2 Jun 2015 00:20:52 -0700 Subject: [PATCH] [SPARK-8023][SQL] Add "deterministic" attribute to Expression to avoid collapsing nondeterministic projects. This closes #6570. Author: Yin Huai <yhuai@databricks.com> Author: Reynold Xin <rxin@databricks.com> Closes #6573 from rxin/deterministic and squashes the following commits: 356cd22 [Reynold Xin] Added unit test for the optimizer. da3fde1 [Reynold Xin] Merge pull request #6570 from yhuai/SPARK-8023 da56200 [Yin Huai] Comments. e38f264 [Yin Huai] Comment. f9d6a73 [Yin Huai] Add a deterministic method to Expression. --- .../sql/catalyst/expressions/Expression.scala | 8 ++ .../sql/catalyst/expressions/random.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 11 ++- .../optimizer/ProjectCollapsingSuite.scala | 73 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 41 ++++++++++- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 + 6 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d199287844..adc6505d69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] { * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false + + /** + * Returns true when the current expression always return the same result for fixed input values. + */ + // TODO: Need to define explicit input values vs implicit input values. + def deterministic: Boolean = true + def nullable: Boolean + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 4f4f67a6e4..b2647124c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { */ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957c..b25fb48f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] { * expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { + + /** Returns true if any expression in projectList is non-deterministic. */ + private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { + projectList.exists(expr => expr.find(!_.deterministic).isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Project(projectList1, Project(projectList2, child)) => + // We only collapse these two Projects if the child Project's expressions are all + // deterministic. + case Project(projectList1, Project(projectList2, child)) + if !hasNondeterministic(projectList2) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala new file mode 100644 index 0000000000..151654bffb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ProjectCollapsingSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b8bb1bff9e..bfba379d9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -452,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest { } test("rand") { - val randCol = testData.select('key, rand(5L).as("rand")) + val randCol = testData.select($"key", rand(5L).as("rand")) randCol.columns.length should be (2) val rows = randCol.collect() rows.foreach { row => assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) >= 0.0) } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } } test("randn") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 64a49c83cb..1658bb93b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre type UDFType = UDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient -- GitLab