diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 955fb4226fc0e45cade91de0c8aff8357b38589e..64a8edc34d681e402491105fd4c82f678c868072 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -239,7 +239,7 @@ public final class UnsafeRow extends MutableRow { @Override public Object get(int ordinal, DataType dataType) { - if (dataType instanceof NullType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { return null; } else if (dataType instanceof BooleanType) { return getBoolean(ordinal); @@ -313,21 +313,13 @@ public final class UnsafeRow extends MutableRow { @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } @Override diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 2834b54e8fb2e71a0c80fb106ea7d2173297ba76..b7bc17f89e82ffa42e1567d98bda4804b38f5b51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -146,8 +146,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getShort(3) === 0) assert(createdFromNull.getInt(4) === 0) assert(createdFromNull.getLong(5) === 0) - assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getFloat(6) === 0.0f) + assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) // assert(createdFromNull.get(10) === null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index ad3bb1744cb3c5066da33e5935cd589f5d0075b1..e72a1bc6c4e20c9cb1438da1965c1363d448348d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -67,4 +67,19 @@ class UnsafeRowSuite extends SparkFunSuite { assert(bytesFromArrayBackedRow === bytesFromOffheapRow) } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } }