From 964b507c7511cf3f4383cb0fc4026a573034b8cc Mon Sep 17 00:00:00 2001 From: Bryan Cutler <cutlerb@gmail.com> Date: Thu, 31 Aug 2017 13:08:52 +0900 Subject: [PATCH] [SPARK-21583][SQL] Create a ColumnarBatch from ArrowColumnVectors ## What changes were proposed in this pull request? This PR allows the creation of a `ColumnarBatch` from `ReadOnlyColumnVectors` where previously a columnar batch could only allocate vectors internally. This is useful for using `ArrowColumnVectors` in a batch form to do row-based iteration. Also added `ArrowConverter.fromPayloadIterator` which converts `ArrowPayload` iterator to `InternalRow` iterator and uses a `ColumnarBatch` internally. ## How was this patch tested? Added a new unit test for creating a `ColumnarBatch` with `ReadOnlyColumnVectors` and a test to verify the roundtrip of rows -> ArrowPayload -> rows, using `toPayloadIterator` and `fromPayloadIterator`. Author: Bryan Cutler <cutlerb@gmail.com> Closes #18787 from BryanCutler/arrow-ColumnarBatch-support-SPARK-21583. --- .../sql/execution/arrow/ArrowConverters.scala | 76 ++++++++++++++++++- .../arrow/ArrowConvertersSuite.scala | 29 ++++++- .../vectorized/ColumnarBatchSuite.scala | 54 +++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index fa45822311..561a067a2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels +import scala.collection.JavaConverters._ + import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.file._ @@ -28,6 +30,7 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -35,7 +38,7 @@ import org.apache.spark.util.Utils /** * Store Arrow data in a form that can be serialized by Spark and served to a Python process. */ -private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { /** * Convert the ArrowPayload to an ArrowRecordBatch. @@ -50,6 +53,17 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se def asPythonSerializable: Array[Byte] = payload } +/** + * Iterator interface to iterate over Arrow record batches and return rows + */ +private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + + /** + * Return the schema loaded from the Arrow record batch being iterated over + */ + def schema: StructType +} + private[sql] object ArrowConverters { /** @@ -110,6 +124,66 @@ private[sql] object ArrowConverters { } } + /** + * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator + * and the schema from the first batch of Arrow data read. + */ + private[sql] def fromPayloadIterator( + payloadIter: Iterator[ArrowPayload], + context: TaskContext): ArrowRowIterator = { + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + + new ArrowRowIterator { + private var reader: ArrowFileReader = null + private var schemaRead = StructType(Seq.empty) + private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + + context.addTaskCompletionListener { _ => + closeReader() + allocator.close() + } + + override def schema: StructType = schemaRead + + override def hasNext: Boolean = rowIter.hasNext || { + closeReader() + if (payloadIter.hasNext) { + rowIter = nextBatch() + true + } else { + allocator.close() + false + } + } + + override def next(): InternalRow = rowIter.next() + + private def closeReader(): Unit = { + if (reader != null) { + reader.close() + reader = null + } + } + + private def nextBatch(): Iterator[InternalRow] = { + val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) + reader = new ArrowFileReader(in, allocator) + reader.loadNextBatch() // throws IOException + val root = reader.getVectorSchemaRoot // throws IOException + schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + + val columns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector).asInstanceOf[ColumnVector] + }.toArray + + val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + batch.setNumRows(root.getRowCount) + batch.rowIterator().asScala + } + } + } + /** * Convert a byte array to an ArrowRecordBatch. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 4893b52f24..30422b6577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -29,8 +29,9 @@ import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -1629,6 +1630,32 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } + test("roundtrip payloads") { + val inputRows = (0 until 9).map { i => + InternalRow(i) + } :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + + val ctx = TaskContext.empty() + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx) + val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + + assert(schema.equals(outputRowIter.schema)) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 08ccbd628c..1f21d3c0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -25,10 +25,13 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import org.apache.arrow.vector.NullableIntVector + import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -1261,4 +1264,55 @@ class ColumnarBatchSuite extends SparkFunSuite { s"vectorized reader")) } } + + test("create columnar batch from Arrow column vectors") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector1.allocateNew() + val mutator1 = vector1.getMutator() + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector2.allocateNew() + val mutator2 = vector2.getMutator() + + (0 until 10).foreach { i => + mutator1.setSafe(i, i) + mutator2.setSafe(i + 1, i) + } + mutator1.setNull(10) + mutator1.setValueCount(11) + mutator2.setNull(0) + mutator2.setValueCount(11) + + val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) + + val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) + val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + batch.setNumRows(11) + + assert(batch.numCols() == 2) + assert(batch.numRows() == 11) + + val rowIter = batch.rowIterator().asScala + rowIter.zipWithIndex.foreach { case (row, i) => + if (i == 10) { + assert(row.isNullAt(0)) + } else { + assert(row.getInt(0) == i) + } + if (i == 0) { + assert(row.isNullAt(1)) + } else { + assert(row.getInt(1) == i - 1) + } + } + + intercept[java.lang.AssertionError] { + batch.getRow(100) + } + + batch.close() + allocator.close() + } } -- GitLab