diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 74a47da2deef2d941f1f02d88063026704f6df1c..1afe83ea3539e6218a2ed424beea652cd49d958a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -33,6 +33,8 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val inMemoryTableScan: InMemoryTableScanExec = null + def vectorTypes: Option[Seq[String]] = None + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -79,17 +81,19 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val scanTimeTotalNs = ctx.freshName("scanTime") ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") - val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.freshName("batch") ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") - val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" val idx = ctx.freshName("batchIdx") ctx.addMutableState("int", idx, s"$idx = 0;") val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) - val columnAssigns = colVars.zipWithIndex.map { case (name, i) => - ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = $batch.column($i);" + val columnVectorClzs = vectorTypes.getOrElse( + Seq.fill(colVars.size)(classOf[ColumnVector].getName)) + val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map { + case ((name, columnVectorClz), i) => + ctx.addMutableState(columnVectorClz, name, s"$name = null;") + s"$name = ($columnVectorClz) $batch.column($i);" } val nextBatch = ctx.freshName("nextBatch") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 588c937a13e4560f8b1b02efffd742c84a8f776d..77e6dbf6364762cb7c2fbf7299be6b45a81c9ccf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -174,6 +174,11 @@ case class FileSourceScanExec( false } + override def vectorTypes: Option[Seq[String]] = + relation.fileFormat.vectorTypes( + requiredSchema = requiredSchema, + partitionSchema = relation.partitionSchema) + @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index dacf4629535202a57bdf4cdf879bee0845e50361..e5a7aee64a4f4cba1f412ac743d489a716a5494c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -64,6 +64,16 @@ trait FileFormat { false } + /** + * Returns concrete column vector class names for each column to be used in a columnar batch + * if this format supports returning columnar batch. + */ + def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType): Option[Seq[String]] = { + None + } + /** * Returns whether a file with `path` could be splitted or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 64eea26a9f98e19852c17e2715cc3873807367bc..e1e740500205a1b62d54de5a5951e72067672031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -272,6 +273,13 @@ class ParquetFileFormat schema.forall(_.dataType.isInstanceOf[AtomicType]) } + override def vectorTypes( + requiredSchema: StructType, + partitionSchema: StructType): Option[Seq[String]] = { + Option(Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)( + classOf[OnHeapColumnVector].getName)) + } + override def isSplitable( sparkSession: SparkSession, options: Map[String, String],