Skip to content
Snippets Groups Projects
Commit 148a84b3 authored by Kazuaki Ishizaki's avatar Kazuaki Ishizaki Committed by Davies Liu
Browse files

[SPARK-17912] [SQL] Refactor code generation to get data for ColumnVector/ColumnarBatch

## What changes were proposed in this pull request?

This PR refactors the code generation part to get data from `ColumnarVector` and `ColumnarBatch` by using a trait `ColumnarBatchScan` for ease of reuse. This is because this part will be reused by several components (e.g. parquet reader, Dataset.cache, and others) since `ColumnarBatch` will be first citizen.

This PR is a part of https://github.com/apache/spark/pull/15219. In advance, this PR makes the code generation for  `ColumnarVector` and `ColumnarBatch` reuseable as a trait. In general, this is very useful for other components from the reuseability view, too.
## How was this patch tested?

tested existing test suites

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #15467 from kiszk/columnarrefactor.
parent 63d83902
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.DataType
/**
* Helper trait for abstracting scan functionality using
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es.
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {
val inMemoryTableScan: InMemoryTableScanExec = null
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
/**
* Generate [[ColumnVector]] expressions for our parent to consume as rows.
* This is called once per [[ColumnarBatch]].
*/
private def genCodeColumnVector(
ctx: CodegenContext,
columnVar: String,
ordinal: String,
dataType: DataType,
nullable: Boolean): ExprCode = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(columnVar, dataType, ordinal)
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
s"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
$javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value);
"""
} else {
s"$javaType $valueVar = $value;"
}).trim
ExprCode(code, isNullVar, valueVar)
}
/**
* Produce code to process the input iterator as [[ColumnarBatch]]es.
* This produces an [[UnsafeRow]] for each row in each batch.
*/
// TODO: return ColumnarBatch.Rows instead
override protected def doProduce(ctx: CodegenContext): String = {
val input = ctx.freshName("input")
// PhysicalRDD always just has one input
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
val scanTimeTotalNs = ctx.freshName("scanTime")
ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
val batch = ctx.freshName("batch")
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
val idx = ctx.freshName("batchIdx")
ctx.addMutableState("int", idx, s"$idx = 0;")
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
s"$name = $batch.column($i);"
}
val nextBatch = ctx.freshName("nextBatch")
ctx.addNewFunction(nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| long getBatchStart = System.nanoTime();
| if ($input.hasNext()) {
| $batch = ($columnarBatchClz)$input.next();
| $numOutputRows.add($batch.numRows());
| $idx = 0;
| ${columnAssigns.mkString("", "\n", "\n")}
| }
| $scanTimeTotalNs += System.nanoTime() - getBatchStart;
|}""".stripMargin)
ctx.currentVars = null
val rowidx = ctx.freshName("rowIdx")
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
s"""
|if ($batch == null) {
| $nextBatch();
|}
|while ($batch != null) {
| int numRows = $batch.numRows();
| while ($idx < numRows) {
| int $rowidx = $idx++;
| ${consume(ctx, columnsBatchInput).trim}
| if (shouldStop()) return;
| }
| $batch = null;
| $nextBatch();
|}
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
|$scanTimeTotalNs = 0;
""".stripMargin
}
}
...@@ -145,7 +145,7 @@ case class FileSourceScanExec( ...@@ -145,7 +145,7 @@ case class FileSourceScanExec(
partitionFilters: Seq[Expression], partitionFilters: Seq[Expression],
dataFilters: Seq[Filter], dataFilters: Seq[Filter],
override val metastoreTableIdentifier: Option[TableIdentifier]) override val metastoreTableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec { extends DataSourceScanExec with ColumnarBatchScan {
val supportsBatch: Boolean = relation.fileFormat.supportBatch( val supportsBatch: Boolean = relation.fileFormat.supportBatch(
relation.sparkSession, StructType.fromAttributes(output)) relation.sparkSession, StructType.fromAttributes(output))
...@@ -312,7 +312,7 @@ case class FileSourceScanExec( ...@@ -312,7 +312,7 @@ case class FileSourceScanExec(
override protected def doProduce(ctx: CodegenContext): String = { override protected def doProduce(ctx: CodegenContext): String = {
if (supportsBatch) { if (supportsBatch) {
return doProduceVectorized(ctx) return super.doProduce(ctx)
} }
val numOutputRows = metricTerm(ctx, "numOutputRows") val numOutputRows = metricTerm(ctx, "numOutputRows")
// PhysicalRDD always just has one input // PhysicalRDD always just has one input
...@@ -336,88 +336,6 @@ case class FileSourceScanExec( ...@@ -336,88 +336,6 @@ case class FileSourceScanExec(
""".stripMargin """.stripMargin
} }
// Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
// never requires UnsafeRow as input.
private def doProduceVectorized(ctx: CodegenContext): String = {
val input = ctx.freshName("input")
// PhysicalRDD always just has one input
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
val scanTimeTotalNs = ctx.freshName("scanTime")
ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
val batch = ctx.freshName("batch")
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
val idx = ctx.freshName("batchIdx")
ctx.addMutableState("int", idx, s"$idx = 0;")
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
s"$name = $batch.column($i);"
}
val nextBatch = ctx.freshName("nextBatch")
ctx.addNewFunction(nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| long getBatchStart = System.nanoTime();
| if ($input.hasNext()) {
| $batch = ($columnarBatchClz)$input.next();
| $numOutputRows.add($batch.numRows());
| $idx = 0;
| ${columnAssigns.mkString("", "\n", "\n")}
| }
| $scanTimeTotalNs += System.nanoTime() - getBatchStart;
|}""".stripMargin)
ctx.currentVars = null
val rowidx = ctx.freshName("rowIdx")
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
s"""
|if ($batch == null) {
| $nextBatch();
|}
|while ($batch != null) {
| int numRows = $batch.numRows();
| while ($idx < numRows) {
| int $rowidx = $idx++;
| ${consume(ctx, columnsBatchInput).trim}
| if (shouldStop()) return;
| }
| $batch = null;
| $nextBatch();
|}
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
|$scanTimeTotalNs = 0;
""".stripMargin
}
private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
dataType: DataType, nullable: Boolean): ExprCode = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(columnVar, dataType, ordinal)
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
s"""
boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal);
$javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value);
"""
} else {
s"$javaType ${valueVar} = $value;"
}).trim
ExprCode(code, isNullVar, valueVar)
}
/** /**
* Create an RDD for bucketed reads. * Create an RDD for bucketed reads.
* The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment