Skip to content
Snippets Groups Projects
Commit a0e46a0d authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-7952][SPARK-7984][SQL] equality check between boolean type and numeric type is broken.

The origin code has several problems:
* `true <=> 1` will return false as we didn't set a rule to handle it.
* `true = a` where `a` is not `Literal` and its value is 1, will return false as we only handle literal values.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #6505 from cloud-fan/tmp1 and squashes the following commits:

77f0f39 [Wenchen Fan] minor fix
b6401ba [Wenchen Fan] add type coercion for CaseKeyWhen and address comments
ebc8c61 [Wenchen Fan] use SQLTestUtils and If
625973c [Wenchen Fan] improve
9ba2130 [Wenchen Fan] address comments
fc0d741 [Wenchen Fan] fix style
2846a04 [Wenchen Fan] fix 7952
parent 91777a1c
No related branches found
No related tags found
No related merge requests found
......@@ -76,7 +76,7 @@ trait HiveTypeCoercion {
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanComparisons ::
BooleanEqualization ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
......@@ -119,7 +119,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
val stringNaN = Literal("NaN")
private val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
......@@ -349,17 +349,17 @@ trait HiveTypeCoercion {
import scala.math.{max, min}
// Conversion rules for integer types into fixed-precision decimals
val intTypeToFixed: Map[DataType, DecimalType] = Map(
private val intTypeToFixed: Map[DataType, DecimalType] = Map(
ByteType -> DecimalType(3, 0),
ShortType -> DecimalType(5, 0),
IntegerType -> DecimalType(10, 0),
LongType -> DecimalType(20, 0)
)
def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
// Conversion rules for float and double into fixed-precision decimals
val floatTypeToFixed: Map[DataType, DecimalType] = Map(
private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
FloatType -> DecimalType(7, 7),
DoubleType -> DecimalType(15, 15)
)
......@@ -482,30 +482,66 @@ trait HiveTypeCoercion {
}
/**
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
object BooleanComparisons extends Rule[LogicalPlan] {
val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_))
val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_))
object BooleanEqualization extends Rule[LogicalPlan] {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq(
Literal(trueValues.head), booleanExpr,
Literal(falseValues.head), Not(booleanExpr),
Literal(false)))
}
private def transform(booleanExpr: Expression, numericExpr: Expression) = {
If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
Literal.create(null, BooleanType),
buildCaseKeyWhen(booleanExpr, numericExpr))
}
private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
CaseWhen(Seq(
And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
buildCaseKeyWhen(booleanExpr, numericExpr)
))
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
// Hive treats (true = 1) as true and (false = 0) as true.
case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
// No need to change other EqualTo operators as that actually makes sense for boolean types.
case e: EqualTo => e
// No need to change the EqualNullSafe operators, too
case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison if p.left.dataType == BooleanType &&
p.right.dataType == BooleanType =>
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
// Hive treats (true = 1) as true and (false = 0) as true,
// all other cases are considered as false.
// We may simplify the expression if one side is literal numeric values
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => l
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(l)
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
if trueValues.contains(value) => r
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
if falseValues.contains(value) => Not(r)
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(l), l)
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(l), Not(l))
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(r), r)
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
if falseValues.contains(value) => And(IsNotNull(r), Not(r))
case EqualTo(l @ BooleanType(), r @ NumericType()) =>
transform(l , r)
case EqualTo(l @ NumericType(), r @ BooleanType()) =>
transform(r, l)
case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
transformNullSafe(l, r)
case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
transformNullSafe(r, l)
}
}
......@@ -606,7 +642,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
......@@ -625,6 +661,23 @@ trait HiveTypeCoercion {
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = ckw.branches.sliding(2, 2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case s => s
}.reduce(_ ++ _)
val transformedKey = if (ckw.key.dataType != commonType) {
Cast(ckw.key, commonType)
} else {
ckw.key
}
CaseKeyWhen(transformedKey, transformedBranches)
}
}
......
......@@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression {
// both then and else val should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
override def dataType: DataType = {
if (!resolved) {
......@@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
override def children: Seq[Expression] = key +: branches
override lazy val resolved: Boolean =
childrenResolved && valueTypesEqual
childrenResolved && valueTypesEqual &&
(key +: whenList).map(_.dataType).distinct.size == 1
/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
......
......@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
class HiveTypeCoercionSuite extends PlanTest {
......@@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
comparePlans(
rule(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
test("coalesce casts") {
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
comparePlans(
fac(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
ruleTest(
ruleTest(fac,
Coalesce(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
......@@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
ruleTest(
ruleTest(fac,
Coalesce(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
......@@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
:: Nil))
}
test("type coercion for CaseKeyWhen") {
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
ruleTest(cwc,
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
// Will remove exception expectation in PR#6405
intercept[RuntimeException] {
ruleTest(cwc,
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
}
}
test("type coercion simplification for equal to") {
val be = new HiveTypeCoercion {}.BooleanEqualization
ruleTest(be,
EqualTo(Literal(true), Literal(1)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal(true), Literal(0)),
Not(Literal(true))
)
ruleTest(be,
EqualNullSafe(Literal(true), Literal(1)),
And(IsNotNull(Literal(true)), Literal(true))
)
ruleTest(be,
EqualNullSafe(Literal(true), Literal(0)),
And(IsNotNull(Literal(true)), Not(Literal(true)))
)
}
}
......@@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)
val literalNull = Literal.create(null, BooleanType)
val literalNull = Literal.create(null, IntegerType)
val literalInt = Literal(1)
val literalString = Literal("a")
......@@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}
test("complex type") {
......
......@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
import org.apache.spark.sql.types._
......@@ -32,12 +32,12 @@ import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData
import org.apache.spark.sql.test.TestSQLContext.implicits._
val sqlCtx = TestSQLContext
val sqlContext = TestSQLContext
import sqlContext.implicits._
test("SPARK-6743: no columns from cache") {
Seq(
......@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
......@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
val df2 = createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
......@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
val df3 = createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
......@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
......@@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}
test("SPARK-7952: fix the equality check between boolean and numeric types") {
withTempTable("t") {
// numeric field i, boolean field j, result of i = j, result of i <=> j
Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
(1, true, true, true),
(0, false, true, true),
(2, true, false, false),
(2, false, false, false),
(null, true, null, false),
(null, false, null, false),
(0, null, null, false),
(1, null, null, false),
(null, null, null, true)
).toDF("i", "b", "r1", "r2").registerTempTable("t")
checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment