Skip to content
Snippets Groups Projects
Commit 191bf268 authored by Davies Liu's avatar Davies Liu Committed by Reynold Xin
Browse files

[SPARK-9518] [SQL] cleanup generated UnsafeRowJoiner and fix bug

Currently, when copy the bitsets, we didn't consider that the row1 may not sit in the beginning of byte array.

cc rxin

Author: Davies Liu <davies@databricks.com>

Closes #7892 from davies/clean_join and squashes the following commits:

14cce9e [Davies Liu] cleanup generated UnsafeRowJoiner and fix bug
parent 137f4786
No related branches found
No related tags found
No related merge requests found
......@@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner {
*/
object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] {
def dump(word: Long): String = {
Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString
}
override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = {
create(in._1, in._2)
}
......@@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
}
def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
val ctx = newCodeGenContext()
val offset = PlatformDependent.BYTE_ARRAY_OFFSET
val getLong = "PlatformDependent.UNSAFE.getLong"
val putLong = "PlatformDependent.UNSAFE.putLong"
val bitset1Words = (schema1.size + 63) / 64
val bitset2Words = (schema2.size + 63) / 64
val outputBitsetWords = (schema1.size + schema2.size + 63) / 64
val bitset1Remainder = schema1.size % 64
val bitset2Remainder = schema2.size % 64
// The number of words we can reduce when we concat two rows together.
// The only reduction comes from merging the bitset portion of the two rows, saving 1 word.
val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords
// --------------------- copy bitset from row 1 ----------------------- //
val copyBitset1 = Seq.tabulate(bitset1Words) { i =>
s"""
|PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8},
| PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8}));
""".stripMargin
}.mkString
// --------------------- copy bitset from row 2 ----------------------- //
var copyBitset2 = ""
if (bitset1Remainder == 0) {
copyBitset2 += Seq.tabulate(bitset2Words) { i =>
s"""
|PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8},
| PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}));
""".stripMargin
}.mkString
} else {
copyBitset2 = Seq.tabulate(bitset2Words) { i =>
s"""
|long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8});
|long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1);
|long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder});
""".stripMargin
}.mkString
copyBitset2 += Seq.tabulate(bitset2Words) { i =>
val currentOffset = offset + (bitset1Words + i - 1) * 8
if (i == 0) {
if (bitset1Words > 0) {
s"""
|PlatformDependent.UNSAFE.putLong(buf, $currentOffset,
| bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset));
""".stripMargin
} else {
s"""
|PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1);
""".stripMargin
}
// --------------------- copy bitset from row 1 and row 2 --------------------------- //
val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
val bits = if (bitset1Remainder > 0) {
if (i < bitset1Words - 1) {
s"$getLong(obj1, offset1 + ${i * 8})"
} else if (i == bitset1Words - 1) {
// combine last work of bitset1 and first word of bitset2
s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)"
} else if (i - bitset1Words < bitset2Words - 1) {
// combine next two words of bitset2
s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" +
s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)"
} else {
// last word of bitset2
s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)"
}
} else {
// they are aligned by word
if (i < bitset1Words) {
s"$getLong(obj1, offset1 + ${i * 8})"
} else {
s"""
|PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2);
""".stripMargin
s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
}
}.mkString("\n")
if (bitset2Words > 0 &&
(bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) {
val lastWord = bitset2Words - 1
copyBitset2 +=
s"""
|PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8},
| bs2w${lastWord}p2);
""".stripMargin
}
}
s"$putLong(buf, ${offset + i * 8}, $bits);"
}.mkString("\n")
// --------------------- copy fixed length portion from row 1 ----------------------- //
var cursor = offset + outputBitsetWords * 8
......@@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
cursor += schema2.size * 8
// --------------------- copy variable length portion from row 1 ----------------------- //
val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8
val copyVariableLengthRow1 = s"""
|// Copy variable length data for row1
|long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
|long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1;
|long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1;
|PlatformDependent.copyMemory(
| obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
| buf, $cursor,
......@@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
""".stripMargin
// --------------------- copy variable length portion from row 2 ----------------------- //
val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8
val copyVariableLengthRow2 = s"""
|// Copy variable length data for row2
|long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8};
|long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2;
|long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2;
|PlatformDependent.copyMemory(
| obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
| buf, $cursor + numBytesVariableRow1,
......@@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
if (i < schema1.size) {
s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
} else {
s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1"
s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
s"""
|PlatformDependent.UNSAFE.putLong(buf, $cursor,
| PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32));
|$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
""".stripMargin
}
}.mkString
......@@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| final Object obj2 = row2.getBaseObject();
| final long offset2 = row2.getBaseOffset();
|
| $copyBitset1
| $copyBitset2
| $copyBitset
| $copyFixedLengthRow1
| $copyFixedLengthRow2
| $copyVariableLengthRow1
......@@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
""".stripMargin
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")
// println(CodeFormatter.format(code))
val c = compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
......
......@@ -22,6 +22,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
/**
* A test suite for the bitset portion of the row concatenation.
......@@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
private def createUnsafeRow(numFields: Int): UnsafeRow = {
val row = new UnsafeRow
val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
val buf = new Array[Byte](sizeInBytes)
row.pointTo(buf, numFields, sizeInBytes)
val offset = numFields * 8
val buf = new Array[Byte](sizeInBytes + offset)
row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes)
row
}
......@@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
|input1: ${set1.mkString}
|input2: ${set2.mkString}
|output: ${out.mkString}
|expect: ${set1.mkString}${set2.mkString}
""".stripMargin
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment