diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 55da0e094d1320bd95b01db886dc6592b1ca271e..b6e2c30fbf1047ede3dbe9c3c3820a3efd10bb9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -174,8 +174,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { case row: InternalRow => - require(row.length == 7, - s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") + require(row.numFields == 7, + s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9067b3ba9a7bb69f8b8577f01a7bcaf6deea33e9..c884aad08889feb5b1885b1f870ae438ddb527f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -203,8 +203,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { case row: InternalRow => - require(row.length == 4, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + require(row.numFields == 4, + s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") val tpe = row.getByte(0) tpe match { case 0 => diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fa1216b455a9e2aac9a010813c35fbb1301b4174..a8986608855e296d7442d4f67d4e554be8408689 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -64,7 +64,8 @@ public final class UnsafeRow extends MutableRow { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - public int length() { return numFields; } + @Override + public int numFields() { return numFields; } /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -218,12 +219,12 @@ public final class UnsafeRow extends MutableRow { } @Override - public int size() { - return numFields; + public Object get(int i) { + throw new UnsupportedOperationException(); } @Override - public Object get(int i) { + public <T> T getAs(int i) { throw new UnsupportedOperationException(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index bfaee04f33b7ffa6d2a87baf4cb6181625704d1c..5c3072a77aeba3f3a7a9808983683b091ff278ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -140,14 +140,14 @@ object CatalystTypeConverters { private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: InternalRow, column: Int): Any = row(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column)) } /** Converter for arrays, sequences, and Java iterables. */ @@ -184,7 +184,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row(column).asInstanceOf[Seq[Any]]) + toScala(row.get(column).asInstanceOf[Seq[Any]]) } private case class MapConverter( @@ -227,7 +227,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row(column).asInstanceOf[Map[Any, Any]]) + toScala(row.get(column).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( @@ -260,9 +260,9 @@ object CatalystTypeConverters { if (row == null) { null } else { - val ar = new Array[Any](row.size) + val ar = new Array[Any](row.numFields) var idx = 0 - while (idx < row.size) { + while (idx < row.numFields) { ar(idx) = converters(idx).toScala(row, idx) idx += 1 } @@ -271,7 +271,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Row = - toScala(row(column).asInstanceOf[InternalRow]) + toScala(row.get(column).asInstanceOf[InternalRow]) } private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index c7ec49b3d6c3dfe6125ef20d7fa90e3881fae791..efc4faea569b207f58e38b9c52b210b9f1e76c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -25,48 +25,139 @@ import org.apache.spark.unsafe.types.UTF8String * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Row { +abstract class InternalRow extends Serializable { - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + def numFields: Int - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + def get(i: Int): Any - // This is only use for test - override def getString(i: Int): String = getAs[UTF8String](i).toString - - // These expensive API should not be used internally. - final override def getDecimal(i: Int): java.math.BigDecimal = - throw new UnsupportedOperationException - final override def getDate(i: Int): java.sql.Date = - throw new UnsupportedOperationException - final override def getTimestamp(i: Int): java.sql.Timestamp = - throw new UnsupportedOperationException - final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException - final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException - final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = - throw new UnsupportedOperationException - final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = - throw new UnsupportedOperationException - final override def getStruct(i: Int): Row = throw new UnsupportedOperationException - final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException - final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = - throw new UnsupportedOperationException - - // A default implementation to change the return type - override def copy(): InternalRow = this + // TODO: Remove this. + def apply(i: Int): Any = get(i) + + def getAs[T](i: Int): T = get(i).asInstanceOf[T] + + def isNullAt(i: Int): Boolean = get(i) == null + + def getBoolean(i: Int): Boolean = getAs[Boolean](i) + + def getByte(i: Int): Byte = getAs[Byte](i) + + def getShort(i: Int): Short = getAs[Short](i) + + def getInt(i: Int): Int = getAs[Int](i) + + def getLong(i: Int): Long = getAs[Long](i) + + def getFloat(i: Int): Float = getAs[Float](i) + + def getDouble(i: Int): Double = getAs[Double](i) + + override def toString: String = s"[${this.mkString(",")}]" + + /** + * Make a copy of the current [[InternalRow]] object. + */ + def copy(): InternalRow = this + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + /* ---------------------- utility methods for Scala ---------------------- */ /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on internal row with external row. + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow] + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + + // This is only use for test + def getString(i: Int): String = getAs[UTF8String](i).toString // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 var i = 0 - while (i < length) { + val len = numFields + while (i < len) { val update: Int = if (isNullAt(i)) { 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c66854d52c50bd27b2959b7592cc87b0e5e19d8a..47ad3e089e4c759a5262fd9b88431572e90c5122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -382,8 +382,8 @@ case class Cast(child: Expression, dataType: DataType) val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 - while (i < row.length) { - newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i))) + while (i < row.numFields) { + newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row.get(i))) i += 1 } newRow.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 04872fbc8b0914056f1b97fe9c81b8818d254108..dbda05a792cbf54b7da4d021f1be03fe620c1f83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -176,49 +176,49 @@ class JoinedRow extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -278,49 +278,49 @@ class JoinedRow2 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -374,50 +374,50 @@ class JoinedRow3 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -471,50 +471,50 @@ class JoinedRow4 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -568,50 +568,50 @@ class JoinedRow5 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -665,50 +665,50 @@ class JoinedRow6 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 6f291d2c86c1e0868211345724f16af94af5b0fb..4b4833bd06a3bf3510d83755d4b9a588c101766a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -211,7 +211,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this() = this(Seq.empty) - override def length: Int = values.length + override def numFields: Int = values.length override def toSeq: Seq[Any] = values.map(_.boxed).toSeq @@ -245,7 +245,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = apply(ordinal).toString + override def getString(ordinal: Int): String = get(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 73fde4e9164d7a6ea05d51038c578883576b41b5..62b6cc834c9c94947d196830c6b8286ce6477301 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -675,7 +675,7 @@ case class CombineSetsAndSumFunction( val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] val inputIterator = inputSetEval.iterator while (inputIterator.hasNext) { - seen.add(inputIterator.next) + seen.add(inputIterator.next()) } } @@ -685,7 +685,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( + casted.iterator.map(f => f.get(0)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 405d6b0e3bc76b18267e66c77e720626eadc9831..f0efc4bff12bac63351fb294a9aaa7e398f50cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -178,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $initColumns } - public int length() { return ${expressions.length};} + public int numFields() { return ${expressions.length};} protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5504781edca1b0ca9a7f0a0c3b0da0b3518e8faf..c91122cda2a4126a48b0a3b1a8777392d1e74a0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -110,7 +110,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow](ordinal) + input.asInstanceOf[InternalRow].get(ordinal) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -142,7 +142,7 @@ case class GetArrayStructFields( protected override def nullSafeEval(input: Any): Any = { input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row(ordinal) + if (row == null) null else row.get(ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d78be5a5958f98e46764921adf6fa321097d1624..53779dd4049d1a7cbd0b8eeb041b13c9a0a78f4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,9 +44,10 @@ abstract class MutableRow extends InternalRow { } override def copy(): InternalRow = { - val arr = new Array[Any](length) + val n = numFields + val arr = new Array[Any](n) var i = 0 - while (i < length) { + while (i < n) { arr(i) = get(i) i += 1 } @@ -54,36 +55,23 @@ abstract class MutableRow extends InternalRow { } } -/** - * A row implementation that uses an array of objects as the underlying storage. - */ -trait ArrayBackedRow { - self: Row => - - protected val values: Array[Any] - - override def toSeq: Seq[Any] = values.toSeq - - def length: Int = values.length - - override def get(i: Int): Any = values(i) - - def setNullAt(i: Int): Unit = { values(i) = null} - - def update(i: Int, value: Any): Unit = { values(i) = value } -} - /** * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def length: Int = values.length + + override def get(i: Int): Any = values(i) + + override def toSeq: Seq[Any] = values.toSeq + override def copy(): Row = this } @@ -101,34 +89,49 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(protected[sql] val values: Array[Any]) - extends InternalRow with ArrayBackedRow { +class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int): Any = values(i) + override def copy(): InternalRow = this } /** * This is used for serialization of Python DataFrame */ -class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) +class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) extends GenericInternalRow(values) { /** No-arg constructor for serialization. */ protected def this() = this(null, null) - override def fieldIndex(name: String): Int = schema.fieldIndex(name) + def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { +class GenericMutableRow(val values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int): Any = values(i) + + override def setNullAt(i: Int): Unit = { values(i) = null} + + override def update(i: Int, value: Any): Unit = { values(i) = value } + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 878a1bb9b7e6d82c0a96a0709fc7d6c962353022..01ff84cb56054ccbbc34ebb1a6a2fb7a4357af69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -83,15 +83,5 @@ class RowTest extends FunSpec with Matchers { it("equality check for internal rows") { internalRow shouldEqual internalRow2 } - - it("throws an exception when check equality between external and internal rows") { - def assertError(f: => Unit): Unit = { - val e = intercept[UnsupportedOperationException](f) - e.getMessage.contains("cannot check equality between external and internal rows") - } - - assertError(internalRow.equals(externalRow)) - assertError(externalRow.equals(internalRow)) - } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index facf65c15514829c1ec772315b820b6b5ff104da..408353cf70a49b29406e5474aea146b57a51b8d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for data type casting expression [[Cast]]. @@ -580,14 +581,21 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from struct") { val struct = Literal.create( - InternalRow("123", "abc", "", null), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString(""), + null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal.create( - InternalRow("123", "abc", ""), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString("")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -676,8 +684,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( InternalRow( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), + Map( + UTF8String.fromString("a") -> UTF8String.fromString("123"), + UTF8String.fromString("b") -> UTF8String.fromString("abc"), + UTF8String.fromString("c") -> UTF8String.fromString("")), InternalRow(0)), StructType(Seq( StructField("a", @@ -700,7 +711,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, InternalRow( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map( + UTF8String.fromString("a") -> true, + UTF8String.fromString("b") -> true, + UTF8String.fromString("c") -> false), InternalRow(0L))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index a8aee8f634e0321808a1b77f8108d96a7f3e6270..fc842772f3480d5fe16cea20647e79680826a61f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -150,12 +151,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateNamedStruct with literal field") { val row = InternalRow(1, 2, 3) val c1 = 'a.int.at(0) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), + InternalRow(1, UTF8String.fromString("y")), row) } test("CreateNamedStruct from all literal fields") { checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + CreateNamedStruct(Seq("a", "x", "b", 2.0)), + InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty) } test("test dsl for complex type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9d8415f06399c8309b20b65f1fb763d73f9b3d06..ac42bde07c37d257c35deeb6a2ac033dbc327dd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -309,7 +309,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getString(ordinal).getBytes("utf-8").length + 4 + row.getUTF8String(ordinal).numBytes() + 4 } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 38720968c1313b506c2671b7756111cdf9ad7310..5d5b0697d70169ccfa83473c750c154ac98ac0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -134,13 +134,13 @@ private[sql] case class InMemoryRelation( // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat // hard to decipher. assert( - row.size == columnBuilders.size, - s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}. - |Row content: $row - """.stripMargin) + row.numFields == columnBuilders.size, + s"Row column number mismatch, expected ${output.size} columns, " + + s"but got ${row.numFields}." + + s"\nRow content: $row") var i = 0 - while (i < row.length) { + while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) i += 1 } @@ -304,7 +304,7 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[InternalRow] { - private[this] val rowLen = nextRow.length + private[this] val rowLen = nextRow.numFields override def next(): InternalRow = { var i = 0 while (i < rowLen) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c87e2064a8f334e3da043a72f4575d7c48433728..83c4e8733f15f26b0887c25a20e07204a532b71f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -25,7 +25,6 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.serializer._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ @@ -53,7 +52,7 @@ private[sql] class Serializer2SerializationStream( private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[Row, Row]] + val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] writeKey(kv._1) writeValue(kv._2) @@ -66,7 +65,7 @@ private[sql] class Serializer2SerializationStream( } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[InternalRow]) this } @@ -205,8 +204,9 @@ private[sql] object SparkSqlSerializer2 { /** * The util function to create the serialization function based on the given schema. */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { - (row: Row) => + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) + : InternalRow => Unit = { + (row: InternalRow) => // If the schema is null, the returned function does nothing when it get called. if (schema != null) { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2b400926177fee6ff85201db494fc0414797ad72..7f452daef33c56b79b5e9645e160efef14313887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -206,7 +206,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => var i = 0 - while (i < mutableRow.length) { + while (i < mutableRow.numFields) { mergers(i)(mutableRow, dataRow, i) i += 1 } @@ -315,7 +315,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (relation.relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index cd2aa7f7433c5738bbc474ec6460a56b7cc80f90..d551f386eee6eb47766f2a35cc906b629a486775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -174,14 +174,19 @@ private[sql] case class InsertIntoHadoopFsRelation( try { writerContainer.executorSideSetup(taskContext) - val converter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow) + .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) + } } writerContainer.commitTask() @@ -248,17 +253,23 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionProj = newProjection(codegenEnabled, partitionCasts, output) val dataProj = newProjection(codegenEnabled, dataOutput, output) - val dataConverter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = converter(dataProj(internalRow)) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataConverter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = dataProj(internalRow) + writerContainer.outputWriterForRow(partitionPart) + .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) + } } writerContainer.commitTask() @@ -530,8 +541,12 @@ private[sql] class DynamicPartitionWriterContainer( while (i < partitionColumns.length) { val col = partitionColumns(i) val partitionValueString = { - val string = row.getString(i) - if (string.eq(null)) defaultPartitionName else PartitioningUtils.escapePathName(string) + val string = row.getUTF8String(i) + if (string.eq(null)) { + defaultPartitionName + } else { + PartitioningUtils.escapePathName(string.toString) + } } if (i > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c8033d3c0470adac758c02c16e73127411ab58c0..1f2797ec5527a6568ade2ec9d034210616bc2a24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,11 +23,11 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -415,12 +415,12 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[InternalRow] = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } @@ -432,20 +432,20 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) @@ -464,7 +464,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) sqlContext.cacheManager.cacheQuery(df, Some(tableName)) } - Seq.empty[InternalRow] + Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index e6e27a87c715159c4d5cfca735a75d78aa6dbeeb..40bf03a3f1a62aa7e4cbfd9bf41dc3689d4a8b00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -126,9 +126,9 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.size) + val values = new Array[Any](row.numFields) var i = 0 - while (i < row.size) { + while (i < row.numFields) { values(i) = toJava(row(i), struct.fields(i).dataType) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 6c49a906c848a1e79713d305363c9ea771385b88..46f0fac861282100e339dcea0bc4b7ccf9a77a90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -148,7 +148,7 @@ class InputAggregationBuffer private[sql] ( toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, - var underlyingInputBuffer: Row) + var underlyingInputBuffer: InternalRow) extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { override def get(i: Int): Any = { @@ -156,6 +156,7 @@ class InputAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + // TODO: Use buffer schema to avoid using generic getter. toScalaConverters(i)(underlyingInputBuffer(offsets(i))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4d3aac464c53849361ebafd2ba380d3891d20199..41d0ecb4bbfbf9bf9f15033b8e6f581b294ec6d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -128,6 +128,7 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverRegistry.getDriverClassName(url) + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, @@ -137,7 +138,7 @@ private[sql] case class JDBCRelation( table, requiredColumns, filters, - parts).map(_.asInstanceOf[Row]) + parts).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 922794ac9aac5f0238068bed9054e7ec9d006b75..562b058414d079ef39fe6d271feb99881b74c002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -154,17 +154,19 @@ private[sql] class JSONRelation( } override def buildScan(): RDD[Row] = { + // Rely on type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 0c3d8fdab6bd25256dca1cffee887f2e08c30935..b5e4263008f5661fb9843d40b761b36ecc4098ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -28,7 +28,7 @@ import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveCo import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -55,8 +55,8 @@ private[parquet] trait ParentContainerUpdater { private[parquet] object NoopUpdater extends ParentContainerUpdater /** - * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since - * any Parquet record is also a struct, this converter can also be used as root converter. + * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. + * Since any Parquet record is also a struct, this converter can also be used as root converter. * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. @@ -108,7 +108,7 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = { var i = 0 - while (i < currentRow.length) { + while (i < currentRow.numFields) { currentRow.setNullAt(i) i += 1 } @@ -178,7 +178,7 @@ private[parquet] class CatalystRowConverter( case t: StructType => new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { - override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy()) + override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) case t: UserDefinedType[_] => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cba5e54d69e4ab8e1ca5c6d4fb8a3f3bcf8754..8cab27d6e1c46d524dbea70c2dc02f06cba4aa7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -178,7 +178,7 @@ private[sql] case class ParquetTableScan( val row = iter.next()._2.asInstanceOf[InternalRow] var i = 0 - while (i < row.size) { + while (i < row.numFields) { mutableRow(i) = row(i) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index d1040bf5562a211613fba81795bb26fc2ba19192..c7c58e69d42ef1ce136908aa19af9554d8ccfcd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -208,9 +208,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 @@ -378,9 +378,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] class MutableRowWriteSupport extends RowWriteSupport { override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index c384697c0ee62f24d7622611290bda02bd19ed7f..8ec228c2b25bcd8d2577c55ec6c4532fc5f50d01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -61,7 +61,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { + extends OutputWriterInternal { private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { @@ -86,7 +86,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } @@ -324,7 +324,7 @@ private[sql] class ParquetRelation2( new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } } - }.values.map(_.asInstanceOf[Row]) + }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7cd005b9594883ef15093db8a52f8dcf271420b8..119bac786d478ff3ce5e9be49ca15f7112e975b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -344,6 +344,18 @@ abstract class OutputWriter { def close(): Unit } +/** + * This is an internal, private version of [[OutputWriter]] with an writeInternal method that + * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have + * the conversion flag set to false. + */ +private[sql] abstract class OutputWriterInternal extends OutputWriter { + + override def write(row: Row): Unit = throw new UnsupportedOperationException + + def writeInternal(row: InternalRow): Unit +} + /** * ::Experimental:: * A [[BaseRelation]] that provides much of the common code required for formats that store their @@ -592,12 +604,12 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - val rdd = buildScan(inputFiles) - val converted = + val rdd: RDD[Row] = buildScan(inputFiles) + val converted: RDD[InternalRow] = if (needConversion) { RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } converted.mapPartitions { rows => val buildProjection = if (codegenEnabled) { @@ -606,8 +618,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r).asInstanceOf[Row]) - } + rows.map(r => mutableProjection(r)) + }.asInstanceOf[RDD[Row]] } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 7cc6ffd7548d0d7bb2819cbe2da5fe80db07a71f..0e5c5abff85f6d65bcf7b599b85087cc7d11df95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -35,14 +35,14 @@ class RowSuite extends SparkFunSuite { expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) - assert(expected.size === actual1.size) + assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected(3) === actual1(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) - assert(expected.size === actual2.size) + assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index da53ec16b5c417ee7b8d9a4a985e2923c7982a11..84855ce45e91804af6877bf9ddc82b4d7d23ffee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -61,9 +61,10 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row - } + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 257526feab9452501819c2e48042cbe7dbfcb99b..0d5183444af78599caf8c9324b5e79f4536e5380 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -131,7 +131,7 @@ class PrunedScanSuite extends DataSourceTest { queryExecution) } - if (rawOutput.size != expectedColumns.size) { + if (rawOutput.numFields != expectedColumns.size) { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 143aadc08b1c48a329cc8ca3c0c69def3cb14d69..5e189c3563ca88bd1465edd21a9e6087e8326e3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -93,7 +93,7 @@ case class AllDataTypesScan( InternalRow(i, UTF8String.fromString(i.toString)), InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - } + }.asInstanceOf[RDD[Row]] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8202e553afbfe03a8c9a0e6329a24be500380e68..34b629403e128a0791b243e68e84f8307ec004f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -122,7 +122,7 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc @@ -252,13 +252,12 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[InternalRow] + Seq.empty[Row] } - override def executeCollect(): Array[Row] = - sideEffectResult.toArray + override def executeCollect(): Array[Row] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(sideEffectResult, 1) + sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ecc78a5f8d32150d423d0086112a12b323bb2467..8850e060d2a737993f9086367a67c878c97f9f59 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ @@ -94,7 +95,9 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { + writer + } def close() { // Seems the boolean value passed into close does not matter. @@ -197,7 +200,8 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: InternalRow, schema: StructType) + : FileSinkOperator.RecordWriter = { def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index de63ee56dd8e6adc0696c9aabca4c109da7df4bc..10623dc820316540dc1c499c32a343ff9382002a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -119,9 +119,9 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = { + override def writeInternal(row: InternalRow): Unit = { var i = 0 - while (i < row.length) { + while (i < row.numFields) { reusableOutputBuffer(i) = wrappers(i)(row(i)) i += 1 } @@ -192,7 +192,7 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) + OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] } override def prepareJobForWrite(job: Job): OutputWriterFactory = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e976125b3706dca57fdea737bfc2d27b7a6813e3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils + + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + override val sqlContext = TestHive + + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..d280543a071d9f55cff060047bde4df590d4fec8 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + + +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write.parquet(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[AnalysisException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..d761909d60e21b38715a669e284e983372d4a33e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/* +This is commented out due a bug in the data source API (SPARK-9291). + + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} +*/ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a8748d913569ae105f5e42909114d8a33ed9e4d..dd274023a1cf576b6b407a7713486fbc52539e9f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,18 +17,14 @@ package org.apache.spark.sql.sources -import java.io.File - import scala.collection.JavaConversions._ -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.hadoop.ParquetOutputCommitter -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -581,165 +577,3 @@ class AlwaysFailParquetOutputCommitter( sys.error("Intentional job commitment failure for testing purpose.") } } - -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} - -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive - - // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. - val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName - - test("SPARK-7684: commitTask() failure should fallback to abortTask()") { - withTempPath { file => - // Here we coalesce partition number to 1 to ensure that only a single task is issued. This - // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` - // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) - intercept[SparkException] { - df.write.format(dataSourceName).save(file.getCanonicalPath) - } - - val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) - assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) - } - } -} - -class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName - - import sqlContext._ - import sqlContext.implicits._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) - .toDF("a", "b", "p1") - .write.parquet(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } - - test("SPARK-7868: _temporary directories should be ignored") { - withTempPath { dir => - val df = Seq("a", "b", "c").zipWithIndex.toDF() - - df.write - .format("parquet") - .save(dir.getCanonicalPath) - - df.write - .format("parquet") - .save(s"${dir.getCanonicalPath}/_temporary") - - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) - } - } - - test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { - withTempDir { dir => - val path = dir.getCanonicalPath - val df = Seq(1 -> "a").toDF() - - // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw - // since it's not a valid Parquet file. - val emptyFile = new File(path, "empty") - Files.createParentDirs(emptyFile) - Files.touch(emptyFile) - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Ignore).save(path) - - // This should only complain that the destination directory already exists, rather than file - // "empty" is not a Parquet file. - assert { - intercept[AnalysisException] { - df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - }.getMessage.contains("already exists") - } - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) - } - } - - test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { - withTempPath { dir => - intercept[AnalysisException] { - // Parquet doesn't allow field names with spaces. Here we are intentionally making an - // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger - // the bug. Please refer to spark-8079 for more details. - range(1, 10) - .withColumnRenamed("id", "a b") - .write - .format("parquet") - .save(dir.getCanonicalPath) - } - } - } - - test("SPARK-8604: Parquet data source should write summary file while doing appending") { - withTempPath { dir => - val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5) - df.write.mode(SaveMode.Overwrite).parquet(path) - - val summaryPath = new Path(path, "_metadata") - val commonSummaryPath = new Path(path, "_common_metadata") - - val fs = summaryPath.getFileSystem(configuration) - fs.delete(summaryPath, true) - fs.delete(commonSummaryPath, true) - - df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) - - assert(fs.exists(summaryPath)) - assert(fs.exists(commonSummaryPath)) - } - } -}