Skip to content
Snippets Groups Projects
Commit cba69aeb authored by Andrew Ray's avatar Andrew Ray Committed by gatorsmile
Browse files

[SPARK-21110][SQL] Structs, arrays, and other orderable datatypes should be usable in inequalities

## What changes were proposed in this pull request?

Allows `BinaryComparison` operators to work on any data type that actually supports ordering as verified by `TypeUtils.checkForOrderingExpr` instead of relying on the incomplete list `TypeCollection.Ordered` (which is removed by this PR).

## How was this patch tested?

Updated unit tests to cover structs and arrays.

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #18818 from aray/SPARK-21110.
parent 7ce11082
No related branches found
No related tags found
No related merge requests found
......@@ -594,6 +594,7 @@ class CodegenContext {
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case NullType => "false"
case _ =>
throw new IllegalArgumentException(
"cannot generate equality code for un-comparable type: " + dataType.simpleString)
......
......@@ -448,6 +448,16 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
abstract class BinaryComparison extends BinaryOperator with Predicate {
// Note that we need to give a superset of allowable input types since orderable types are not
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
case failure => failure
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (ctx.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
......@@ -460,7 +470,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
}
}
protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType)
}
......@@ -478,28 +488,13 @@ object Equality {
}
}
// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.")
case class EqualTo(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = AnyDataType
override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison, remove this type check once we support it.
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
s"input type is ${left.dataType.catalogString}.")
} else {
TypeCheckResult.TypeCheckSuccess
}
case failure => failure
}
}
override def symbol: String = "="
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
......@@ -509,6 +504,8 @@ case class EqualTo(left: Expression, right: Expression)
}
}
// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison
@ExpressionDescription(
usage = """
expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands,
......@@ -516,23 +513,6 @@ case class EqualTo(left: Expression, right: Expression)
""")
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
override def inputType: AbstractDataType = AnyDataType
override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison, remove this type check once we support it.
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
s"input type is ${left.dataType.catalogString}.")
} else {
TypeCheckResult.TypeCheckSuccess
}
case failure => failure
}
}
override def symbol: String = "<=>"
override def nullable: Boolean = false
......@@ -564,8 +544,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
case class LessThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
override def symbol: String = "<"
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
......@@ -576,8 +554,6 @@ case class LessThan(left: Expression, right: Expression)
case class LessThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
override def symbol: String = "<="
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
......@@ -588,8 +564,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
case class GreaterThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
override def symbol: String = ">"
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
......@@ -600,8 +574,6 @@ case class GreaterThan(left: Expression, right: Expression)
case class GreaterThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
override def symbol: String = ">="
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
......
......@@ -65,6 +65,7 @@ object TypeUtils {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
}
}
......
......@@ -78,18 +78,6 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
private[sql] object TypeCollection {
/**
* Types that can be ordered/compared. In the long run we should probably make this a trait
* that can be mixed into each data type, and perhaps create an `AbstractDataType`.
*/
// TODO: Should we consolidate this with RowOrdering.isOrderable?
val Ordered = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType,
TimestampType, DateType,
StringType, BinaryType)
/**
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
......
......@@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest {
right,
joinType = Cross,
condition = Some('b === 'd))
assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil)
}
test("PredicateSubQuery is used outside of a filter") {
......
......@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, StringType, TypeCollection}
import org.apache.spark.sql.types._
class ExpressionTypeCheckingSuite extends SparkFunSuite {
......@@ -109,16 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType")
assertError(EqualNullSafe('mapField, 'mapField),
"EqualNullSafe does not support ordering on type MapType")
assertError(LessThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"LessThan does not support ordering on type MapType")
assertError(LessThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"LessThanOrEqual does not support ordering on type MapType")
assertError(GreaterThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"GreaterThan does not support ordering on type MapType")
assertError(GreaterThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"GreaterThanOrEqual does not support ordering on type MapType")
assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
......
......@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
......@@ -215,14 +218,35 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
private case class MyStruct(a: Long, b: String)
private case class MyStruct2(a: MyStruct, b: Array[Int])
private val udt = new ExamplePointUDT
private val smallValues =
Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val largeValues =
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))
Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2),
new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))),
Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt))
private val equalValues1 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val equalValues2 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
test("BinaryComparison consistency check") {
DataTypeTestUtils.ordered.foreach { dt =>
......@@ -285,11 +309,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
val nullInt = NonFoldableLiteral.create(null, IntegerType)
val nullNullType = Literal.create(null, NullType)
def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
checkEvaluation(op(nullInt, normalInt), null)
checkEvaluation(op(nullInt, nullInt), null)
checkEvaluation(op(nullNullType, nullNullType), null)
}
nullTest(LessThan)
......@@ -301,6 +327,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true)
}
test("EqualTo on complex type") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment