diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 131efea20f31e0f2f58e4a998bdbcd3819c79399..4ca2d85406bb77e40b9cd4a6eeb1214e0b82789d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -38,6 +38,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" + case _: BroadcastHashJoin => "bhj" case _ => nodeName.toLowerCase } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 943ad31c0cef52bd950081b0a21fbb30cebf938c..cbd549763ac95876b14bfc388de0530b2c0307a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -90,8 +90,14 @@ case class BroadcastHashJoin( // The following line doesn't run in a job so we cannot track the metric value. However, we // have already tracked it in the above lines. So here we can use // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + // TODO: move this check into HashedRelation + val hashed = if (canJoinKeyFitWithinLong) { + LongHashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + } else { + HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) + } sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -112,15 +118,12 @@ case class BroadcastHashJoin( streamedPlan.execute().mapPartitions { streamedIter => val hashedRelation = broadcastRelation.value - hashedRelation match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) - case _ => - } + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + private var broadcastRelation: Broadcast[HashedRelation] = _ // the term for hash relation private var relationTerm: String = _ @@ -129,16 +132,15 @@ case class BroadcastHashJoin( } override def doProduce(ctx: CodegenContext): String = { - // create a name for HashRelation - val broadcastRelation = Await.result(broadcastFuture, timeout) + // create a name for HashedRelation + broadcastRelation = Await.result(broadcastFuture, timeout) val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) relationTerm = ctx.freshName("relation") - // TODO: create specialized HashRelation for single join key - val clsName = classOf[UnsafeHashedRelation].getName + val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = ($clsName) $broadcast.value(); - | incPeakExecutionMemory($relationTerm.getUnsafeSize()); + | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) s""" @@ -147,23 +149,24 @@ case class BroadcastHashJoin( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // generate the key as UnsafeRow + // generate the key as UnsafeRow or Long ctx.currentVars = input - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) - val keyTerm = keyVal.value - val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + val expr = rewriteKeyExpr(streamedKeys).head + val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + (ev, ev.isNull) + } else { + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + (ev, s"${ev.value}.anyNull()") + } // find the matches from HashedRelation - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - val row = ctx.freshName("row") + val matched = ctx.freshName("matched") // create variables for output ctx.currentVars = null - ctx.INPUT_ROW = row + ctx.INPUT_ROW = matched val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => BoundReference(i, a.dataType, a.nullable).gen(ctx) } @@ -172,7 +175,7 @@ case class BroadcastHashJoin( case BuildRight => input ++ buildColumns } - val ouputCode = if (condition.isDefined) { + val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) @@ -186,20 +189,39 @@ case class BroadcastHashJoin( consume(ctx, resultVars) } - s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $row = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $ouputCode - | } - | } + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashedRelation + | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); + | if ($matched != null) { + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = ${anyNull} ? null : + | ($bufferType) $relationTerm.get(${keyVal.value}); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } + | } + """.stripMargin + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index f48fc3b84864d9872eca4b26d4a312357da99525..ad3275696e637237ba5201d1a30096f0bdaabe89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -116,12 +116,7 @@ case class BroadcastHashOuterJoin( val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator - - hashTable match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) - case _ => - } + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) val resultProj = resultProjection joinType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 8929dc3af19121497f4c916ff3236e222ad31ae4..d0e18dfcf3d9003e268925becbded24f178bd5d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -64,11 +64,7 @@ case class BroadcastLeftSemiJoinHash( left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value - hashedRelation match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize) - case _ => - } + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8ef854001f4de063afd4038649fcbd165061451c..ecbb1ac64b7c08e2574571b7f83a3116671d5ee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric - +import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { self: SparkPlan => @@ -47,11 +47,49 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + if (e.dataType != LongType) { + keyExpr = Cast(e, LongType) + } else { + keyExpr = e + } + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + width -= bits + } + // TODO: support BooleanType, DateType and TimestampType + case other => + return keys + } + } + keyExpr :: Nil + } + + protected val canJoinKeyFitWithinLong: Boolean = { + val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) + val key = rewriteKeyExpr(buildKeys) + sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] + } + protected def buildSideKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) + UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) protected def streamSideKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) + UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ee7a1bdc343c02dd6e4f312ba9605fff29ac8d27..c94d6c195b1d8b050bbbeab70fd555eef6bb112c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -39,8 +39,23 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[execution] sealed trait HashedRelation { + /** + * Returns matched rows. + */ def get(key: InternalRow): Seq[InternalRow] + /** + * Returns matched rows for a key that has only one column with LongType. + */ + def get(key: Long): Seq[InternalRow] = { + throw new UnsupportedOperationException + } + + /** + * Returns the size of used memory. + */ + def getMemorySize: Long = 1L // to make the test happy + // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { @@ -58,11 +73,48 @@ private[execution] sealed trait HashedRelation { } } +/** + * Interface for a hashed relation that have only one row per key. + * + * We should call getValue() for better performance. + */ +private[execution] trait UniqueHashedRelation extends HashedRelation { + + /** + * Returns the matched single row. + */ + def getValue(key: InternalRow): InternalRow + + /** + * Returns the matched single row with key that have only one column of LongType. + */ + def getValue(key: Long): InternalRow = { + throw new UnsupportedOperationException + } + + override def get(key: InternalRow): Seq[InternalRow] = { + val row = getValue(key) + if (row != null) { + CompactBuffer[InternalRow](row) + } else { + null + } + } + + override def get(key: Long): Seq[InternalRow] = { + val row = getValue(key) + if (row != null) { + CompactBuffer[InternalRow](row) + } else { + null + } + } +} /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation( +private[joins] class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { @@ -85,19 +137,14 @@ private[joins] final class GeneralHashedRelation( * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] -final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) - extends HashedRelation with Externalizable { +private[joins] class UniqueKeyHashedRelation( + private var hashTable: JavaHashMap[InternalRow, InternalRow]) + extends UniqueHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) def this() = this(null) - override def get(key: InternalRow): Seq[InternalRow] = { - val v = hashTable.get(key) - if (v eq null) null else CompactBuffer(v) - } - - def getValue(key: InternalRow): InternalRow = hashTable.get(key) + override def getValue(key: InternalRow): InternalRow = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -108,8 +155,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } -// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. - private[execution] object HashedRelation { @@ -208,7 +253,7 @@ private[joins] final class UnsafeHashedRelation( * * For non-broadcast joins or in local mode, return 0. */ - def getUnsafeSize: Long = { + override def getMemorySize: Long = { if (binaryMap != null) { binaryMap.getTotalMemoryConsumption } else { @@ -408,6 +453,232 @@ private[joins] object UnsafeHashedRelation { } } + // TODO: create UniqueUnsafeRelation new UnsafeHashedRelation(hashTable) } } + +/** + * An interface for a hashed relation that the key is a Long. + */ +private[joins] trait LongHashedRelation extends HashedRelation { + override def get(key: InternalRow): Seq[InternalRow] = { + get(key.getLong(0)) + } +} + +private[joins] final class GeneralLongHashedRelation( + private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) + extends LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) + + override def get(key: Long): Seq[InternalRow] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] final class UniqueLongHashedRelation( + private var hashTable: JavaHashMap[Long, UnsafeRow]) + extends UniqueHashedRelation with LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) + + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def getValue(key: Long): InternalRow = { + hashTable.get(key) + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +/** + * A relation that pack all the rows into a byte array, together with offsets and sizes. + * + * All the bytes of UnsafeRow are packed together as `bytes`: + * + * [ Row0 ][ Row1 ][] ... [ RowN ] + * + * With keys: + * + * start start+1 ... start+N + * + * `offsets` are offsets of UnsafeRows in the `bytes` + * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. + * + * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: + * + * start = 3 + * offsets = [0, 0, 24] + * sizes = [24, 0, 32] + * bytes = [0 - 24][][24 - 56] + */ +private[joins] final class LongArrayRelation( + private var numFields: Int, + private var start: Long, + private var offsets: Array[Int], + private var sizes: Array[Int], + private var bytes: Array[Byte] + ) extends UniqueHashedRelation with LongHashedRelation with Externalizable { + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(0, 0L, null, null, null) + + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } + + override def getMemorySize: Long = { + offsets.length * 4 + sizes.length * 4 + bytes.length + } + + override def getValue(key: Long): InternalRow = { + val idx = (key - start).toInt + if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { + val result = new UnsafeRow(numFields) + result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + result + } else { + null + } + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(numFields) + out.writeLong(start) + out.writeInt(sizes.length) + var i = 0 + while (i < sizes.length) { + out.writeInt(sizes(i)) + i += 1 + } + out.writeInt(bytes.length) + out.write(bytes) + } + + override def readExternal(in: ObjectInput): Unit = { + numFields = in.readInt() + start = in.readLong() + val length = in.readInt() + // read sizes of rows + sizes = new Array[Int](length) + offsets = new Array[Int](length) + var i = 0 + var offset = 0 + while (i < length) { + offsets(i) = offset + sizes(i) = in.readInt() + offset += sizes(i) + i += 1 + } + // read all the bytes + val total = in.readInt() + assert(total == offset) + bytes = new Array[Byte](total) + in.readFully(bytes) + } +} + +/** + * Create hashed relation with key that is long. + */ +private[joins] object LongHashedRelation { + + val DENSE_FACTOR = 0.2 + + def apply( + input: Iterator[InternalRow], + numInputRows: LongSQLMetric, + keyGenerator: Projection, + sizeEstimate: Int): HashedRelation = { + + // Use a Java hash table here because unsafe maps expect fixed size records + val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + + // Create a mapping of key -> rows + var numFields = 0 + var keyIsUnique = true + var minKey = Long.MaxValue + var maxKey = Long.MinValue + while (input.hasNext) { + val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numFields = unsafeRow.numFields() + numInputRows += 1 + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val key = rowKey.getLong(0) + minKey = math.min(minKey, key) + maxKey = math.max(maxKey, key) + val existingMatchList = hashTable.get(key) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(key, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + if (keyIsUnique) { + if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { + // The keys are dense enough, so use LongArrayRelation + val length = (maxKey - minKey).toInt + 1 + val sizes = new Array[Int](length) + val offsets = new Array[Int](length) + var offset = 0 + var i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + offsets(i) = offset + sizes(i) = rows(0).getSizeInBytes + offset += sizes(i) + } + i += 1 + } + val bytes = new Array[Byte](offset) + i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) + } + i += 1 + } + new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) + + } else { + // all the keys are unique, one row per key. + val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueLongHashedRelation(uniqHashTable) + } + } else { + new GeneralLongHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 33d4976403d9ae22b5880d5e6be5b00146716d9d..f015d297048a33e4a8bd38ec4eb972d4db7bfd5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -22,6 +22,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.IntegerType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -122,10 +123,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("broadcast hash join") { - val N = 20 << 20 + val N = 100 << 20 val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) - runBenchmark("BroadcastHashJoin", N) { + runBenchmark("Join w long", N) { sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() } @@ -133,9 +134,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X - BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X + Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X + Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X + */ + + val dim2 = broadcast(sqlContext.range(1 << 16) + .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 ints", N) { + sqlContext.range(N).join(dim2, + (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1") + && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X + Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X */ + } ignore("hash and BytesToBytesMap") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index e5fd9e277fc61e288f35521637bf1a2b21a726ff..f985dfbd8ade9fdacade3f0467f3d09d35063714 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer - class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself @@ -134,4 +133,32 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { out2.flush() assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } + + test("LongArrayRelation") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) + val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) + val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) + val longRelation = LongHashedRelation(rows.iterator, SQLMetrics.nullLongMetric, keyProj, 100) + assert(longRelation.isInstanceOf[LongArrayRelation]) + val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] + (0 until 100).foreach { i => + val row = longArrayRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) + } + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + longArrayRelation.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val relation = new LongArrayRelation() + relation.readExternal(in) + (0 until 100).foreach { i => + val row = longArrayRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) + } + } }