diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 287718fab7f0dba4de361bb5da917112bb93c7eb..d58c4756938c712e6d01f4cdda68e755fe8cf94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } +/** + * A predicate that is evaluated to be true if there are at least `n` null values. + */ +case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: InternalRow): Boolean = { + var numNulls = 0 + var i = 0 + while (i < childrenArray.length && numNulls < n) { + val evalC = childrenArray(i).eval(input) + if (evalC == null) { + numNulls += 1 + } + i += 1 + } + numNulls >= n + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numNulls = ctx.freshName("numNulls") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($numNulls < $n) { + ${eval.code} + if (${eval.isNull}) { + $numNulls += 1; + } + } + """ + }.mkString("\n") + s""" + int $numNulls = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $numNulls >= $n; + """ + } +} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 29d706dcb39a7db75cb8fec873effcbf42d12f68..e4b6294dc7b8ec93e2b50ba0861dba269039e4d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,8 +31,14 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -object DefaultOptimizer extends Optimizer { - val batches = +class DefaultOptimizer extends Optimizer { + + /** + * Override to provide additional rules for the "Operator Optimizations" batch. + */ + val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown, - SamplePushDown, - PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - ColumnPruning, + SetOperationPushDown :: + SamplePushDown :: + PushPredicateThroughJoin :: + PushPredicateThroughProject :: + PushPredicateThroughGenerate :: + ColumnPruning :: // Operator combine - ProjectCollapsing, - CombineFilters, - CombineLimits, + ProjectCollapsing :: + CombineFilters :: + CombineLimits :: // Constant folding - NullPropagation, - OptimizeIn, - ConstantFolding, - LikeSimplification, - BooleanSimplification, - RemovePositive, - SimplifyFilters, - SimplifyCasts, - SimplifyCaseConversionExpressions) :: + NullPropagation :: + OptimizeIn :: + ConstantFolding :: + LikeSimplification :: + BooleanSimplification :: + RemovePositive :: + SimplifyFilters :: + SimplifyCasts :: + SimplifyCaseConversionExpressions :: + extendedOperatorOptimizationRules.toList : _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + // We need to preserve the nullability of c's output. + // So, we first create a outputMap and if a reference is from the output of + // c, we use that output attribute from c. + val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) + val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq + Project(projectList, c) } else { c } + } } /** @@ -517,6 +530,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => + // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and + // Not(AtLeastNNulls(1, e2)) + // (this is used to make sure there is no null in the result of e1 and e2 and + // they are added by FilterNullsInJoinKey optimziation rule), we can + // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). + Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index aacfc86ab0e49df4d9d76dc725248cac7595edc9..54b5f497726643567ec4484a7c1988537b8b935f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -86,7 +86,37 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + /** + * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children + * have at least one null value and atLeastNNulls.children are all attributes. + */ + private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { + val expressions = atLeastNNulls.children + val n = atLeastNNulls.n + if (n != 1) { + // AtLeastNNulls is not used to check if atLeastNNulls.children have + // at least one null value. + false + } else { + // AtLeastNNulls is used to check if atLeastNNulls.children have + // at least one null value. We need to make sure all atLeastNNulls.children + // are attributes. + expressions.forall(_.isInstanceOf[Attribute]) + } + } + + override def output: Seq[Attribute] = condition match { + case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => + // The condition is used to make sure that there is no null value in + // a.children. + val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) + child.output.map { + case attr if nonNullableAttributes.contains(attr) => + attr.withNullability(false) + case attr => attr + } + case _ => child.output + } } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a41185b4d8754b11161b119283176a3d267eea41..3e5515129874129414421d68784de3947c47a79b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => + protected val defaultOptimizer = new DefaultOptimizer + protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -186,7 +188,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 9fcb548af6bbbd2c32c2082a6510b26f277783be..649a5b44dc0368def5bebe23058f6e766fe52267 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -149,7 +148,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = defaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ace6c15dc8418384dd7507ce703696378adeec5e..bf197124d8dbc77a4bf372cc6401e87a15423224 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNulls") { + test("AtLeastNNonNullNans") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,11 +96,46 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) + } + + test("AtLeastNNull") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) + + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) + + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.USER_DEFAULT), + Literal(Float.MaxValue), + Literal(false)) + + checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index a4fd4cf3b330b36b5fae21165e0736b6ee34d13d..ea85f0657a726b76a2c9be20587554b355e5f52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6644e85d4a037d488f12a565d0c830bc0a314a35..387960c4b482bb86dd4eb7009044f150b2a322a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -413,6 +413,10 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) + val ADVANCED_SQL_OPTIMIZATION = booleanConf( + "spark.sql.advancedOptimization", + defaultValue = Some(true), isPublic = false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -484,6 +488,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) + private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a09846548e505fe8a9a7571aad86ac86ee0c..31e2b508d485e03ce138e10a728ba9cf7c20514c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -156,7 +157,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { + override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil + } @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala new file mode 100644 index 0000000000000000000000000000000000000000..5a4dde575696415a058d1671705c47fff1627bb4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * An optimization rule used to insert Filters to filter out rows whose equal join keys + * have at least one null values. For this kind of rows, they will not contribute to + * the join results of equal joins because a null does not equal another null. We can + * filter them out before shuffling join input rows. For example, we have two tables + * + * table1(key String, value Int) + * "str1"|1 + * null |2 + * + * table2(key String, value Int) + * "str1"|3 + * null |4 + * + * For a inner equal join, the result will be + * "str1"|1|"str1"|3 + * + * those two rows having null as the value of key will not contribute to the result. + * So, we can filter them out early. + * + * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. + * + */ +case class FilterNullsInJoinKey( + sqlContext: SQLContext) + extends Rule[LogicalPlan] { + + /** + * Checks if we need to add a Filter operator. We will add a Filter when + * there is any attribute in `keys` whose corresponding attribute of `keys` + * in `plan.output` is still nullable (`nullable` field is `true`). + */ + private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { + val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) + plan.output.filter(keyAttributeSet.contains).exists(_.nullable) + } + + /** + * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. + */ + private def addFilterIfNecessary( + keys: Seq[Expression], + child: LogicalPlan): LogicalPlan = { + // We get all attributes from keys. + val attributes = keys.filter(_.isInstanceOf[Attribute]) + + // Then, we create a Filter to make sure these attributes are non-nullable. + val filter = + if (attributes.nonEmpty) { + Filter(Not(AtLeastNNulls(1, attributes)), child) + } else { + child + } + + filter + } + + /** + * We reconstruct the join condition. + */ + private def reconstructJoinCondition( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + otherPredicate: Option[Expression]): Expression = { + // First, we rewrite the equal condition part. When we extract those keys, + // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). + val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { + case (l, r) => EqualTo(l, r) + }.reduce(And) + + // Then, we add otherPredicate. When we extract those equal condition part, + // we use splitConjunctivePredicates. So, it is safe to use + // And(rewrittenEqualJoinCondition, c). + val rewrittenJoinCondition = otherPredicate + .map(c => And(rewrittenEqualJoinCondition, c)) + .getOrElse(rewrittenEqualJoinCondition) + + rewrittenJoinCondition + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!sqlContext.conf.advancedSqlOptimizations) { + plan + } else { + plan transform { + case join: Join => join match { + // For a inner join having equal join condition part, we can add filters + // to both sides of the join operator. + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) + + // For a left outer join having equal join condition part, we can add a filter + // to the right side of the join operator. + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(rightKeys, right) => + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) + + // For a right outer join having equal join condition part, we can add a filter + // to the left side of the join operator. + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) + + // For a left semi join having equal join condition part, we can add filters + // to both sides of the join operator. + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => + val withLeftFilter = addFilterIfNecessary(leftKeys, left) + val withRightFilter = addFilterIfNecessary(rightKeys, right) + val rewrittenJoinCondition = + reconstructJoinCondition(leftKeys, rightKeys, condition) + + Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) + + case other => other + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f98e4acafbf2c5d8183eb92406649715b22ad7d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.test.TestSQLContext + +/** This is the test suite for FilterNullsInJoinKey optimization rule. */ +class FilterNullsInJoinKeySuite extends PlanTest { + + // We add predicate pushdown rules at here to make sure we do not + // create redundant Filter operators. Also, because the attribute ordering of + // the Project operator added by ColumnPruning may be not deterministic + // (the ordering may depend on the testing environment), + // we first construct the plan with expected Filter operators and then + // run the optimizer to add the the Project for column pruning. + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Operator Optimizations", FixedPoint(100), + FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning, + ProjectCollapsing) :: Nil + } + + val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) + + val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) + + test("inner join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For an inner join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("make sure we do not keep adding filters") { + val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some('a === 'e)) + .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) + + val optimized = Optimize.execute(joinedPlan.analyze) + val conditions = optimized.collect { + case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs + } + + // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. + assert(conditions.length === 3) + + // Make sure attribtues are indeed a, b, e, i, and j. + assert( + conditions.flatMap(exprs => exprs).toSet === + joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) + } + + test("inner join (partially optimized)") { + val joinCondition = + ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // We cannot extract attribute from the left join key. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .join(correctRight, Inner, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("inner join (not optimized)") { + val nonOptimizedJoinConditions = + Some('c - 100 + 'd === 'g + 1 - 'h) :: + Some('d > 'h || 'c === 'g) :: + Some('d + 'g + 'c > 'd - 'h) :: Nil + + nonOptimizedJoinConditions.foreach { joinCondition => + val joinedPlan = + leftRelation + .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) + .select('a, 'c, 'f, 'd, 'h, 'g) + + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) + } + } + + test("left outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left outer join, FilterNullsInJoinKey add filter to the right side. + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + leftRelation + .join(correctRight, LeftOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("right outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a right outer join, FilterNullsInJoinKey add filter to the left side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(rightRelation, RightOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } + + test("full outer join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, FullOuter, Some(joinCondition)) + .select('a, 'f, 'd, 'h) + + // FilterNullsInJoinKey does not fire for a full outer join. + val optimized = Optimize.execute(joinedPlan.analyze) + + comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) + } + + test("left semi join") { + val joinCondition = + ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) + + val joinedPlan = + leftRelation + .join(rightRelation, LeftSemi, Some(joinCondition)) + .select('a, 'd) + + val optimized = Optimize.execute(joinedPlan.analyze) + + // For a left semi join, FilterNullsInJoinKey add filter to both side. + val correctLeft = + leftRelation + .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) + + val correctRight = + rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) + + val correctAnswer = + correctLeft + .join(correctRight, LeftSemi, Some(joinCondition)) + .select('a, 'd) + + comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) + } +}