Skip to content
Snippets Groups Projects
Commit 78d18fdb authored by Michael Armbrust's avatar Michael Armbrust Committed by Reynold Xin
Browse files

[SPARK-2658][SQL] Add rule for true = 1.

Author: Michael Armbrust <michael@databricks.com>

Closes #1556 from marmbrus/fixBooleanEqualsOne and squashes the following commits:

ad8edd4 [Michael Armbrust] Add rule for true = 1 and false = 0.
parent 9e7725c8
No related branches found
No related tags found
No related merge requests found
......@@ -231,10 +231,20 @@ trait HiveTypeCoercion {
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
*/
object BooleanComparisons extends Rule[LogicalPlan] {
val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_))
val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_))
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
// No need to change EqualTo operators as that actually makes sense for boolean types.
// Hive treats (true = 1) as true and (false = 0) as true.
case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
// No need to change other EqualTo operators as that actually makes sense for boolean types.
case e: EqualTo => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison
......
true true true true true true false false false false false false false false false false false false true true true true true true false false false false false false false false false false false false
......@@ -30,6 +30,18 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {
createQueryTest("boolean = number",
"""
|SELECT
| 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y,
| 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y,
| 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y,
| 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y,
| 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y,
| 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y
|FROM src LIMIT 1
""".stripMargin)
test("CREATE TABLE AS runs once") {
hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect()
assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment