Skip to content
Snippets Groups Projects
Commit c564b274 authored by Yin Huai's avatar Yin Huai Committed by Reynold Xin
Browse files

[SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow instead of just UnsafeRow

https://issues.apache.org/jira/browse/SPARK-9753

This PR makes TungstenAggregate to accept `InternalRow` instead of just `UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to `UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, respectively.

Author: Yin Huai <yhuai@databricks.com>

Closes #8041 from yhuai/joinedRowForProjection and squashes the following commits:

7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream.
d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner.
e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the given groupingKeyRow is already an UnsafeRow.
parent 998f4ff9
No related branches found
No related tags found
No related merge requests found
...@@ -121,6 +121,10 @@ public final class UnsafeFixedWidthAggregationMap { ...@@ -121,6 +121,10 @@ public final class UnsafeFixedWidthAggregationMap {
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
}
public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
// Probe our map using the serialized key // Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup( final BytesToBytesMap.Location loc = map.lookup(
unsafeGroupingKeyRow.getBaseObject(), unsafeGroupingKeyRow.getBaseObject(),
......
...@@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst ...@@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
*/ */
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
// When `out` is backed by ChainedBufferOutputStream, we will get an private[this] val dOut: DataOutputStream =
// UnsupportedOperationException when we call dOut.writeInt because it internally calls new DataOutputStream(new BufferedOutputStream(out))
// ChainedBufferOutputStream's write(b: Int), which is not supported.
// To workaround this issue, we create an array for sorting the int value.
// To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
// run SparkSqlSerializer2SortMergeShuffleSuite.
private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)
override def writeValue[T: ClassTag](value: T): SerializationStream = { override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow] val row = value.asInstanceOf[UnsafeRow]
val size = row.getSizeInBytes
// This part is based on DataOutputStream's writeInt. dOut.writeInt(row.getSizeInBytes)
// It is for dOut.writeInt(row.getSizeInBytes). row.writeToStream(dOut, writeBuffer)
intBuffer(0) = ((size >>> 24) & 0xFF).toByte
intBuffer(1) = ((size >>> 16) & 0xFF).toByte
intBuffer(2) = ((size >>> 8) & 0xFF).toByte
intBuffer(3) = ((size >>> 0) & 0xFF).toByte
dOut.write(intBuffer, 0, 4)
row.writeToStream(out, writeBuffer)
this this
} }
...@@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst ...@@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def close(): Unit = { override def close(): Unit = {
writeBuffer = null writeBuffer = null
intBuffer = null
dOut.writeInt(EOF) dOut.writeInt(EOF)
dOut.close() dOut.close()
} }
...@@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst ...@@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def deserializeStream(in: InputStream): DeserializationStream = { override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream { new DeserializationStream {
private[this] val dIn: DataInputStream = new DataInputStream(in) private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows // 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow() private[this] var row: UnsafeRow = new UnsafeRow()
...@@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst ...@@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) { if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize) rowBuffer = new Array[Byte](rowSize)
} }
ByteStreams.readFully(in, rowBuffer, 0, rowSize) ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size rowSize = dIn.readInt() // read the next row's size
if (rowSize == EOF) { // We are returning the last row in this stream if (rowSize == EOF) { // We are returning the last row in this stream
...@@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst ...@@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) { if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize) rowBuffer = new Array[Byte](rowSize)
} }
ByteStreams.readFully(in, rowBuffer, 0, rowSize) ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
row.asInstanceOf[T] row.asInstanceOf[T]
} }
......
...@@ -39,7 +39,7 @@ case class TungstenAggregate( ...@@ -39,7 +39,7 @@ case class TungstenAggregate(
override def canProcessUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = false override def canProcessSafeRows: Boolean = true
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
...@@ -77,7 +77,7 @@ case class TungstenAggregate( ...@@ -77,7 +77,7 @@ case class TungstenAggregate(
resultExpressions, resultExpressions,
newMutableProjection, newMutableProjection,
child.output, child.output,
iter.asInstanceOf[Iterator[UnsafeRow]], iter,
testFallbackStartsAt) testFallbackStartsAt)
if (!hasInput && groupingExpressions.isEmpty) { if (!hasInput && groupingExpressions.isEmpty) {
......
...@@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} ...@@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
...@@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType ...@@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType
* processing input rows from inputIter, and generating output * processing input rows from inputIter, and generating output
* rows. * rows.
* - Part 3: Methods and fields used by hash-based aggregation. * - Part 3: Methods and fields used by hash-based aggregation.
* - Part 4: The function used to switch this iterator from hash-based * - Part 4: Methods and fields used when we switch to sort-based aggregation.
* aggregation to sort-based aggregation.
* - Part 5: Methods and fields used by sort-based aggregation. * - Part 5: Methods and fields used by sort-based aggregation.
* - Part 6: Loads input and process input rows. * - Part 6: Loads input and process input rows.
* - Part 7: Public methods of this iterator. * - Part 7: Public methods of this iterator.
...@@ -82,7 +82,7 @@ class TungstenAggregationIterator( ...@@ -82,7 +82,7 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression], resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute], originalInputAttributes: Seq[Attribute],
inputIter: Iterator[UnsafeRow], inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int]) testFallbackStartsAt: Option[Int])
extends Iterator[UnsafeRow] with Logging { extends Iterator[UnsafeRow] with Logging {
...@@ -174,13 +174,10 @@ class TungstenAggregationIterator( ...@@ -174,13 +174,10 @@ class TungstenAggregationIterator(
// Creates a function used to process a row based on the given inputAttributes. // Creates a function used to process a row based on the given inputAttributes.
private def generateProcessRow( private def generateProcessRow(
inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) val joinedRow = new JoinedRow()
val inputSchema = StructType.fromAttributes(inputAttributes)
val unsafeRowJoiner =
GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
aggregationMode match { aggregationMode match {
// Partial-only // Partial-only
...@@ -189,9 +186,9 @@ class TungstenAggregationIterator( ...@@ -189,9 +186,9 @@ class TungstenAggregationIterator(
val algebraicUpdateProjection = val algebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: UnsafeRow) => { (currentBuffer: UnsafeRow, row: InternalRow) => {
algebraicUpdateProjection.target(currentBuffer) algebraicUpdateProjection.target(currentBuffer)
algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) algebraicUpdateProjection(joinedRow(currentBuffer, row))
} }
// PartialMerge-only or Final-only // PartialMerge-only or Final-only
...@@ -203,10 +200,10 @@ class TungstenAggregationIterator( ...@@ -203,10 +200,10 @@ class TungstenAggregationIterator(
mergeExpressions, mergeExpressions,
aggregationBufferAttributes ++ inputAttributes)() aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: UnsafeRow) => { (currentBuffer: UnsafeRow, row: InternalRow) => {
// Process all algebraic aggregate functions. // Process all algebraic aggregate functions.
algebraicMergeProjection.target(currentBuffer) algebraicMergeProjection.target(currentBuffer)
algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) algebraicMergeProjection(joinedRow(currentBuffer, row))
} }
// Final-Complete // Final-Complete
...@@ -233,8 +230,8 @@ class TungstenAggregationIterator( ...@@ -233,8 +230,8 @@ class TungstenAggregationIterator(
val completeAlgebraicUpdateProjection = val completeAlgebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: UnsafeRow) => { (currentBuffer: UnsafeRow, row: InternalRow) => {
val input = unsafeRowJoiner.join(currentBuffer, row) val input = joinedRow(currentBuffer, row)
// For all aggregate functions with mode Complete, update the given currentBuffer. // For all aggregate functions with mode Complete, update the given currentBuffer.
completeAlgebraicUpdateProjection.target(currentBuffer)(input) completeAlgebraicUpdateProjection.target(currentBuffer)(input)
...@@ -253,14 +250,14 @@ class TungstenAggregationIterator( ...@@ -253,14 +250,14 @@ class TungstenAggregationIterator(
val completeAlgebraicUpdateProjection = val completeAlgebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: UnsafeRow) => { (currentBuffer: UnsafeRow, row: InternalRow) => {
completeAlgebraicUpdateProjection.target(currentBuffer) completeAlgebraicUpdateProjection.target(currentBuffer)
// For all aggregate functions with mode Complete, update the given currentBuffer. // For all aggregate functions with mode Complete, update the given currentBuffer.
completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row))
} }
// Grouping only. // Grouping only.
case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}
case other => case other =>
throw new IllegalStateException( throw new IllegalStateException(
...@@ -272,15 +269,16 @@ class TungstenAggregationIterator( ...@@ -272,15 +269,16 @@ class TungstenAggregationIterator(
private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
val groupingAttributes = groupingExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute)
val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
val bufferSchema = StructType.fromAttributes(bufferAttributes)
val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
aggregationMode match { aggregationMode match {
// Partial-only or PartialMerge-only: every output row is basically the values of // Partial-only or PartialMerge-only: every output row is basically the values of
// the grouping expressions and the corresponding aggregation buffer. // the grouping expressions and the corresponding aggregation buffer.
case (Some(Partial), None) | (Some(PartialMerge), None) => case (Some(Partial), None) | (Some(PartialMerge), None) =>
val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
val bufferSchema = StructType.fromAttributes(bufferAttributes)
val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
unsafeRowJoiner.join(currentGroupingKey, currentBuffer) unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
} }
...@@ -288,11 +286,12 @@ class TungstenAggregationIterator( ...@@ -288,11 +286,12 @@ class TungstenAggregationIterator(
// Final-only, Complete-only and Final-Complete: a output row is generated based on // Final-only, Complete-only and Final-Complete: a output row is generated based on
// resultExpressions. // resultExpressions.
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
val joinedRow = new JoinedRow()
val resultProjection = val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) resultProjection(joinedRow(currentGroupingKey, currentBuffer))
} }
// Grouping-only: a output row is generated from values of grouping expressions. // Grouping-only: a output row is generated from values of grouping expressions.
...@@ -316,7 +315,7 @@ class TungstenAggregationIterator( ...@@ -316,7 +315,7 @@ class TungstenAggregationIterator(
// A function used to process a input row. Its first argument is the aggregation buffer // A function used to process a input row. Its first argument is the aggregation buffer
// and the second argument is the input row. // and the second argument is the input row.
private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
generateProcessRow(originalInputAttributes) generateProcessRow(originalInputAttributes)
// A function used to generate output rows based on the grouping keys (first argument) // A function used to generate output rows based on the grouping keys (first argument)
...@@ -354,7 +353,7 @@ class TungstenAggregationIterator( ...@@ -354,7 +353,7 @@ class TungstenAggregationIterator(
while (!sortBased && inputIter.hasNext) { while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next() val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput) val groupingKey = groupProjection.apply(newInput)
val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) { if (buffer == null) {
// buffer == null means that we could not allocate more memory. // buffer == null means that we could not allocate more memory.
// Now, we need to spill the map and switch to sort-based aggregation. // Now, we need to spill the map and switch to sort-based aggregation.
...@@ -374,7 +373,7 @@ class TungstenAggregationIterator( ...@@ -374,7 +373,7 @@ class TungstenAggregationIterator(
val newInput = inputIter.next() val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput) val groupingKey = groupProjection.apply(newInput)
val buffer: UnsafeRow = if (i < fallbackStartsAt) { val buffer: UnsafeRow = if (i < fallbackStartsAt) {
hashMap.getAggregationBuffer(groupingKey) hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
} else { } else {
null null
} }
...@@ -397,7 +396,7 @@ class TungstenAggregationIterator( ...@@ -397,7 +396,7 @@ class TungstenAggregationIterator(
private[this] var mapIteratorHasNext: Boolean = false private[this] var mapIteratorHasNext: Boolean = false
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// Part 3: Methods and fields used by sort-based aggregation. // Part 4: Methods and fields used when we switch to sort-based aggregation.
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// This sorter is used for sort-based aggregation. It is initialized as soon as // This sorter is used for sort-based aggregation. It is initialized as soon as
...@@ -407,7 +406,7 @@ class TungstenAggregationIterator( ...@@ -407,7 +406,7 @@ class TungstenAggregationIterator(
/** /**
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/ */
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
logInfo("falling back to sort based aggregation.") logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map. // Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter() externalSorter = hashMap.destructAndCreateExternalSorter()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment