diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java new file mode 100644 index 0000000000000000000000000000000000000000..5f28d52a94bd7c7bafea170c9bc1906f71996b2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.UTF8String; + +public interface SpecializedGetters { + + boolean isNullAt(int ordinal); + + boolean getBoolean(int ordinal); + + byte getByte(int ordinal); + + short getShort(int ordinal); + + int getInt(int ordinal); + + long getLong(int ordinal); + + float getFloat(int ordinal); + + double getDouble(int ordinal); + + Decimal getDecimal(int ordinal); + + UTF8String getUTF8String(int ordinal); + + byte[] getBinary(int ordinal); + + Interval getInterval(int ordinal); + + InternalRow getStruct(int ordinal, int numFields); + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 9a11de3840ce2d002190f3271d946b5272630788..e395a67434fa71d028306f478515054c09f9ce78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.{Interval, UTF8String} * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Serializable { +abstract class InternalRow extends Serializable with SpecializedGetters { def numFields: Int @@ -38,29 +38,30 @@ abstract class InternalRow extends Serializable { def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] - def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) + override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) + override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) + override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) + override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) + override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) + override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) + override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) + override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) + override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int): Decimal = + getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) - def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) + override def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString @@ -71,7 +72,8 @@ abstract class InternalRow extends Serializable { * @param ordinal position to get the struct from. * @param numFields number of fields the struct type has */ - def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = + getAs[InternalRow](ordinal, null) override def toString: String = s"[${this.mkString(",")}]"