Skip to content
Snippets Groups Projects
Commit 6a4bfcd6 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Michael Armbrust
Browse files

[SPARK-13658][SQL] BooleanSimplification rule is slow with large boolean expressions

JIRA: https://issues.apache.org/jira/browse/SPARK-13658

## What changes were proposed in this pull request?

Quoted from JIRA description: When run TPCDS Q3 [1] with lots predicates to filter out the partitions, the optimizer rule BooleanSimplification take about 2 seconds (it use lots of sematicsEqual, which require copy the whole tree).

It will great if we could speedup it.

[1] https://github.com/cloudera/impala-tpcds-kit/blob/master/queries/q3.sql

How to speed up it:

When we ask the canonicalized expression in `Expression`, it calls `Canonicalize.execute` on itself. `Canonicalize.execute` basically transforms up all expressions included in this expression. However, we don't keep the canonicalized versions for these children expressions. So in next time we ask the canonicalized expressions for the children expressions (e.g., `BooleanSimplification`), we will rerun `Canonicalize.execute` on each of them. It wastes much time.

By forcing the children expressions to get and keep their canonicalized versions first, we can avoid re-canonicalize these expressions.

I simply benchmark it with an expression which is part of the where clause in TPCDS Q3:

    val testRelation = LocalRelation('ss_sold_date_sk.int, 'd_moy.int, 'i_manufact_id.int, 'ss_item_sk.string, 'i_item_sk.string, 'd_date_sk.int)

    val input = ('d_date_sk === 'ss_sold_date_sk) && ('ss_item_sk === 'i_item_sk) && ('i_manufact_id === 436) && ('d_moy === 12) && (('ss_sold_date_sk > 2415355 && 'ss_sold_date_sk < 2415385) || ('ss_sold_date_sk > 2415720 && 'ss_sold_date_sk < 2415750) || ('ss_sold_date_sk > 2416085 && 'ss_sold_date_sk < 2416115) || ('ss_sold_date_sk > 2416450 && 'ss_sold_date_sk < 2416480) || ('ss_sold_date_sk > 2416816 && 'ss_sold_date_sk < 2416846) || ('ss_sold_date_sk > 2417181 && 'ss_sold_date_sk < 2417211) || ('ss_sold_date_sk > 2417546 && 'ss_sold_date_sk < 2417576) || ('ss_sold_date_sk > 2417911 && 'ss_sold_date_sk < 2417941) || ('ss_sold_date_sk > 2418277 && 'ss_sold_date_sk < 2418307) || ('ss_sold_date_sk > 2418642 && 'ss_sold_date_sk < 2418672) || ('ss_sold_date_sk > 2419007 && 'ss_sold_date_sk < 2419037) || ('ss_sold_date_sk > 2419372 && 'ss_sold_date_sk < 2419402) || ('ss_sold_date_sk > 2419738 && 'ss_sold_date_sk < 2419768) || ('ss_sold_date_sk > 2420103 && 'ss_sold_date_sk < 2420133) || ('ss_sold_date_sk > 2420468 && 'ss_sold_date_sk < 2420498) || ('ss_sold_date_sk > 2420833 && 'ss_sold_date_sk < 2420863) || ('ss_sold_date_sk > 2421199 && 'ss_sold_date_sk < 2421229) || ('ss_sold_date_sk > 2421564 && 'ss_sold_date_sk < 2421594) || ('ss_sold_date_sk > 2421929 && 'ss_sold_date_sk < 2421959) || ('ss_sold_date_sk > 2422294 && 'ss_sold_date_sk < 2422324) || ('ss_sold_date_sk > 2422660 && 'ss_sold_date_sk < 2422690) || ('ss_sold_date_sk > 2423025 && 'ss_sold_date_sk < 2423055) || ('ss_sold_date_sk > 2423390 && 'ss_sold_date_sk < 2423420) || ('ss_sold_date_sk > 2423755 && 'ss_sold_date_sk < 2423785) || ('ss_sold_date_sk > 2424121 && 'ss_sold_date_sk < 2424151) || ('ss_sold_date_sk > 2424486 && 'ss_sold_date_sk < 2424516) || ('ss_sold_date_sk > 2424851 && 'ss_sold_date_sk < 2424881) || ('ss_sold_date_sk > 2425216 && 'ss_sold_date_sk < 2425246) || ('ss_sold_date_sk > 2425582 && 'ss_sold_date_sk < 2425612) || ('ss_sold_date_sk > 2425947 && 'ss_sold_date_sk < 2425977) || ('ss_sold_date_sk > 2426312 && 'ss_sold_date_sk < 2426342) || ('ss_sold_date_sk > 2426677 && 'ss_sold_date_sk < 2426707) || ('ss_sold_date_sk > 2427043 && 'ss_sold_date_sk < 2427073) || ('ss_sold_date_sk > 2427408 && 'ss_sold_date_sk < 2427438) || ('ss_sold_date_sk > 2427773 && 'ss_sold_date_sk < 2427803) || ('ss_sold_date_sk > 2428138 && 'ss_sold_date_sk < 2428168) || ('ss_sold_date_sk > 2428504 && 'ss_sold_date_sk < 2428534) || ('ss_sold_date_sk > 2428869 && 'ss_sold_date_sk < 2428899) || ('ss_sold_date_sk > 2429234 && 'ss_sold_date_sk < 2429264) || ('ss_sold_date_sk > 2429599 && 'ss_sold_date_sk < 2429629) || ('ss_sold_date_sk > 2429965 && 'ss_sold_date_sk < 2429995) || ('ss_sold_date_sk > 2430330 && 'ss_sold_date_sk < 2430360) || ('ss_sold_date_sk > 2430695 && 'ss_sold_date_sk < 2430725) || ('ss_sold_date_sk > 2431060 && 'ss_sold_date_sk < 2431090) || ('ss_sold_date_sk > 2431426 && 'ss_sold_date_sk < 2431456) || ('ss_sold_date_sk > 2431791 && 'ss_sold_date_sk < 2431821) || ('ss_sold_date_sk > 2432156 && 'ss_sold_date_sk < 2432186) || ('ss_sold_date_sk > 2432521 && 'ss_sold_date_sk < 2432551) || ('ss_sold_date_sk > 2432887 && 'ss_sold_date_sk < 2432917) || ('ss_sold_date_sk > 2433252 && 'ss_sold_date_sk < 2433282) || ('ss_sold_date_sk > 2433617 && 'ss_sold_date_sk < 2433647) || ('ss_sold_date_sk > 2433982 && 'ss_sold_date_sk < 2434012) || ('ss_sold_date_sk > 2434348 && 'ss_sold_date_sk < 2434378) || ('ss_sold_date_sk > 2434713 && 'ss_sold_date_sk < 2434743)))

    val plan = testRelation.where(input).analyze
    val actual = Optimize.execute(plan)

With this patch:

    352 milliseconds
    346 milliseconds
    340 milliseconds

Without this patch:

    585 milliseconds
    880 milliseconds
    677 milliseconds

## How was this patch tested?

Existing tests should pass.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #11647 from viirya/improve-expr-canonicalize.
parent 63f642ae
No related branches found
No related tags found
No related merge requests found
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.rules._
/** /**
* Rewrites an expression using rules that are guaranteed preserve the result while attempting * Rewrites an expression using rules that are guaranteed preserve the result while attempting
* to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
...@@ -30,26 +28,23 @@ import org.apache.spark.sql.catalyst.rules._ ...@@ -30,26 +28,23 @@ import org.apache.spark.sql.catalyst.rules._
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`. * by `hashCode`.
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
*/ */
object Canonicalize extends RuleExecutor[Expression] { object Canonicalize extends {
override protected def batches: Seq[Batch] = def execute(e: Expression): Expression = {
Batch( expressionReorder(ignoreNamesTypes(e))
"Expression Canonicalization", FixedPoint(100), }
IgnoreNamesTypes,
Reorder) :: Nil
/** Remove names and nullability from types. */ /** Remove names and nullability from types. */
protected object IgnoreNamesTypes extends Rule[Expression] { private def ignoreNamesTypes(e: Expression): Expression = e match {
override def apply(e: Expression): Expression = e transformUp { case a: AttributeReference =>
case a: AttributeReference => AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) case _ => e
}
} }
/** Collects adjacent commutative operations. */ /** Collects adjacent commutative operations. */
protected def gatherCommutative( private def gatherCommutative(
e: Expression, e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match {
case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
...@@ -57,25 +52,25 @@ object Canonicalize extends RuleExecutor[Expression] { ...@@ -57,25 +52,25 @@ object Canonicalize extends RuleExecutor[Expression] {
} }
/** Orders a set of commutative operations by their hash code. */ /** Orders a set of commutative operations by their hash code. */
protected def orderCommutative( private def orderCommutative(
e: Expression, e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(e, f).sortBy(_.hashCode()) gatherCommutative(e, f).sortBy(_.hashCode())
/** Rearrange expressions that are commutative or associative. */ /** Rearrange expressions that are commutative or associative. */
protected object Reorder extends Rule[Expression] { private def expressionReorder(e: Expression): Expression = e match {
override def apply(e: Expression): Expression = e transformUp { case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)
case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l)
case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l)
case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l)
case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) case _ => e
case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
}
} }
} }
...@@ -152,7 +152,10 @@ abstract class Expression extends TreeNode[Expression] { ...@@ -152,7 +152,10 @@ abstract class Expression extends TreeNode[Expression] {
* `deterministic` expressions where `this.canonicalized == other.canonicalized` will always * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
* evaluate to the same result. * evaluate to the same result.
*/ */
lazy val canonicalized: Expression = Canonicalize.execute(this) lazy val canonicalized: Expression = {
val canonicalizedChildred = children.map(_.canonicalized)
Canonicalize.execute(withNewChildren(canonicalizedChildred))
}
/** /**
* Returns true when two expressions will always compute the same result, even if they differ * Returns true when two expressions will always compute the same result, even if they differ
...@@ -161,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] { ...@@ -161,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] {
* See [[Canonicalize]] for more details. * See [[Canonicalize]] for more details.
*/ */
def semanticEquals(other: Expression): Boolean = def semanticEquals(other: Expression): Boolean =
deterministic && other.deterministic && canonicalized == other.canonicalized deterministic && other.deterministic && canonicalized == other.canonicalized
/** /**
* Returns a `hashCode` for the calculation performed by this expression. Unlike the standard * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment