Skip to content
Snippets Groups Projects
Commit 9897cc5e authored by Reynold Xin's avatar Reynold Xin Committed by Yin Huai
Browse files

[SPARK-9736] [SQL] JoinedRow.anyNull should delegate to the underlying rows.

JoinedRow.anyNull currently loops through every field to check for null, which is inefficient if the underlying rows are UnsafeRows. It should just delegate to the underlying implementation.

Author: Reynold Xin <rxin@databricks.com>

Closes #8027 from rxin/SPARK-9736 and squashes the following commits:

03a2e92 [Reynold Xin] Include all files.
90f1add [Reynold Xin] [SPARK-9736][SQL] JoinedRow.anyNull should delegate to the underlying rows.
parent 2432c2e2
No related branches found
No related tags found
No related merge requests found
...@@ -37,15 +37,7 @@ abstract class InternalRow extends SpecializedGetters with Serializable { ...@@ -37,15 +37,7 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
def copy(): InternalRow def copy(): InternalRow
/** Returns true if there are any NULL values in this row. */ /** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = { def anyNull: Boolean
val len = numFields
var i = 0
while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
/* ---------------------- utility methods for Scala ---------------------- */ /* ---------------------- utility methods for Scala ---------------------- */
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
* be instantiated once per thread and reused.
*/
class JoinedRow extends InternalRow {
private[this] var row1: InternalRow = _
private[this] var row2: InternalRow = _
def this(left: InternalRow, right: InternalRow) = {
this()
row1 = left
row2 = right
}
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
row1 = r1
row2 = r2
this
}
/** Updates this JoinedRow by updating its left base row. Returns itself. */
def withLeft(newLeft: InternalRow): InternalRow = {
row1 = newLeft
this
}
/** Updates this JoinedRow by updating its right base row. Returns itself. */
def withRight(newRight: InternalRow): InternalRow = {
row2 = newRight
this
}
override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
assert(fieldTypes.length == row1.numFields + row2.numFields)
val (left, right) = fieldTypes.splitAt(row1.numFields)
row1.toSeq(left) ++ row2.toSeq(right)
}
override def numFields: Int = row1.numFields + row2.numFields
override def get(i: Int, dt: DataType): AnyRef =
if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
override def isNullAt(i: Int): Boolean =
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getByte(i: Int): Byte =
if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getShort(i: Int): Short =
if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getInt(i: Int): Int =
if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getFloat(i: Int): Float =
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def getDouble(i: Int): Double =
if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
if (i < row1.numFields) {
row1.getDecimal(i, precision, scale)
} else {
row2.getDecimal(i - row1.numFields, precision, scale)
}
}
override def getUTF8String(i: Int): UTF8String =
if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
override def getBinary(i: Int): Array[Byte] =
if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
override def getArray(i: Int): ArrayData =
if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)
override def getInterval(i: Int): CalendarInterval =
if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
override def getMap(i: Int): MapData =
if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
override def getStruct(i: Int, numFields: Int): InternalRow = {
if (i < row1.numFields) {
row1.getStruct(i, numFields)
} else {
row2.getStruct(i - row1.numFields, numFields)
}
}
override def anyNull: Boolean = row1.anyNull || row2.anyNull
override def copy(): InternalRow = {
val copy1 = row1.copy()
val copy2 = row2.copy()
new JoinedRow(copy1, copy2)
}
override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
row2.toString
} else if (row2 eq null) {
row1.toString
} else {
s"{${row1.toString} + ${row2.toString}}"
}
}
}
...@@ -169,122 +169,3 @@ object FromUnsafeProjection { ...@@ -169,122 +169,3 @@ object FromUnsafeProjection {
GenerateSafeProjection.generate(exprs) GenerateSafeProjection.generate(exprs)
} }
} }
/**
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
* be instantiated once per thread and reused.
*/
class JoinedRow extends InternalRow {
private[this] var row1: InternalRow = _
private[this] var row2: InternalRow = _
def this(left: InternalRow, right: InternalRow) = {
this()
row1 = left
row2 = right
}
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
row1 = r1
row2 = r2
this
}
/** Updates this JoinedRow by updating its left base row. Returns itself. */
def withLeft(newLeft: InternalRow): InternalRow = {
row1 = newLeft
this
}
/** Updates this JoinedRow by updating its right base row. Returns itself. */
def withRight(newRight: InternalRow): InternalRow = {
row2 = newRight
this
}
override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
assert(fieldTypes.length == row1.numFields + row2.numFields)
val (left, right) = fieldTypes.splitAt(row1.numFields)
row1.toSeq(left) ++ row2.toSeq(right)
}
override def numFields: Int = row1.numFields + row2.numFields
override def get(i: Int, dt: DataType): AnyRef =
if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
override def isNullAt(i: Int): Boolean =
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
override def getBoolean(i: Int): Boolean =
if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
override def getByte(i: Int): Byte =
if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
override def getShort(i: Int): Short =
if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
override def getInt(i: Int): Int =
if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
override def getFloat(i: Int): Float =
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def getDouble(i: Int): Double =
if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
if (i < row1.numFields) {
row1.getDecimal(i, precision, scale)
} else {
row2.getDecimal(i - row1.numFields, precision, scale)
}
}
override def getUTF8String(i: Int): UTF8String =
if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
override def getBinary(i: Int): Array[Byte] =
if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
override def getArray(i: Int): ArrayData =
if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)
override def getInterval(i: Int): CalendarInterval =
if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
override def getMap(i: Int): MapData =
if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
override def getStruct(i: Int, numFields: Int): InternalRow = {
if (i < row1.numFields) {
row1.getStruct(i, numFields)
} else {
row2.getStruct(i - row1.numFields, numFields)
}
}
override def copy(): InternalRow = {
val copy1 = row1.copy()
val copy2 = row2.copy()
new JoinedRow(copy1, copy2)
}
override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
row2.toString
} else if (row2 eq null) {
row1.toString
} else {
s"{${row1.toString} + ${row2.toString}}"
}
}
}
...@@ -49,7 +49,17 @@ trait BaseGenericInternalRow extends InternalRow { ...@@ -49,7 +49,17 @@ trait BaseGenericInternalRow extends InternalRow {
override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal)
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
override def toString(): String = { override def anyNull: Boolean = {
val len = numFields
var i = 0
while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
override def toString: String = {
if (numFields == 0) { if (numFields == 0) {
"[empty row]" "[empty row]"
} else { } else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment