From ce8edb8bf4db5f82bcfeb11efbdf5229b0d25dfa Mon Sep 17 00:00:00 2001
From: Ala Luszczak <ala@databricks.com>
Date: Fri, 19 May 2017 13:18:48 +0200
Subject: [PATCH] [SPARK-20798] GenerateUnsafeProjection should check if a
 value is null before calling the getter

## What changes were proposed in this pull request?

GenerateUnsafeProjection.writeStructToBuffer() did not honor the assumption that the caller must make sure that a value is not null before using the getter. This could lead to various errors. This change fixes that behavior.

Example of code generated before:
```scala
/* 059 */         final UTF8String fieldName = value.getUTF8String(0);
/* 060 */         if (value.isNullAt(0)) {
/* 061 */           rowWriter1.setNullAt(0);
/* 062 */         } else {
/* 063 */           rowWriter1.write(0, fieldName);
/* 064 */         }
```

Example of code generated now:
```scala
/* 060 */         boolean isNull1 = value.isNullAt(0);
/* 061 */         UTF8String value1 = isNull1 ? null : value.getUTF8String(0);
/* 062 */         if (isNull1) {
/* 063 */           rowWriter1.setNullAt(0);
/* 064 */         } else {
/* 065 */           rowWriter1.write(0, value1);
/* 066 */         }
```

## How was this patch tested?

Adds GenerateUnsafeProjectionSuite.

Author: Ala Luszczak <ala@databricks.com>

Closes #18030 from ala/fix-generate-unsafe-projection.
---
 .../codegen/GenerateUnsafeProjection.scala    | 15 +++--
 .../GenerateUnsafeProjectionSuite.scala       | 61 +++++++++++++++++++
 .../execution/vectorized/ColumnarBatch.java   |  6 ++
 3 files changed, 78 insertions(+), 4 deletions(-)
 create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7e4c9089a2..b358102d91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
       fieldTypes: Seq[DataType],
       bufferHolder: String): String = {
     val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
-      val fieldName = ctx.freshName("fieldName")
-      val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
-      val isNull = s"$input.isNullAt($i)"
-      ExprCode(code, isNull, fieldName)
+      val javaType = ctx.javaType(dt)
+      val isNullVar = ctx.freshName("isNull")
+      val valueVar = ctx.freshName("value")
+      val defaultValue = ctx.defaultValue(dt)
+      val readValue = ctx.getValue(input, dt, i.toString)
+      val code =
+        s"""
+          boolean $isNullVar = $input.isNullAt($i);
+          $javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
+        """
+      ExprCode(code, isNullVar, valueVar)
     }
 
     s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
new file mode 100644
index 0000000000..e9d21f8a8e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BoundReference
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+class GenerateUnsafeProjectionSuite extends SparkFunSuite {
+  test("Test unsafe projection string access pattern") {
+    val dataType = (new StructType).add("a", StringType)
+    val exprs = BoundReference(0, dataType, nullable = true) :: Nil
+    val projection = GenerateUnsafeProjection.generate(exprs)
+    val result = projection.apply(InternalRow(AlwaysNull))
+    assert(!result.isNullAt(0))
+    assert(result.getStruct(0, 1).isNullAt(0))
+  }
+}
+
+object AlwaysNull extends InternalRow {
+  override def numFields: Int = 1
+  override def setNullAt(i: Int): Unit = {}
+  override def copy(): InternalRow = this
+  override def anyNull: Boolean = true
+  override def isNullAt(ordinal: Int): Boolean = true
+  override def update(i: Int, value: Any): Unit = notSupported
+  override def getBoolean(ordinal: Int): Boolean = notSupported
+  override def getByte(ordinal: Int): Byte = notSupported
+  override def getShort(ordinal: Int): Short = notSupported
+  override def getInt(ordinal: Int): Int = notSupported
+  override def getLong(ordinal: Int): Long = notSupported
+  override def getFloat(ordinal: Int): Float = notSupported
+  override def getDouble(ordinal: Int): Double = notSupported
+  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
+  override def getUTF8String(ordinal: Int): UTF8String = notSupported
+  override def getBinary(ordinal: Int): Array[Byte] = notSupported
+  override def getInterval(ordinal: Int): CalendarInterval = notSupported
+  override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
+  override def getArray(ordinal: Int): ArrayData = notSupported
+  override def getMap(ordinal: Int): MapData = notSupported
+  override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
+  private def notSupported: Nothing = throw new UnsupportedOperationException
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index a6ce4c2edc..8b7b0e655b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -198,21 +198,25 @@ public final class ColumnarBatch {
 
     @Override
     public Decimal getDecimal(int ordinal, int precision, int scale) {
+      if (columns[ordinal].isNullAt(rowId)) return null;
       return columns[ordinal].getDecimal(rowId, precision, scale);
     }
 
     @Override
     public UTF8String getUTF8String(int ordinal) {
+      if (columns[ordinal].isNullAt(rowId)) return null;
       return columns[ordinal].getUTF8String(rowId);
     }
 
     @Override
     public byte[] getBinary(int ordinal) {
+      if (columns[ordinal].isNullAt(rowId)) return null;
       return columns[ordinal].getBinary(rowId);
     }
 
     @Override
     public CalendarInterval getInterval(int ordinal) {
+      if (columns[ordinal].isNullAt(rowId)) return null;
       final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
       final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
       return new CalendarInterval(months, microseconds);
@@ -220,11 +224,13 @@ public final class ColumnarBatch {
 
     @Override
     public InternalRow getStruct(int ordinal, int numFields) {
+      if (columns[ordinal].isNullAt(rowId)) return null;
       return columns[ordinal].getStruct(rowId);
     }
 
     @Override
     public ArrayData getArray(int ordinal) {
+      if (columns[ordinal].isNullAt(rowId)) return null;
       return columns[ordinal].getArray(rowId);
     }
 
-- 
GitLab