From f1fdb23821b89623b592bfb3ef73d61afbe93b0a Mon Sep 17 00:00:00 2001
From: Takuya UESHIN <ueshin@happy-camper.st>
Date: Thu, 21 Apr 2016 21:17:56 -0700
Subject: [PATCH] [SPARK-14793] [SQL] Code generation for large complex type
 exceeds JVM size limit.

## What changes were proposed in this pull request?

Code generation for complex type, `CreateArray`, `CreateMap`, `CreateStruct`, `CreateNamedStruct`, exceeds JVM size limit for large elements.

We should split generated code into multiple `apply` functions if the complex types have large elements,  like `UnsafeProjection` or others for large expressions.

## How was this patch tested?

I added some tests to check if the generated codes for the expressions exceed or not.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #12559 from ueshin/issues/SPARK-14793.
---
 .../codegen/GenerateSafeProjection.scala      |   1 +
 .../expressions/complexTypeCreator.scala      | 139 +++++++++++-------
 .../expressions/CodeGenerationSuite.scala     |  57 +++++++
 3 files changed, 144 insertions(+), 53 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 7be57aca33..ee1a363145 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -68,6 +68,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
       this.$values = new Object[${schema.length}];
       $allFields
       final InternalRow $output = new $rowClass($values);
+      this.$values = null;
     """
 
     ExprCode(code, "false", output)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 3d4819c55a..d986d9dca6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -51,20 +51,27 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val arrayClass = classOf[GenericArrayData].getName
     val values = ctx.freshName("values")
+    ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
     ev.copy(code = s"""
       final boolean ${ev.isNull} = false;
-      final Object[] $values = new Object[${children.size}];""" +
-      children.zipWithIndex.map { case (e, i) =>
-        val eval = e.genCode(ctx)
-        eval.code + s"""
-          if (${eval.isNull}) {
-            $values[$i] = null;
-          } else {
-            $values[$i] = ${eval.value};
-          }
-         """
-      }.mkString("\n") +
-      s"final ArrayData ${ev.value} = new $arrayClass($values);")
+      this.$values = new Object[${children.size}];""" +
+      ctx.splitExpressions(
+        ctx.INPUT_ROW,
+        children.zipWithIndex.map { case (e, i) =>
+          val eval = e.genCode(ctx)
+          eval.code + s"""
+            if (${eval.isNull}) {
+              $values[$i] = null;
+            } else {
+              $values[$i] = ${eval.value};
+            }
+           """
+        }) +
+      s"""
+        final ArrayData ${ev.value} = new $arrayClass($values);
+        this.$values = null;
+      """)
   }
 
   override def prettyName: String = "array"
@@ -119,34 +126,46 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
     val mapClass = classOf[ArrayBasedMapData].getName
     val keyArray = ctx.freshName("keyArray")
     val valueArray = ctx.freshName("valueArray")
+    ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;")
+    ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;")
+
     val keyData = s"new $arrayClass($keyArray)"
     val valueData = s"new $arrayClass($valueArray)"
     ev.copy(code = s"""
       final boolean ${ev.isNull} = false;
-      final Object[] $keyArray = new Object[${keys.size}];
-      final Object[] $valueArray = new Object[${values.size}];""" +
-      keys.zipWithIndex.map { case (key, i) =>
-        val eval = key.genCode(ctx)
-        s"""
-          ${eval.code}
-          if (${eval.isNull}) {
-            throw new RuntimeException("Cannot use null as map key!");
-          } else {
-            $keyArray[$i] = ${eval.value};
-          }
-        """
-    }.mkString("\n") + values.zipWithIndex.map {
-      case (value, i) =>
-        val eval = value.genCode(ctx)
-        s"""
-          ${eval.code}
-          if (${eval.isNull}) {
-            $valueArray[$i] = null;
-          } else {
-            $valueArray[$i] = ${eval.value};
-          }
-        """
-    }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);")
+      $keyArray = new Object[${keys.size}];
+      $valueArray = new Object[${values.size}];""" +
+      ctx.splitExpressions(
+        ctx.INPUT_ROW,
+        keys.zipWithIndex.map { case (key, i) =>
+          val eval = key.genCode(ctx)
+          s"""
+            ${eval.code}
+            if (${eval.isNull}) {
+              throw new RuntimeException("Cannot use null as map key!");
+            } else {
+              $keyArray[$i] = ${eval.value};
+            }
+          """
+        }) +
+      ctx.splitExpressions(
+        ctx.INPUT_ROW,
+        values.zipWithIndex.map { case (value, i) =>
+          val eval = value.genCode(ctx)
+          s"""
+            ${eval.code}
+            if (${eval.isNull}) {
+              $valueArray[$i] = null;
+            } else {
+              $valueArray[$i] = ${eval.value};
+            }
+          """
+        }) +
+      s"""
+        final MapData ${ev.value} = new $mapClass($keyData, $valueData);
+        this.$keyArray = null;
+        this.$valueArray = null;
+      """)
   }
 
   override def prettyName: String = "map"
@@ -182,19 +201,26 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rowClass = classOf[GenericInternalRow].getName
     val values = ctx.freshName("values")
+    ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
     ev.copy(code = s"""
       boolean ${ev.isNull} = false;
-      final Object[] $values = new Object[${children.size}];""" +
-      children.zipWithIndex.map { case (e, i) =>
-        val eval = e.genCode(ctx)
-        eval.code + s"""
-          if (${eval.isNull}) {
-            $values[$i] = null;
-          } else {
-            $values[$i] = ${eval.value};
-          }"""
-      }.mkString("\n") +
-      s"final InternalRow ${ev.value} = new $rowClass($values);")
+      this.$values = new Object[${children.size}];""" +
+      ctx.splitExpressions(
+        ctx.INPUT_ROW,
+        children.zipWithIndex.map { case (e, i) =>
+          val eval = e.genCode(ctx)
+          eval.code + s"""
+            if (${eval.isNull}) {
+              $values[$i] = null;
+            } else {
+              $values[$i] = ${eval.value};
+            }"""
+        }) +
+      s"""
+        final InternalRow ${ev.value} = new $rowClass($values);
+        this.$values = null;
+      """)
   }
 
   override def prettyName: String = "struct"
@@ -261,19 +287,26 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rowClass = classOf[GenericInternalRow].getName
     val values = ctx.freshName("values")
+    ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
     ev.copy(code = s"""
       boolean ${ev.isNull} = false;
-      final Object[] $values = new Object[${valExprs.size}];""" +
-      valExprs.zipWithIndex.map { case (e, i) =>
-        val eval = e.genCode(ctx)
-        eval.code + s"""
+      $values = new Object[${valExprs.size}];""" +
+      ctx.splitExpressions(
+        ctx.INPUT_ROW,
+        valExprs.zipWithIndex.map { case (e, i) =>
+          val eval = e.genCode(ctx)
+          eval.code + s"""
           if (${eval.isNull}) {
             $values[$i] = null;
           } else {
             $values[$i] = ${eval.value};
           }"""
-      }.mkString("\n") +
-      s"final InternalRow ${ev.value} = new $rowClass($values);")
+        }) +
+      s"""
+        final InternalRow ${ev.value} = new $rowClass($values);
+        this.$values = null;
+      """)
   }
 
   override def prettyName: String = "named_struct"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index b682e7d2b1..2082cea0f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ThreadUtils
@@ -80,6 +81,62 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     assert(actual(0) == cases)
   }
 
+  test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") {
+    val length = 5000
+    val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
+    val plan = GenerateMutableProjection.generate(expressions)
+    val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
+    val expected = Seq(new GenericArrayData(Seq.fill(length)(true)))
+
+    if (!checkResult(actual, expected)) {
+      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
+    }
+  }
+
+  test("SPARK-14793: split wide map creation into blocks due to JVM code size limit") {
+    val length = 5000
+    val expressions = Seq(CreateMap(
+      List.fill(length)(EqualTo(Literal(1), Literal(1))).zipWithIndex.flatMap {
+        case (expr, i) => Seq(Literal(i), expr)
+      }))
+    val plan = GenerateMutableProjection.generate(expressions)
+    val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
+    val expected = Seq(new ArrayBasedMapData(
+      new GenericArrayData(0 until length),
+      new GenericArrayData(Seq.fill(length)(true))))
+
+    if (!checkResult(actual, expected)) {
+      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
+    }
+  }
+
+  test("SPARK-14793: split wide struct creation into blocks due to JVM code size limit") {
+    val length = 5000
+    val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
+    val plan = GenerateMutableProjection.generate(expressions)
+    val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
+    val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
+
+    if (!checkResult(actual, expected)) {
+      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
+    }
+  }
+
+  test("SPARK-14793: split wide named struct creation into blocks due to JVM code size limit") {
+    val length = 5000
+    val expressions = Seq(CreateNamedStruct(
+      List.fill(length)(EqualTo(Literal(1), Literal(1))).flatMap {
+        expr => Seq(Literal(expr.toString), expr)
+      }))
+    val plan = GenerateMutableProjection.generate(expressions)
+    val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
+    val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
+
+    if (!checkResult(actual, expected)) {
+      fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
+    }
+  }
+
   test("test generated safe and unsafe projection") {
     val schema = new StructType(Array(
       StructField("a", StringType, true),
-- 
GitLab