From 148a84b37082697c7f61c6a621010abe4b12f2eb Mon Sep 17 00:00:00 2001
From: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Date: Thu, 19 Jan 2017 15:16:05 -0800
Subject: [PATCH] [SPARK-17912] [SQL] Refactor code generation to get data for
 ColumnVector/ColumnarBatch

## What changes were proposed in this pull request?

This PR refactors the code generation part to get data from `ColumnarVector` and `ColumnarBatch` by using a trait `ColumnarBatchScan` for ease of reuse. This is because this part will be reused by several components (e.g. parquet reader, Dataset.cache, and others) since `ColumnarBatch` will be first citizen.

This PR is a part of https://github.com/apache/spark/pull/15219. In advance, this PR makes the code generation for  `ColumnarVector` and `ColumnarBatch` reuseable as a trait. In general, this is very useful for other components from the reuseability view, too.
## How was this patch tested?

tested existing test suites

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #15467 from kiszk/columnarrefactor.
---
 .../sql/execution/ColumnarBatchScan.scala     | 133 ++++++++++++++++++
 .../sql/execution/DataSourceScanExec.scala    |  86 +----------
 2 files changed, 135 insertions(+), 84 deletions(-)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
new file mode 100644
index 0000000000..04fba17be4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.types.DataType
+
+
+/**
+ * Helper trait for abstracting scan functionality using
+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es.
+ */
+private[sql] trait ColumnarBatchScan extends CodegenSupport {
+
+  val inMemoryTableScan: InMemoryTableScanExec = null
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+    "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
+
+  /**
+   * Generate [[ColumnVector]] expressions for our parent to consume as rows.
+   * This is called once per [[ColumnarBatch]].
+   */
+  private def genCodeColumnVector(
+      ctx: CodegenContext,
+      columnVar: String,
+      ordinal: String,
+      dataType: DataType,
+      nullable: Boolean): ExprCode = {
+    val javaType = ctx.javaType(dataType)
+    val value = ctx.getValue(columnVar, dataType, ordinal)
+    val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
+    val valueVar = ctx.freshName("value")
+    val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
+    val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
+      s"""
+        boolean $isNullVar = $columnVar.isNullAt($ordinal);
+        $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value);
+      """
+    } else {
+      s"$javaType $valueVar = $value;"
+    }).trim
+    ExprCode(code, isNullVar, valueVar)
+  }
+
+  /**
+   * Produce code to process the input iterator as [[ColumnarBatch]]es.
+   * This produces an [[UnsafeRow]] for each row in each batch.
+   */
+  // TODO: return ColumnarBatch.Rows instead
+  override protected def doProduce(ctx: CodegenContext): String = {
+    val input = ctx.freshName("input")
+    // PhysicalRDD always just has one input
+    ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+
+    // metrics
+    val numOutputRows = metricTerm(ctx, "numOutputRows")
+    val scanTimeMetric = metricTerm(ctx, "scanTime")
+    val scanTimeTotalNs = ctx.freshName("scanTime")
+    ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
+
+    val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
+    val batch = ctx.freshName("batch")
+    ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
+
+    val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
+    val idx = ctx.freshName("batchIdx")
+    ctx.addMutableState("int", idx, s"$idx = 0;")
+    val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
+    val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
+      ctx.addMutableState(columnVectorClz, name, s"$name = null;")
+      s"$name = $batch.column($i);"
+    }
+
+    val nextBatch = ctx.freshName("nextBatch")
+    ctx.addNewFunction(nextBatch,
+      s"""
+         |private void $nextBatch() throws java.io.IOException {
+         |  long getBatchStart = System.nanoTime();
+         |  if ($input.hasNext()) {
+         |    $batch = ($columnarBatchClz)$input.next();
+         |    $numOutputRows.add($batch.numRows());
+         |    $idx = 0;
+         |    ${columnAssigns.mkString("", "\n", "\n")}
+         |  }
+         |  $scanTimeTotalNs += System.nanoTime() - getBatchStart;
+         |}""".stripMargin)
+
+    ctx.currentVars = null
+    val rowidx = ctx.freshName("rowIdx")
+    val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
+      genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
+    }
+    s"""
+       |if ($batch == null) {
+       |  $nextBatch();
+       |}
+       |while ($batch != null) {
+       |  int numRows = $batch.numRows();
+       |  while ($idx < numRows) {
+       |    int $rowidx = $idx++;
+       |    ${consume(ctx, columnsBatchInput).trim}
+       |    if (shouldStop()) return;
+       |  }
+       |  $batch = null;
+       |  $nextBatch();
+       |}
+       |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
+       |$scanTimeTotalNs = 0;
+     """.stripMargin
+  }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 7616164397..39b010efec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -145,7 +145,7 @@ case class FileSourceScanExec(
     partitionFilters: Seq[Expression],
     dataFilters: Seq[Filter],
     override val metastoreTableIdentifier: Option[TableIdentifier])
-  extends DataSourceScanExec {
+  extends DataSourceScanExec with ColumnarBatchScan  {
 
   val supportsBatch: Boolean = relation.fileFormat.supportBatch(
     relation.sparkSession, StructType.fromAttributes(output))
@@ -312,7 +312,7 @@ case class FileSourceScanExec(
 
   override protected def doProduce(ctx: CodegenContext): String = {
     if (supportsBatch) {
-      return doProduceVectorized(ctx)
+      return super.doProduce(ctx)
     }
     val numOutputRows = metricTerm(ctx, "numOutputRows")
     // PhysicalRDD always just has one input
@@ -336,88 +336,6 @@ case class FileSourceScanExec(
      """.stripMargin
   }
 
-  // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
-  // never requires UnsafeRow as input.
-  private def doProduceVectorized(ctx: CodegenContext): String = {
-    val input = ctx.freshName("input")
-    // PhysicalRDD always just has one input
-    ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
-
-    // metrics
-    val numOutputRows = metricTerm(ctx, "numOutputRows")
-    val scanTimeMetric = metricTerm(ctx, "scanTime")
-    val scanTimeTotalNs = ctx.freshName("scanTime")
-    ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
-
-    val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
-    val batch = ctx.freshName("batch")
-    ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
-
-    val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
-    val idx = ctx.freshName("batchIdx")
-    ctx.addMutableState("int", idx, s"$idx = 0;")
-    val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
-    val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
-      ctx.addMutableState(columnVectorClz, name, s"$name = null;")
-      s"$name = $batch.column($i);"
-    }
-
-    val nextBatch = ctx.freshName("nextBatch")
-    ctx.addNewFunction(nextBatch,
-      s"""
-         |private void $nextBatch() throws java.io.IOException {
-         |  long getBatchStart = System.nanoTime();
-         |  if ($input.hasNext()) {
-         |    $batch = ($columnarBatchClz)$input.next();
-         |    $numOutputRows.add($batch.numRows());
-         |    $idx = 0;
-         |    ${columnAssigns.mkString("", "\n", "\n")}
-         |  }
-         |  $scanTimeTotalNs += System.nanoTime() - getBatchStart;
-         |}""".stripMargin)
-
-    ctx.currentVars = null
-    val rowidx = ctx.freshName("rowIdx")
-    val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
-      genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
-    }
-    s"""
-       |if ($batch == null) {
-       |  $nextBatch();
-       |}
-       |while ($batch != null) {
-       |  int numRows = $batch.numRows();
-       |  while ($idx < numRows) {
-       |    int $rowidx = $idx++;
-       |    ${consume(ctx, columnsBatchInput).trim}
-       |    if (shouldStop()) return;
-       |  }
-       |  $batch = null;
-       |  $nextBatch();
-       |}
-       |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
-       |$scanTimeTotalNs = 0;
-     """.stripMargin
-  }
-
-  private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
-    dataType: DataType, nullable: Boolean): ExprCode = {
-    val javaType = ctx.javaType(dataType)
-    val value = ctx.getValue(columnVar, dataType, ordinal)
-    val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
-    val valueVar = ctx.freshName("value")
-    val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
-    val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
-      s"""
-        boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal);
-        $javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value);
-      """
-    } else {
-      s"$javaType ${valueVar} = $value;"
-    }).trim
-    ExprCode(code, isNullVar, valueVar)
-  }
-
   /**
    * Create an RDD for bucketed reads.
    * The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
-- 
GitLab