Skip to content
Snippets Groups Projects
Commit 964b507c authored by Bryan Cutler's avatar Bryan Cutler Committed by Takuya UESHIN
Browse files

[SPARK-21583][SQL] Create a ColumnarBatch from ArrowColumnVectors

## What changes were proposed in this pull request?

This PR allows the creation of a `ColumnarBatch` from `ReadOnlyColumnVectors` where previously a columnar batch could only allocate vectors internally.  This is useful for using `ArrowColumnVectors` in a batch form to do row-based iteration.  Also added `ArrowConverter.fromPayloadIterator` which converts `ArrowPayload` iterator to `InternalRow` iterator and uses a `ColumnarBatch` internally.

## How was this patch tested?

Added a new unit test for creating a `ColumnarBatch` with `ReadOnlyColumnVectors` and a test to verify the roundtrip of rows -> ArrowPayload -> rows, using `toPayloadIterator` and `fromPayloadIterator`.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #18787 from BryanCutler/arrow-ColumnarBatch-support-SPARK-21583.
parent ecf437a6
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels
import scala.collection.JavaConverters._
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.file._
......@@ -28,6 +30,7 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
......@@ -35,7 +38,7 @@ import org.apache.spark.util.Utils
/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
*/
private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {
/**
* Convert the ArrowPayload to an ArrowRecordBatch.
......@@ -50,6 +53,17 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se
def asPythonSerializable: Array[Byte] = payload
}
/**
* Iterator interface to iterate over Arrow record batches and return rows
*/
private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
/**
* Return the schema loaded from the Arrow record batch being iterated over
*/
def schema: StructType
}
private[sql] object ArrowConverters {
/**
......@@ -110,6 +124,66 @@ private[sql] object ArrowConverters {
}
}
/**
* Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator
* and the schema from the first batch of Arrow data read.
*/
private[sql] def fromPayloadIterator(
payloadIter: Iterator[ArrowPayload],
context: TaskContext): ArrowRowIterator = {
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)
new ArrowRowIterator {
private var reader: ArrowFileReader = null
private var schemaRead = StructType(Seq.empty)
private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty
context.addTaskCompletionListener { _ =>
closeReader()
allocator.close()
}
override def schema: StructType = schemaRead
override def hasNext: Boolean = rowIter.hasNext || {
closeReader()
if (payloadIter.hasNext) {
rowIter = nextBatch()
true
} else {
allocator.close()
false
}
}
override def next(): InternalRow = rowIter.next()
private def closeReader(): Unit = {
if (reader != null) {
reader.close()
reader = null
}
}
private def nextBatch(): Iterator[InternalRow] = {
val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
reader = new ArrowFileReader(in, allocator)
reader.loadNextBatch() // throws IOException
val root = reader.getVectorSchemaRoot // throws IOException
schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)
val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
}.toArray
val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount)
batch.setNumRows(root.getRowCount)
batch.rowIterator().asScala
}
}
}
/**
* Convert a byte array to an ArrowRecordBatch.
*/
......
......@@ -29,8 +29,9 @@ import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
......@@ -1629,6 +1630,32 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
}
test("roundtrip payloads") {
val inputRows = (0 until 9).map { i =>
InternalRow(i)
} :+ InternalRow(null)
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
val ctx = TaskContext.empty()
val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx)
val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx)
assert(schema.equals(outputRowIter.schema))
var count = 0
outputRowIter.zipWithIndex.foreach { case (row, i) =>
if (i != 9) {
assert(row.getInt(0) == i)
} else {
assert(row.isNullAt(0))
}
count += 1
}
assert(count == inputRows.length)
}
/** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in validator
......
......@@ -25,10 +25,13 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
import org.apache.arrow.vector.NullableIntVector
import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.CalendarInterval
......@@ -1261,4 +1264,55 @@ class ColumnarBatchSuite extends SparkFunSuite {
s"vectorized reader"))
}
}
test("create columnar batch from Arrow column vectors") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector1.allocateNew()
val mutator1 = vector1.getMutator()
val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector2.allocateNew()
val mutator2 = vector2.getMutator()
(0 until 10).foreach { i =>
mutator1.setSafe(i, i)
mutator2.setSafe(i + 1, i)
}
mutator1.setNull(10)
mutator1.setValueCount(11)
mutator2.setNull(0)
mutator2.setValueCount(11)
val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))
val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType)))
val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11)
batch.setNumRows(11)
assert(batch.numCols() == 2)
assert(batch.numRows() == 11)
val rowIter = batch.rowIterator().asScala
rowIter.zipWithIndex.foreach { case (row, i) =>
if (i == 10) {
assert(row.isNullAt(0))
} else {
assert(row.getInt(0) == i)
}
if (i == 0) {
assert(row.isNullAt(1))
} else {
assert(row.getInt(1) == i - 1)
}
}
intercept[java.lang.AssertionError] {
batch.getRow(100)
}
batch.close()
allocator.close()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment