Skip to content
Snippets Groups Projects
Commit f7050376 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Reynold Xin
Browse files

[SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN

## What changes were proposed in this pull request?

Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too.

**Before**
```
scala> sql("SELECT IF(null, 1, 0)").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4]
:     +- INPUT
+- Scan OneRowRelation[]
scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14]
:     +- INPUT
+- Scan OneRowRelation[]
```

**After**
```
scala> sql("SELECT IF(null, 1, 0)").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4]
:     +- INPUT
+- Scan OneRowRelation[]
scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4]
:     +- INPUT
+- Scan OneRowRelation[]
```

**Hive**
```
hive> select if(null,1,2);
OK
2
hive> select case when cast(null as boolean) then 1 else 2 end;
OK
2
```

## How was this patch tested?

Pass the Jenkins tests (including new extended test cases).

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12122 from dongjoon-hyun/SPARK-14338.
parent a3e29354
No related branches found
No related tags found
No related merge requests found
......@@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
def nonNullLiteral(e: Expression): Boolean = e match {
private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
......@@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
* Simplifies conditional expressions (if / case).
*/
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
private def falseOrNullLiteral(e: Expression): Boolean = e match {
case FalseLiteral => true
case Literal(null, _) => true
case _ => false
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
val newBranches = branches.filter(_._1 != FalseLiteral)
val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
......
......@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{IntegerType, NullType}
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
......@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal(null, NullType), Literal(30))
test("simplify if") {
assertEquivalent(
......@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
assertEquivalent(
If(FalseLiteral, Literal(10), Literal(20)),
Literal(20))
assertEquivalent(
If(Literal(null, NullType), Literal(10), Literal(20)),
Literal(20))
}
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
}
test("remove entire CaseWhen if only the else branch is reachable") {
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
Literal(30))
assertEquivalent(
......@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
CaseWhen(trueBranch :: normalBranch :: Nil, None),
CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
Literal(5))
// Test branch elimination and simplification in combination
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
:: Nil, None),
Literal(5))
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment