diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f17c37256c9e5f13387152980e82a0b5cb55a10c..ab9de023e2b305524cde3942a25d0d01aa50c02d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1859,28 +1859,37 @@ class Analyzer( case p: Project => p case f: Filter => f + case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => + val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) + val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) + a.transformExpressions { case e => + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + }.copy(child = newChild) + // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => - val leafNondeterministic = expr.collect { - case n: Nondeterministic => n - } - leafNondeterministic.map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")(isGenerated = true) - } - new TreeNodeRef(e) -> ne - } - }.toMap + val nondeterToAttr = getNondeterToAttr(p.expressions) val newPlan = p.transformExpressions { case e => - nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) } - val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } + + private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { + exprs.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { case n: Nondeterministic => n } + leafNondeterministic.distinct.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")(isGenerated = true) + } + e -> ne + } + }.toMap + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..72e10eadf79f3a21abf57850330867b7b4b88670 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * Test suite for moving non-deterministic expressions into Project. + */ +class PullOutNondeterministicSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val b = 'b.int + private lazy val r = LocalRelation(a, b) + private lazy val rnd = Rand(10).as('_nondeterministic) + private lazy val rndref = rnd.toAttribute + + test("no-op on filter") { + checkAnalysis( + r.where(Rand(10) > Literal(1.0)), + r.where(Rand(10) > Literal(1.0)) + ) + } + + test("sort") { + checkAnalysis( + r.sortBy(SortOrder(Rand(10), Ascending)), + r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b) + ) + } + + test("aggregate") { + checkAnalysis( + r.groupBy(Rand(10))(Rand(10).as("rnd")), + r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd")) + ) + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 9c3a145f3aaa77ebe7d295d6745a4dec88e17638..c64520ff93c83653e607a8a5a3a47b6a72c744c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -137,10 +137,14 @@ GROUP BY position 3 is an aggregate function, and aggregate functions are not al -- !query 13 select a, rand(0), sum(b) from data group by a, 2 -- !query 13 schema -struct<> +struct<a:int,rand(0):double,sum(b):bigint> -- !query 13 output -org.apache.spark.sql.AnalysisException -nondeterministic expression rand(0) should not appear in grouping expression.; +1 0.4048454303385226 2 +1 0.8446490682263027 1 +2 0.5871875724155838 1 +2 0.8865128837019473 2 +3 0.742083829230211 1 +3 0.9179913208300406 2 -- !query 14