Skip to content
Snippets Groups Projects
Commit 3f6d28a5 authored by Wenchen Fan's avatar Wenchen Fan Committed by Yin Huai
Browse files

[SPARK-9102] [SQL] Improve project collapse with nondeterministic expressions

Currently we will stop project collapse when the lower projection has nondeterministic expressions. However it's overkill sometimes, we should be able to optimize `df.select(Rand(10)).select('a)` to `df.select('a)`

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7445 from cloud-fan/non-deterministic and squashes the following commits:

0deaef6 [Wenchen Fan] Improve project collapse with nondeterministic expressions
parent 111c0553
No related branches found
No related tags found
No related merge requests found
......@@ -206,31 +206,33 @@ object ColumnPruning extends Rule[LogicalPlan] {
*/
object ProjectCollapsing extends Rule[LogicalPlan] {
/** Returns true if any expression in projectList is non-deterministic. */
private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = {
projectList.exists(expr => expr.find(!_.deterministic).isDefined)
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
// We only collapse these two Projects if the child Project's expressions are all
// deterministic.
case Project(projectList1, Project(projectList2, child))
if !hasNondeterministic(projectList2) =>
case p @ Project(projectList1, Project(projectList2, child)) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
case a @ Alias(e, _) => (a.toAttribute, a)
case a: Alias => (a.toAttribute, a)
})
// Substitute any attributes that are produced by the child projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// TODO: Fix TransformBase to avoid the cast below.
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute if aliasMap.contains(a) => aliasMap(a)
}).asInstanceOf[Seq[NamedExpression]]
// We only collapse these two Projects if their overlapped expressions are all
// deterministic.
val hasNondeterministic = projectList1.flatMap(_.collect {
case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
}).exists(_.find(!_.deterministic).isDefined)
Project(substitutedProjection, child)
if (hasNondeterministic) {
p
} else {
// Substitute any attributes that are produced by the child projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// TODO: Fix TransformBase to avoid the cast below.
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
Project(substitutedProjection, child)
}
}
}
......
......@@ -70,4 +70,30 @@ class ProjectCollapsingSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("collapse two nondeterministic, independent projects into one") {
val query = testRelation
.select(Rand(10).as('rand))
.select(Rand(20).as('rand2))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select(Rand(20).as('rand2)).analyze
comparePlans(optimized, correctAnswer)
}
test("collapse one nondeterministic, one deterministic, independent projects into one") {
val query = testRelation
.select(Rand(10).as('rand), 'a)
.select(('a + 1).as('a_plus_1))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select(('a + 1).as('a_plus_1)).analyze
comparePlans(optimized, correctAnswer)
}
}
......@@ -745,8 +745,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
test("SPARK-8072: Better Exception for Duplicate Columns") {
// only one duplicate column present
val e = intercept[org.apache.spark.sql.AnalysisException] {
val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1")
.write.format("parquet").save("temp")
Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1")
.write.format("parquet").save("temp")
}
assert(e.getMessage.contains("Duplicate column(s)"))
assert(e.getMessage.contains("parquet"))
......@@ -755,9 +755,9 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
// multiple duplicate columns present
val f = intercept[org.apache.spark.sql.AnalysisException] {
val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7))
.toDF("column1", "column2", "column3", "column1", "column3")
.write.format("json").save("temp")
Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7))
.toDF("column1", "column2", "column3", "column1", "column3")
.write.format("json").save("temp")
}
assert(f.getMessage.contains("Duplicate column(s)"))
assert(f.getMessage.contains("JSON"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment