Skip to content
Snippets Groups Projects
Commit 0f80990b authored by Yin Huai's avatar Yin Huai Committed by Reynold Xin
Browse files

[SPARK-8023][SQL] Add "deterministic" attribute to Expression to avoid...

[SPARK-8023][SQL] Add "deterministic" attribute to Expression to avoid collapsing nondeterministic projects.

This closes #6570.

Author: Yin Huai <yhuai@databricks.com>
Author: Reynold Xin <rxin@databricks.com>

Closes #6573 from rxin/deterministic and squashes the following commits:

356cd22 [Reynold Xin] Added unit test for the optimizer.
da3fde1 [Reynold Xin] Merge pull request #6570 from yhuai/SPARK-8023
da56200 [Yin Huai] Comments.
e38f264 [Yin Huai] Comment.
f9d6a73 [Yin Huai] Add a deterministic method to Expression.
parent 7b7f7b6c
No related branches found
No related tags found
No related merge requests found
...@@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] { ...@@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable
*/ */
def foldable: Boolean = false def foldable: Boolean = false
/**
* Returns true when the current expression always return the same result for fixed input values.
*/
// TODO: Need to define explicit input values vs implicit input values.
def deterministic: Boolean = true
def nullable: Boolean def nullable: Boolean
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
/** Returns the result of evaluating this expression on a given input Row */ /** Returns the result of evaluating this expression on a given input Row */
......
...@@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { ...@@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
*/ */
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
override def deterministic: Boolean = false
override def nullable: Boolean = false override def nullable: Boolean = false
override def dataType: DataType = DoubleType override def dataType: DataType = DoubleType
......
...@@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] { ...@@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] {
* expressions into one single expression. * expressions into one single expression.
*/ */
object ProjectCollapsing extends Rule[LogicalPlan] { object ProjectCollapsing extends Rule[LogicalPlan] {
/** Returns true if any expression in projectList is non-deterministic. */
private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = {
projectList.exists(expr => expr.find(!_.deterministic).isDefined)
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(projectList1, Project(projectList2, child)) => // We only collapse these two Projects if the child Project's expressions are all
// deterministic.
case Project(projectList1, Project(projectList2, child))
if !hasNondeterministic(projectList2) =>
// Create a map of Aliases to their values from the child projection. // Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect { val aliasMap = AttributeMap(projectList2.collect {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
class ProjectCollapsingSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int)
test("collapse two deterministic, independent projects into one") {
val query = testRelation
.select(('a + 1).as('a_plus_1), 'b)
.select('a_plus_1, ('b + 1).as('b_plus_1))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze
comparePlans(optimized, correctAnswer)
}
test("collapse two deterministic, dependent projects into one") {
val query = testRelation
.select(('a + 1).as('a_plus_1), 'b)
.select(('a_plus_1 + 1).as('a_plus_2), 'b)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation.select(
(('a + 1).as('a_plus_1) + 1).as('a_plus_2),
'b).analyze
comparePlans(optimized, correctAnswer)
}
test("do not collapse nondeterministic projects") {
val query = testRelation
.select(Rand(10).as('rand))
.select(('rand + 1).as('rand1), ('rand + 2).as('rand2))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = query.analyze
comparePlans(optimized, correctAnswer)
}
}
...@@ -19,6 +19,7 @@ package org.apache.spark.sql ...@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.scalatest.Matchers._ import org.scalatest.Matchers._
import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.implicits._
...@@ -452,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest { ...@@ -452,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest {
} }
test("rand") { test("rand") {
val randCol = testData.select('key, rand(5L).as("rand")) val randCol = testData.select($"key", rand(5L).as("rand"))
randCol.columns.length should be (2) randCol.columns.length should be (2)
val rows = randCol.collect() val rows = randCol.collect()
rows.foreach { row => rows.foreach { row =>
assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) <= 1.0)
assert(row.getDouble(1) >= 0.0) assert(row.getDouble(1) >= 0.0)
} }
def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
val projects = df.queryExecution.executedPlan.collect {
case project: Project => project
}
assert(projects.size === expectedNumProjects)
}
// We first create a plan with two Projects.
// Project [rand + 1 AS rand1, rand - 1 AS rand2]
// Project [key, (Rand 5 + 1) AS rand]
// LogicalRDD [key, value]
// Because Rand function is not deterministic, the column rand is not deterministic.
// So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2]
// and Project [key, Rand 5 AS rand]. The final plan still has two Projects.
val dfWithTwoProjects =
testData
.select($"key", (rand(5L) + 1).as("rand"))
.select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2"))
checkNumProjects(dfWithTwoProjects, 2)
// Now, we add one more project rand1 - rand2 on top of the query plan.
// Since rand1 and rand2 are deterministic (they basically apply +/- to the generated
// rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2.
// So, the plan will be optimized from ...
// Project [(rand1 - rand2) AS (rand1 - rand2)]
// Project [rand + 1 AS rand1, rand - 1 AS rand2]
// Project [key, (Rand 5 + 1) AS rand]
// LogicalRDD [key, value]
// to ...
// Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)]
// Project [key, Rand 5 AS rand]
// LogicalRDD [key, value]
val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2")
checkNumProjects(dfWithThreeProjects, 2)
dfWithThreeProjects.collect().foreach { row =>
assert(row.getDouble(0) === 2.0 +- 0.0001)
}
} }
test("randn") { test("randn") {
......
...@@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre ...@@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre
type UDFType = UDF type UDFType = UDF
override def deterministic: Boolean = isUDFDeterministic
override def nullable: Boolean = true override def nullable: Boolean = true
@transient @transient
...@@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr ...@@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
extends Expression with HiveInspectors with Logging { extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF type UDFType = GenericUDF
override def deterministic: Boolean = isUDFDeterministic
override def nullable: Boolean = true override def nullable: Boolean = true
@transient @transient
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment