diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d3071c533bef974f40dc6c7935b46cee04200453..efcd45fad779c614aadb5295dc81ea0269508ecf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -133,8 +133,25 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.") + usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.", + arguments = """ + Arguments: + * expr1, expr2, expr3, ... - the arguments must be same type. + """, + examples = """ + Examples: + > SELECT 1 _FUNC_(1, 2, 3); + true + > SELECT 1 _FUNC_(2, 3, 4); + false + > SELECT named_struct('a', 1, 'b', 2) _FUNC_(named_struct('a', 1, 'b', 1), named_struct('a', 1, 'b', 3)); + false + > SELECT named_struct('a', 1, 'b', 2) _FUNC_(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 3)); + true + """) +// scalastyle:on line.size.limit case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") @@ -491,7 +508,24 @@ object Equality { // TODO: although map type is not orderable, technically map type should be able to be used // in equality comparison @ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.") + usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.", + arguments = """ + Arguments: + * expr1, expr2 - the two expressions must be same type or can be casted to a common type, + and must be a type that can be used in equality comparison. Map type is not supported. + For complex types such array/struct, the data types of fields must be orderable. + """, + examples = """ + Examples: + > SELECT 2 _FUNC_ 2; + true + > SELECT 1 _FUNC_ '1'; + true + > SELECT true _FUNC_ NULL; + NULL + > SELECT NULL _FUNC_ NULL; + NULL + """) case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -510,6 +544,23 @@ case class EqualTo(left: Expression, right: Expression) usage = """ expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands, but returns true if both are null, false if one of the them is null. + """, + arguments = """ + Arguments: + * expr1, expr2 - the two expressions must be same type or can be casted to a common type, + and must be a type that can be used in equality comparison. Map type is not supported. + For complex types such array/struct, the data types of fields must be orderable. + """, + examples = """ + Examples: + > SELECT 2 _FUNC_ 2; + true + > SELECT 1 _FUNC_ '1'; + true + > SELECT true _FUNC_ NULL; + false + > SELECT NULL _FUNC_ NULL; + true """) case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -540,7 +591,27 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } @ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than `expr2`.") + usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than `expr2`.", + arguments = """ + Arguments: + * expr1, expr2 - the two expressions must be same type or can be casted to a common type, + and must be a type that can be ordered. For example, map type is not orderable, so it + is not supported. For complex types such array/struct, the data types of fields must + be orderable. + """, + examples = """ + Examples: + > SELECT 1 _FUNC_ 2; + true + > SELECT 1.1 _FUNC_ '1'; + false + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52'); + false + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52'); + true + > SELECT 1 _FUNC_ NULL; + NULL + """) case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -550,7 +621,27 @@ case class LessThan(left: Expression, right: Expression) } @ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than or equal to `expr2`.") + usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than or equal to `expr2`.", + arguments = """ + Arguments: + * expr1, expr2 - the two expressions must be same type or can be casted to a common type, + and must be a type that can be ordered. For example, map type is not orderable, so it + is not supported. For complex types such array/struct, the data types of fields must + be orderable. + """, + examples = """ + Examples: + > SELECT 2 _FUNC_ 2; + true + > SELECT 1.0 _FUNC_ '1'; + true + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52'); + true + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52'); + true + > SELECT 1 _FUNC_ NULL; + NULL + """) case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -560,7 +651,27 @@ case class LessThanOrEqual(left: Expression, right: Expression) } @ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than `expr2`.") + usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than `expr2`.", + arguments = """ + Arguments: + * expr1, expr2 - the two expressions must be same type or can be casted to a common type, + and must be a type that can be ordered. For example, map type is not orderable, so it + is not supported. For complex types such array/struct, the data types of fields must + be orderable. + """, + examples = """ + Examples: + > SELECT 2 _FUNC_ 1; + true + > SELECT 2 _FUNC_ '1.1'; + true + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52'); + false + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52'); + false + > SELECT 1 _FUNC_ NULL; + NULL + """) case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -570,7 +681,27 @@ case class GreaterThan(left: Expression, right: Expression) } @ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than or equal to `expr2`.") + usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than or equal to `expr2`.", + arguments = """ + Arguments: + * expr1, expr2 - the two expressions must be same type or can be casted to a common type, + and must be a type that can be ordered. For example, map type is not orderable, so it + is not supported. For complex types such array/struct, the data types of fields must + be orderable. + """, + examples = """ + Examples: + > SELECT 2 _FUNC_ 1; + true + > SELECT 2.0 _FUNC_ '2.1'; + false + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52'); + true + > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52'); + false + > SELECT 1 _FUNC_ NULL; + NULL + """) case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 055c31c2b30185a3de38666b6b377ee9403ac84b..1438a88c19e0b7bc9469cf75a0ca37db07aa9af5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -123,7 +123,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, false, null) :: (null, null, null) :: Nil) - test("IN") { + test("basic IN predicate test") { checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), @@ -151,19 +151,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) - val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, - LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) - primitiveTypes.foreach { t => - val dataGen = RandomDataGenerator.forType(t, nullable = true).get + } + + test("IN with different types") { + def testWithRandomDataGeneration(dataType: DataType, nullable: Boolean): Unit = { + val maybeDataGen = RandomDataGenerator.forType(dataType, nullable = nullable) + // Actually we won't pass in unsupported data types, this is a safety check. + val dataGen = maybeDataGen.getOrElse( + fail(s"Failed to create data generator for type $dataType")) val inputData = Seq.fill(10) { val value = dataGen.apply() - value match { + def cleanData(value: Any) = value match { case d: Double if d.isNaN => 0.0d case f: Float if f.isNaN => 0.0f case _ => value } + value match { + case s: Seq[_] => s.map(cleanData(_)) + case m: Map[_, _] => + val pair = m.unzip + val newKeys = pair._1.map(cleanData(_)) + val newValues = pair._2.map(cleanData(_)) + newKeys.zip(newValues).toMap + case _ => cleanData(value) + } } - val input = inputData.map(NonFoldableLiteral.create(_, t)) + val input = inputData.map(NonFoldableLiteral.create(_, dataType)) val expected = if (inputData(0) == null) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { @@ -175,6 +188,55 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } checkEvaluation(In(input(0), input.slice(1, 10)), expected) } + + val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => + RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType] + } ++ Seq(DecimalType.USER_DEFAULT) + + val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true)) + + // Basic types: + for ( + dataType <- atomicTypes; + nullable <- Seq(true, false)) { + testWithRandomDataGeneration(dataType, nullable) + } + + // Array types: + for ( + arrayType <- atomicArrayTypes; + nullable <- Seq(true, false) + if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined) { + testWithRandomDataGeneration(arrayType, nullable) + } + + // Struct types: + for ( + colOneType <- atomicTypes; + colTwoType <- atomicTypes; + nullable <- Seq(true, false)) { + val structType = StructType( + StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) + testWithRandomDataGeneration(structType, nullable) + } + + // Map types: not supported + for ( + keyType <- atomicTypes; + valueType <- atomicTypes; + nullable <- Seq(true, false)) { + val mapType = MapType(keyType, valueType) + val e = intercept[Exception] { + testWithRandomDataGeneration(mapType, nullable) + } + if (e.getMessage.contains("Code generation of")) { + // If the `value` expression is null, `eval` will be short-circuited. + // Codegen version evaluation will be run then. + assert(e.getMessage.contains("cannot generate equality code for un-comparable type")) + } else { + assert(e.getMessage.contains("Exception evaluating")) + } + } } test("INSET") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql new file mode 100644 index 0000000000000000000000000000000000000000..3b3d4ad64b3ece94384ed1a1cdcc3d3d7b4ba84c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -0,0 +1,36 @@ +-- EqualTo +select 1 = 1; +select 1 = '1'; +select 1.0 = '1'; + +-- GreaterThan +select 1 > '1'; +select 2 > '1.0'; +select 2 > '2.0'; +select 2 > '2.2'; +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'; + +-- GreaterThanOrEqual +select 1 >= '1'; +select 2 >= '1.0'; +select 2 >= '2.0'; +select 2.0 >= '2.2'; +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'; + +-- LessThan +select 1 < '1'; +select 2 < '1.0'; +select 2 < '2.0'; +select 2.0 < '2.2'; +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'; + +-- LessThanOrEqual +select 1 <= '1'; +select 2 <= '1.0'; +select 2 <= '2.0'; +select 2.0 <= '2.2'; +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..8e7e04c8e1c4f343b126bfc72876c1bef8ec5afe --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -0,0 +1,218 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 27 + + +-- !query 0 +select 1 = 1 +-- !query 0 schema +struct<(1 = 1):boolean> +-- !query 0 output +true + + +-- !query 1 +select 1 = '1' +-- !query 1 schema +struct<(1 = CAST(1 AS INT)):boolean> +-- !query 1 output +true + + +-- !query 2 +select 1.0 = '1' +-- !query 2 schema +struct<(1.0 = CAST(1 AS DECIMAL(2,1))):boolean> +-- !query 2 output +true + + +-- !query 3 +select 1 > '1' +-- !query 3 schema +struct<(1 > CAST(1 AS INT)):boolean> +-- !query 3 output +false + + +-- !query 4 +select 2 > '1.0' +-- !query 4 schema +struct<(2 > CAST(1.0 AS INT)):boolean> +-- !query 4 output +true + + +-- !query 5 +select 2 > '2.0' +-- !query 5 schema +struct<(2 > CAST(2.0 AS INT)):boolean> +-- !query 5 output +false + + +-- !query 6 +select 2 > '2.2' +-- !query 6 schema +struct<(2 > CAST(2.2 AS INT)):boolean> +-- !query 6 output +false + + +-- !query 7 +select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52') +-- !query 7 schema +struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean> +-- !query 7 output +false + + +-- !query 8 +select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52' +-- !query 8 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean> +-- !query 8 output +false + + +-- !query 9 +select 1 >= '1' +-- !query 9 schema +struct<(1 >= CAST(1 AS INT)):boolean> +-- !query 9 output +true + + +-- !query 10 +select 2 >= '1.0' +-- !query 10 schema +struct<(2 >= CAST(1.0 AS INT)):boolean> +-- !query 10 output +true + + +-- !query 11 +select 2 >= '2.0' +-- !query 11 schema +struct<(2 >= CAST(2.0 AS INT)):boolean> +-- !query 11 output +true + + +-- !query 12 +select 2.0 >= '2.2' +-- !query 12 schema +struct<(2.0 >= CAST(2.2 AS DECIMAL(2,1))):boolean> +-- !query 12 output +false + + +-- !query 13 +select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52') +-- !query 13 schema +struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean> +-- !query 13 output +true + + +-- !query 14 +select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52' +-- !query 14 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean> +-- !query 14 output +false + + +-- !query 15 +select 1 < '1' +-- !query 15 schema +struct<(1 < CAST(1 AS INT)):boolean> +-- !query 15 output +false + + +-- !query 16 +select 2 < '1.0' +-- !query 16 schema +struct<(2 < CAST(1.0 AS INT)):boolean> +-- !query 16 output +false + + +-- !query 17 +select 2 < '2.0' +-- !query 17 schema +struct<(2 < CAST(2.0 AS INT)):boolean> +-- !query 17 output +false + + +-- !query 18 +select 2.0 < '2.2' +-- !query 18 schema +struct<(2.0 < CAST(2.2 AS DECIMAL(2,1))):boolean> +-- !query 18 output +true + + +-- !query 19 +select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52') +-- !query 19 schema +struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean> +-- !query 19 output +false + + +-- !query 20 +select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52' +-- !query 20 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean> +-- !query 20 output +true + + +-- !query 21 +select 1 <= '1' +-- !query 21 schema +struct<(1 <= CAST(1 AS INT)):boolean> +-- !query 21 output +true + + +-- !query 22 +select 2 <= '1.0' +-- !query 22 schema +struct<(2 <= CAST(1.0 AS INT)):boolean> +-- !query 22 output +false + + +-- !query 23 +select 2 <= '2.0' +-- !query 23 schema +struct<(2 <= CAST(2.0 AS INT)):boolean> +-- !query 23 output +true + + +-- !query 24 +select 2.0 <= '2.2' +-- !query 24 schema +struct<(2.0 <= CAST(2.2 AS DECIMAL(2,1))):boolean> +-- !query 24 output +true + + +-- !query 25 +select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52') +-- !query 25 schema +struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean> +-- !query 25 output +true + + +-- !query 26 +select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' +-- !query 26 schema +struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> +-- !query 26 output +true