diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 900def59d23a50d9d04157ef5455baa16dfa3880..320451c52c70650501c284c3b78e7df790500351 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -368,12 +368,12 @@ class Column(object): >>> from pyspark.sql import functions as F >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() - +-----+--------------------------------------------------------+ - | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| - +-----+--------------------------------------------------------+ - |Alice| -1| - | Bob| 1| - +-----+--------------------------------------------------------+ + +-----+------------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| + +-----+------------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+------------------------------------------------------------+ """ if not isinstance(condition, Column): raise TypeError("condition should be a Column") @@ -393,12 +393,12 @@ class Column(object): >>> from pyspark.sql import functions as F >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() - +-----+---------------------------------+ - | name|CASE WHEN (age > 3) THEN 1 ELSE 0| - +-----+---------------------------------+ - |Alice| 0| - | Bob| 1| - +-----+---------------------------------+ + +-----+-------------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END| + +-----+-------------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+-------------------------------------+ """ v = value._jc if isinstance(value, Column) else value jc = self._jc.otherwise(v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index c87b6c8e9543605149a4ca7e25b4c3d5a3199924..d0fbdacf6eafdc8a0320f58920c3141f657612d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -752,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Case statements */ case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen(branches.map(nodeToExpr)) + CaseWhen.createFromParser(branches.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => val keyExpr = nodeToExpr(branches.head) CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 6ec408a673c796883f2410ab5ff14add52acb65f..85ff4ea0c946b696b888d0bb90b4bd8a3257407b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -305,7 +305,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { throw new AnalysisException(s"invalid function approximate($s) $udfName") } } - | CASE ~> whenThenElse ^^ CaseWhen + | CASE ~> whenThenElse ^^ + { case branches => CaseWhen.createFromParser(branches) } | CASE ~> expression ~ whenThenElse ^^ { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 980b5d52fa8f77de53ba75dea2c83368a1ffc0d2..2737fe32cd086ad5ef458c2cc91d556e7bd6f5eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -621,14 +621,24 @@ object HiveTypeCoercion { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => - val castedBranches = c.branches.grouped(2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case other => other - }.reduce(_ ++ _) - CaseWhen(castedBranches) + var changed = false + val newBranches = c.branches.map { case (condition, value) => + if (value.dataType.sameType(commonType)) { + (condition, value) + } else { + changed = true + (condition, Cast(value, commonType)) + } + } + val newElseValue = c.elseValue.map { value => + if (value.dataType.sameType(commonType)) { + value + } else { + changed = true + Cast(value, commonType) + } + } + if (changed) CaseWhen(newBranches, newElseValue) else c }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 5a1462433d583a46d71c17801f8d2f8d05ba3b51..8cc7bc1da2fc3735cd7b11f3279a72deda5876f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -81,44 +81,39 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi /** * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". * When a = true, returns b; when c = true, returns d; else returns e. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch */ -case class CaseWhen(branches: Seq[Expression]) extends Expression { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches - - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq +case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) + extends Expression { - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { case Seq(dt1, dt2) => dt1.sameType(dt2) } - override def dataType: DataType = thenList.head.dataType + override def dataType: DataType = branches.head._2.dataType override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) + // Result is nullable if any of the branch is nullable, or if the else value is nullable + branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { + // Make sure all branch conditions are boolean types. if (valueTypesEqual) { - if (whenList.forall(_.dataType == BooleanType)) { + if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { - val index = whenList.indexWhere(_.dataType != BooleanType) + val index = branches.indexWhere(_._1.dataType != BooleanType) TypeCheckResult.TypeCheckFailure( s"WHEN expressions in CaseWhen should all be boolean type, " + - s"but the ${index + 1}th when expression's type is ${whenList(index)}") + s"but the ${index + 1}th when expression's type is ${branches(index)._1}") } } else { TypeCheckResult.TypeCheckFailure( @@ -127,31 +122,26 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { } override def eval(input: InternalRow): Any = { - // Written in imperative fashion for performance considerations - val len = branchesArr.length var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) + while (i < branches.size) { + if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { + return branches(i)._2.eval(input) } - i += 2 + i += 1 } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) + if (elseValue.isDefined) { + return elseValue.get.eval(input) + } else { + return null } - return res } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val len = branchesArr.length val got = ctx.freshName("got") - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) + val cases = branches.map { case (condition, value) => + val cond = condition.gen(ctx) + val res = value.gen(ctx) s""" if (!$got) { ${cond.code} @@ -165,17 +155,19 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { """ }.mkString("\n") - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" + val elseCase = { + if (elseValue.isDefined) { + val res = elseValue.get.gen(ctx) + s""" if (!$got) { ${res.code} ${ev.isNull} = ${res.isNull}; ${ev.value} = ${res.value}; } - """ - } else { - "" + """ + } else { + "" + } } s""" @@ -183,32 +175,42 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $cases - $other + $elseCase """ } override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString + val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString + val elseCase = elseValue.map(" ELSE " + _).getOrElse("") + "CASE" + cases + elseCase + " END" } override def sql: String = { - val branchesSQL = branches.map(_.sql) - val (cases, maybeElse) = if (branches.length % 2 == 0) { - (branchesSQL, None) - } else { - (branchesSQL.init, Some(branchesSQL.last)) - } + val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString + val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") + "CASE" + cases + elseCase + " END" + } +} - val head = s"CASE " - val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" - val body = cases.grouped(2).map { - case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" - }.mkString(" ") +/** Factory methods for CaseWhen. */ +object CaseWhen { - head + body + tail + def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { + CaseWhen(branches, Option(elseValue)) + } + + /** + * A factory method to faciliate the creation of this expression when used in parsers. + * @param branches Expressions at even position are the branch conditions, and expressions at odd + * position are branch values. + */ + def createFromParser(branches: Seq[Expression]): CaseWhen = { + val cases = branches.grouped(2).flatMap { + case cond :: value :: Nil => Some((cond, value)) + case value :: Nil => None + }.toArray.toSeq // force materialization to make the seq serializable + val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + CaseWhen(cases, elseValue) } } @@ -218,17 +220,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { */ object CaseKeyWhen { def apply(key: Expression, branches: Seq[Expression]): CaseWhen = { - val newBranches = branches.zipWithIndex.map { case (expr, i) => - if (i % 2 == 0 && i != branches.size - 1) { - // If this expression is at even position, then it is either a branch condition, or - // the very last value that is the "else value". The "i != branches.size - 1" makes - // sure we are not adding an EqualTo to the "else value". - EqualTo(key, expr) - } else { - expr - } - } - CaseWhen(newBranches) + val cases = branches.grouped(2).flatMap { + case cond :: value :: Nil => Some((EqualTo(key, cond), value)) + case value :: Nil => None + }.toArray.toSeq // force materialization to make the seq serializable + val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d4be545a35ab2e650519a99bfc6529f17f4e85f9..d0b29aa01f640f3477347cda80d256c76f35aff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -315,6 +315,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { arg } + case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule) + val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule) + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } case other => other } case nonChild: AnyRef => nonChild diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cf84855885a37082f171cd5c0c7aeaf11638d03e..975cd87d090e4c8dac38b46bf67ef231ac6eaaf4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -239,7 +239,7 @@ class AnalysisSuite extends AnalysisTest { test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) - val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) + val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) assertAnalysisSuccess(plan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 0521ed848c793fe87bc6c688a7943b19e3716aa2..59549e3998e7eb1327fdc3a0b94d3fe570bfc727 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -132,13 +132,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)), + CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))), "WHEN expressions in CaseWhen should all be boolean type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 40378c6727667c2418e29fb677c2f6a6e791e565..b1f6c0b802d8e958993be75c371c9f6bcd7286f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -308,15 +308,14 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) ruleTest(HiveTypeCoercion.CaseWhenCoercion, - CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), - CaseWhen(Seq( - Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Literal(1.2))), + Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) ruleTest(HiveTypeCoercion.CaseWhenCoercion, - CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), - CaseWhen(Seq( - Literal(true), Cast(Literal(100L), DecimalType(22, 2)), - Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) ) } @@ -452,7 +451,7 @@ class HiveTypeCoercionSuite extends PlanTest { val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), DecimalType(25, 5), DoubleType, DoubleType) - rightTypes.zip(expectedTypes).map { case (rType, expectedType) => + rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) => val plan2 = LocalRelation( AttributeReference("r", rType)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 4029da5925580f5b3bf14801a530e69d829fdfab..3c581ecdaf068f559552fd8e543e076274acd8cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -80,38 +80,39 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) - - checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c3, c4)), c6), "a", row) + checkEvaluation(CaseWhen(Seq((Literal.create(null, BooleanType), c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((Literal.create(false, BooleanType), c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((Literal.create(true, BooleanType), c4)), c6), "a", row) + + checkEvaluation(CaseWhen(Seq((c3, c4), (c2, c5)), c6), "a", row) + checkEvaluation(CaseWhen(Seq((c2, c4), (c3, c5)), c6), "b", row) + checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5))), null, row) + + assert(CaseWhen(Seq((c2, c4)), c6).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable === true) val c4_notNull = 'a.boolean.notNull.at(3) val c5_notNull = 'a.boolean.notNull.at(4) val c6_notNull = 'a.boolean.notNull.at(5) - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false) + assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull)), c6).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6_notNull).nullable === false) + assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } test("case key when") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e8c61d6e01dc31e977bda09930de1b9a7e539144..6a020f9f2883e5c174ccad06338880e8fba01622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "when() cannot be applied once otherwise() is applied") case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def otherwise(value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - if (branches.size % 2 == 0) { - withExpr { CaseWhen(branches :+ lit(value).expr) } - } else { - throw new IllegalArgumentException( - "otherwise() can only be applied once on a Column previously generated by when()") - } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches, Option(lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") case _ => throw new IllegalArgumentException( "otherwise() can only be applied on a Column previously generated by when()") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 71fea2716bd9f9ccc2a1f5840cc6f79c473658ac..b8ea2261e94e263644bedd819be37fd1930eae76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = withExpr { - CaseWhen(Seq(condition.expr, lit(value).expr)) + CaseWhen(Seq((condition.expr, lit(value).expr))) } /**