Skip to content
Snippets Groups Projects
Commit 32495289 authored by Takuya UESHIN's avatar Takuya UESHIN Committed by Reynold Xin
Browse files

[SPARK-2196] [SQL] Fix nullability of CaseWhen.

`CaseWhen` should use `branches.length` to check if `elseValue` is provided or not.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #1133 from ueshin/issues/SPARK-2196 and squashes the following commits:

510f12d [Takuya UESHIN] Add some tests.
dc25e8d [Takuya UESHIN] Fix nullable of CaseWhen to be nullable if the elseValue is nullable.
4f049cc [Takuya UESHIN] Fix nullability of CaseWhen.
parent f46e02fc
No related branches found
No related tags found
No related merge requests found
......@@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
@transient private[this] lazy val values =
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
@transient private[this] lazy val elseValue =
if (branches.length % 2 == 0) None else Option(branches.last)
override def nullable = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
values.exists(_.nullable) || (values.length % 2 == 0)
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
override lazy val resolved = {
......
......@@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite {
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
}
test("case when") {
val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
val c1 = 'a.boolean.at(0)
val c2 = 'a.boolean.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.string.at(3)
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)
checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row)
checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row)
checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row)
checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
val c4_notNull = 'a.boolean.notNull.at(3)
val c5_notNull = 'a.boolean.notNull.at(4)
val c6_notNull = 'a.boolean.notNull.at(5)
assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
}
test("complex type") {
val row = new GenericRow(Array[Any](
"^Ba*n", // 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment