diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4b99030d1046f151cdfdfd0fcd2314eb44d1e195..87294a0e21441f1f35ab2708086e4cb5d24daf9d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -246,11 +246,6 @@ public final class UnsafeRow extends MutableRow { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public int size() { - return numFields; - } - /** * Returns the object for column `i`, which should not be primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0f2fd6a86d1770f49eaf077f2bfddd0f9f1b2fd9..5f0592dc1d77ba4365401559157ddc46d5d57773 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -151,7 +152,7 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def apply(i: Int): Any + def apply(i: Int): Any = get(i) /** * Returns the value at position i. If the value is null, null is returned. The following @@ -176,10 +177,10 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def get(i: Int): Any = apply(i) + def get(i: Int): Any /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean = apply(i) == null + def isNullAt(i: Int): Boolean = get(i) == null /** * Returns the value at position i as a primitive boolean. @@ -311,7 +312,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + def getAs[T](i: Int): T = get(i).asInstanceOf[T] /** * Returns the value of a given fieldName. @@ -363,6 +364,41 @@ trait Row extends Serializable { false } + protected def canEqual(other: Any) = + other.isInstanceOf[Row] && !other.isInstanceOf[InternalRow] + + override def equals(o: Any): Boolean = { + if (o == null || !canEqual(o)) return false + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + i += 1 + } + return true + } + /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 57de0f26a9720881973e5e0084cef4b4b018ec6c..e2fafb88ee43eb93b52851d8f0a3f12cebd7945a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -53,41 +53,8 @@ abstract class InternalRow extends Row { // A default implementation to change the return type override def copy(): InternalRow = this - override def apply(i: Int): Any = get(i) - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[Row]) { - return false - } - - val other = o.asInstanceOf[Row] - if (length != other.length) { - return false - } - - var i = 0 - while (i < length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = apply(i) - val o2 = other.apply(i) - if (o1.isInstanceOf[Array[Byte]]) { - // handle equality of Array[Byte] - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - } else if (o1 != o2) { - return false - } - } - i += 1 - } - true - } + protected override def canEqual(other: Any) = other.isInstanceOf[InternalRow] // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { @@ -98,7 +65,7 @@ abstract class InternalRow extends Row { if (isNullAt(i)) { 0 } else { - apply(i) match { + get(i) match { case b: Boolean => if (b) 0 else 1 case b: Byte => b.toInt case s: Short => s.toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 83d5b3b76b0a3a14ef42d82d053a072cc80f2ebb..65ae87fe6d166d9a20c07506c2b78de8adb7cb1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -56,7 +56,6 @@ object Cast { case (_, DateType) => true case (StringType, IntervalType) => true - case (IntervalType, StringType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 886a486bf5ee03db75f992840975f2cef2e2d6b7..bf47a6c75b80974d6730c13d3e7f401d6e7f99f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -110,7 +110,7 @@ class JoinedRow extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -204,7 +204,7 @@ class JoinedRow2 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -292,7 +292,7 @@ class JoinedRow3 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -380,7 +380,7 @@ class JoinedRow4 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -468,7 +468,7 @@ class JoinedRow5 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -556,7 +556,7 @@ class JoinedRow6 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index efa24710a5a67fe263c61914d11ba6c031cd74e0..6f291d2c86c1e0868211345724f16af94af5b0fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -219,7 +219,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def apply(i: Int): Any = values(i).boxed + override def get(i: Int): Any = values(i).boxed override def isNullAt(i: Int): Boolean = values(i).isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 094904bbf9c15048a6e2decfba11836a1728a715..d78be5a5958f98e46764921adf6fa321097d1624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -66,7 +66,7 @@ trait ArrayBackedRow { def length: Int = values.length - override def apply(i: Int): Any = values(i) + override def get(i: Int): Any = values(i) def setNullAt(i: Int): Unit = { values(i) = null} @@ -84,27 +84,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBa def this(size: Int) = this(new Array[Any](size)) - // This is used by test or outside - override def equals(o: Any): Boolean = o match { - case other: Row if other.length == length => - var i = 0 - while (i < length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - val equal = (apply(i), other.apply(i)) match { - case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b) - case (a, b) => a == b - } - if (!equal) { - return false - } - i += 1 - } - true - case _ => false - } - override def copy(): Row = this }