diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 645eb48d5a51bf11866d0b03024d8dcf115c2a17..5f8a6f88717224cf53353894b2c1ab4ad45a9a96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner { */ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] { - def dump(word: Long): String = { - Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString - } - override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = { create(in._1, in._2) } @@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val ctx = newCodeGenContext() val offset = PlatformDependent.BYTE_ARRAY_OFFSET + val getLong = "PlatformDependent.UNSAFE.getLong" + val putLong = "PlatformDependent.UNSAFE.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 val bitset1Remainder = schema1.size % 64 - val bitset2Remainder = schema2.size % 64 // The number of words we can reduce when we concat two rows together. // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords - // --------------------- copy bitset from row 1 ----------------------- // - val copyBitset1 = Seq.tabulate(bitset1Words) { i => - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8}, - | PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8})); - """.stripMargin - }.mkString - - - // --------------------- copy bitset from row 2 ----------------------- // - var copyBitset2 = "" - if (bitset1Remainder == 0) { - copyBitset2 += Seq.tabulate(bitset2Words) { i => - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8}, - | PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8})); - """.stripMargin - }.mkString - } else { - copyBitset2 = Seq.tabulate(bitset2Words) { i => - s""" - |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}); - |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1); - |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder}); - """.stripMargin - }.mkString - - copyBitset2 += Seq.tabulate(bitset2Words) { i => - val currentOffset = offset + (bitset1Words + i - 1) * 8 - if (i == 0) { - if (bitset1Words > 0) { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, - | bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset)); - """.stripMargin - } else { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1); - """.stripMargin - } + // --------------------- copy bitset from row 1 and row 2 --------------------------- // + val copyBitset = Seq.tabulate(outputBitsetWords) { i => + val bits = if (bitset1Remainder > 0) { + if (i < bitset1Words - 1) { + s"$getLong(obj1, offset1 + ${i * 8})" + } else if (i == bitset1Words - 1) { + // combine last work of bitset1 and first word of bitset2 + s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)" + } else if (i - bitset1Words < bitset2Words - 1) { + // combine next two words of bitset2 + s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" + + s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)" + } else { + // last word of bitset2 + s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)" + } + } else { + // they are aligned by word + if (i < bitset1Words) { + s"$getLong(obj1, offset1 + ${i * 8})" } else { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2); - """.stripMargin + s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})" } - }.mkString("\n") - - if (bitset2Words > 0 && - (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) { - val lastWord = bitset2Words - 1 - copyBitset2 += - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8}, - | bs2w${lastWord}p2); - """.stripMargin } - } + s"$putLong(buf, ${offset + i * 8}, $bits);" + }.mkString("\n") // --------------------- copy fixed length portion from row 1 ----------------------- // var cursor = offset + outputBitsetWords * 8 @@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U cursor += schema2.size * 8 // --------------------- copy variable length portion from row 1 ----------------------- // + val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8 val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 - |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8}; - |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1; + |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; |PlatformDependent.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, @@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U """.stripMargin // --------------------- copy variable length portion from row 2 ----------------------- // + val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8 val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 - |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8}; - |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2; + |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; |PlatformDependent.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, @@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" } else { - s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1" + s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 s""" - |PlatformDependent.UNSAFE.putLong(buf, $cursor, - | PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32)); + |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } }.mkString @@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | final Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | - | $copyBitset1 - | $copyBitset2 + | $copyBitset | $copyFixedLengthRow1 | $copyFixedLengthRow2 | $copyVariableLengthRow1 @@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U """.stripMargin logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - // println(CodeFormatter.format(code)) val c = compile(code) c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 76d9d991ed0dc902b371e788180479b75b750d42..718a2acc8281d10e87f45d5ca6be3f5aadaf249a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent /** * A test suite for the bitset portion of the row concatenation. @@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { private def createUnsafeRow(numFields: Int): UnsafeRow = { val row = new UnsafeRow val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 - val buf = new Array[Byte](sizeInBytes) - row.pointTo(buf, numFields, sizeInBytes) + val offset = numFields * 8 + val buf = new Array[Byte](sizeInBytes + offset) + row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } @@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { |input1: ${set1.mkString} |input2: ${set2.mkString} |output: ${out.mkString} + |expect: ${set1.mkString}${set2.mkString} """.stripMargin }