diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c3f0abac244cf7138e5befa9bf4982072f5a4322..d205547698c5ba4deef90ea8473126a542d7686c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -578,12 +578,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); - } else if (!(other instanceof InternalRow)) { - return false; - } else { - throw new IllegalArgumentException( - "Cannot compare UnsafeRow to " + other.getClass().getName()); } + return false; } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9c3c6d3b2a7f21a9e5187ef485f9870675c47119..09007b7c89fe3210d987b9e3c27b20288e80e03c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -481,8 +481,13 @@ class CodegenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" + case array: ArrayType => genComp(array, c1, c2) + " == 0" + case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) - case other => s"$c1.equals($c2)" + case _ => + throw new IllegalArgumentException( + "cannot generate equality code for un-comparable type: " + dataType.simpleString) } /** @@ -512,6 +517,11 @@ class CodegenContext { val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { + // when comparing unsafe arrays, try equals first as it compares the binary directly + // which is very fast. + if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) { + return 0; + } int lengthA = a.numElements(); int lengthB = b.numElements(); int $minLength = (lengthA > lengthB) ? lengthB : lengthA; @@ -551,6 +561,11 @@ class CodegenContext { val funcCode: String = s""" public int $compareFunc(InternalRow a, InternalRow b) { + // when comparing unsafe rows, try equals first as it compares the binary directly + // which is very fast. + if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) { + return 0; + } InternalRow i = null; $comparisons return 0; @@ -561,7 +576,8 @@ class CodegenContext { case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => - throw new IllegalArgumentException("cannot generate compare code for un-comparable type") + throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type: " + dataType.simpleString) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2ad452b6a90cac2d920ccd6664ed9576f1863b35..3fcbb05372d87107dd8c335319583c3f601b90b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } + + protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) } @@ -429,17 +431,7 @@ case class EqualTo(left: Expression, right: Expression) override def symbol: String = "=" - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (left.dataType == FloatType) { - Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 - } else if (left.dataType == DoubleType) { - Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 - } else if (left.dataType != BinaryType) { - input1 == input2 - } else { - java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) - } - } + protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) @@ -482,15 +474,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - if (left.dataType == FloatType) { - Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 - } else if (left.dataType == DoubleType) { - Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 - } else if (left.dataType != BinaryType) { - input1 == input2 - } else { - java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) - } + ordering.equiv(input1, input2) } } @@ -513,8 +497,6 @@ case class LessThan(left: Expression, right: Expression) override def symbol: String = "<" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -527,8 +509,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -541,8 +521,6 @@ case class GreaterThan(left: Expression, right: Expression) override def symbol: String = ">" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -555,7 +533,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 2a445b8cdb0917302302638079347b9ba1e40cd5..f9f6799e6e72ff55d198bf7f8502165d9eec4078 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(nullInt, normalInt), false) checkEvaluation(EqualNullSafe(nullInt, nullInt), true) } + + test("EqualTo on complex type") { + val array = new GenericArrayData(Array(1, 2, 3)) + val struct = create_row("a", 1L, array) + + val arrayType = ArrayType(IntegerType) + val structType = new StructType() + .add("1", StringType) + .add("2", LongType) + .add("3", ArrayType(IntegerType)) + + val projection = UnsafeProjection.create( + new StructType().add("array", arrayType).add("struct", structType)) + + val unsafeRow = projection(InternalRow(array, struct)) + + val unsafeArray = unsafeRow.getArray(0) + val unsafeStruct = unsafeRow.getStruct(1, 3) + + checkEvaluation(EqualTo( + Literal.create(array, arrayType), + Literal.create(unsafeArray, arrayType)), true) + + checkEvaluation(EqualTo( + Literal.create(struct, structType), + Literal.create(unsafeStruct, structType)), true) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6b517bc70f7d21e3cf624b2aea4b7f5a30c25c3c..806381008aba61176ce7de26e2f143522bada9de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2476,4 +2476,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-18053: ARRAY equality is broken") { + withTable("array_tbl") { + spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl") + assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) + } + } }