Skip to content
Snippets Groups Projects
Commit 5cb5edaf authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long

## What changes were proposed in this pull request?

Currently, we use java HashMap for HashedRelation if the key could fit within a Long. The java HashMap and CompactBuffer are not memory efficient, the memory used by them is also accounted accurately.

This PR introduce a LongToUnsafeRowMap (similar to BytesToBytesMap) for better memory efficiency and performance.

This PR reopen #12190 to fix bugs.

## How was this patch tested?

Existing tests.

Author: Davies Liu <davies@databricks.com>

Closes #12278 from davies/long_map3.
parent dfce9665
No related branches found
No related tags found
No related merge requests found
Showing
with 602 additions and 361 deletions
...@@ -716,7 +716,8 @@ public final class BytesToBytesMap extends MemoryConsumer { ...@@ -716,7 +716,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
offset += klen; offset += klen;
Platform.copyMemory(vbase, voff, base, offset, vlen); Platform.copyMemory(vbase, voff, base, offset, vlen);
offset += vlen; offset += vlen;
Platform.putLong(base, offset, 0); // put this value at the beginning of the list
Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0);
// --- Update bookkeeping data structures ---------------------------------------------------- // --- Update bookkeeping data structures ----------------------------------------------------
offset = currentPage.getBaseOffset(); offset = currentPage.getBaseOffset();
...@@ -724,17 +725,12 @@ public final class BytesToBytesMap extends MemoryConsumer { ...@@ -724,17 +725,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
pageCursor += recordLength; pageCursor += recordLength;
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
currentPage, recordOffset); currentPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
updateAddressesAndSizes(storedKeyAddress);
numValues++; numValues++;
if (isDefined) { if (!isDefined) {
// put this pair at the end of chain
while (nextValue()) { /* do nothing */ }
Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress);
nextValue(); // point to new added value
} else {
numKeys++; numKeys++;
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode); longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
isDefined = true; isDefined = true;
if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) { if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {
......
...@@ -454,7 +454,7 @@ case class TungstenAggregate( ...@@ -454,7 +454,7 @@ case class TungstenAggregate(
val thisPlan = ctx.addReferenceObj("plan", this) val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap") hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") ctx.addMutableState(hashMapClassName, hashMapTerm, s"")
sorterTerm = ctx.freshName("sorter") sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
...@@ -467,6 +467,7 @@ case class TungstenAggregate( ...@@ -467,6 +467,7 @@ case class TungstenAggregate(
s""" s"""
${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException { private void $doAgg() throws java.io.IOException {
$hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)} ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
......
...@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ ...@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
/** /**
* Performs an inner hash join of two child relations. When the output RDD of this operator is * Performs an inner hash join of two child relations. When the output RDD of this operator is
...@@ -50,10 +51,7 @@ case class BroadcastHashJoin( ...@@ -50,10 +51,7 @@ case class BroadcastHashJoin(
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = { override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode( val mode = HashedRelationBroadcastMode(buildKeys)
canJoinKeyFitWithinLong,
rewriteKeyExpr(buildKeys),
buildPlan.output)
buildSide match { buildSide match {
case BuildLeft => case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
...@@ -68,7 +66,7 @@ case class BroadcastHashJoin( ...@@ -68,7 +66,7 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter => streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy() val hashed = broadcastRelation.value.asReadOnlyCopy()
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
join(streamedIter, hashed, numOutputRows) join(streamedIter, hashed, numOutputRows)
} }
} }
...@@ -105,7 +103,7 @@ case class BroadcastHashJoin( ...@@ -105,7 +103,7 @@ case class BroadcastHashJoin(
ctx.addMutableState(clsName, relationTerm, ctx.addMutableState(clsName, relationTerm,
s""" s"""
| $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($relationTerm.getMemorySize()); | incPeakExecutionMemory($relationTerm.estimatedSize());
""".stripMargin) """.stripMargin)
(broadcastRelation, relationTerm) (broadcastRelation, relationTerm)
} }
...@@ -118,15 +116,13 @@ case class BroadcastHashJoin( ...@@ -118,15 +116,13 @@ case class BroadcastHashJoin(
ctx: CodegenContext, ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = { input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input ctx.currentVars = input
if (canJoinKeyFitWithinLong) { if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
// generate the join key as Long // generate the join key as Long
val expr = rewriteKeyExpr(streamedKeys).head val ev = streamedKeys.head.gen(ctx)
val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
(ev, ev.isNull) (ev, ev.isNull)
} else { } else {
// generate the join key as UnsafeRow // generate the join key as UnsafeRow
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
(ev, s"${ev.value}.anyNull()") (ev, s"${ev.value}.anyNull()")
} }
} }
......
...@@ -17,16 +17,12 @@ ...@@ -17,16 +17,12 @@
package org.apache.spark.sql.execution.joins package org.apache.spark.sql.execution.joins
import java.util.NoSuchElementException
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} import org.apache.spark.sql.types.{IntegralType, LongType}
import org.apache.spark.util.collection.CompactBuffer
trait HashJoin { trait HashJoin {
self: SparkPlan => self: SparkPlan =>
...@@ -59,9 +55,15 @@ trait HashJoin { ...@@ -59,9 +55,15 @@ trait HashJoin {
case BuildRight => (right, left) case BuildRight => (right, left)
} }
protected lazy val (buildKeys, streamedKeys) = buildSide match { protected lazy val (buildKeys, streamedKeys) = {
case BuildLeft => (leftKeys, rightKeys) require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
case BuildRight => (rightKeys, leftKeys) "Join keys from two sides should have same types")
val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
}
} }
/** /**
...@@ -69,7 +71,7 @@ trait HashJoin { ...@@ -69,7 +71,7 @@ trait HashJoin {
* *
* If not, returns the original expressions. * If not, returns the original expressions.
*/ */
def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
var keyExpr: Expression = null var keyExpr: Expression = null
var width = 0 var width = 0
keys.foreach { e => keys.foreach { e =>
...@@ -84,17 +86,8 @@ trait HashJoin { ...@@ -84,17 +86,8 @@ trait HashJoin {
width = dt.defaultSize width = dt.defaultSize
} else { } else {
val bits = dt.defaultSize * 8 val bits = dt.defaultSize * 8
// hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
// value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
// with two same ints have hash code 0, we rotate the bits of second one.
val rotated = if (e.dataType == IntegerType) {
// (e >>> 15) | (e << 17)
BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
} else {
e
}
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
width -= bits width -= bits
} }
// TODO: support BooleanType, DateType and TimestampType // TODO: support BooleanType, DateType and TimestampType
...@@ -105,17 +98,11 @@ trait HashJoin { ...@@ -105,17 +98,11 @@ trait HashJoin {
keyExpr :: Nil keyExpr :: Nil
} }
protected lazy val canJoinKeyFitWithinLong: Boolean = {
val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
val key = rewriteKeyExpr(buildKeys)
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
}
protected def buildSideKeyGenerator(): Projection = protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) UnsafeProjection.create(buildKeys)
protected def streamSideKeyGenerator(): UnsafeProjection = protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) UnsafeProjection.create(streamedKeys)
@transient private[this] lazy val boundCondition = if (condition.isDefined) { @transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.joins package org.apache.spark.sql.execution.joins
import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.TaskContext
import org.apache.spark.memory.MemoryMode
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
...@@ -57,54 +56,20 @@ case class ShuffledHashJoin( ...@@ -57,54 +56,20 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] = override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val context = TaskContext.get() val context = TaskContext.get()
if (!canJoinKeyFitWithinLong) { val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
// build BytesToBytesMap // This relation is usually used until the end of task.
val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
// This relation is usually used until the end of task.
context.addTaskCompletionListener((t: TaskContext) =>
relation.close()
)
return relation
}
// try to acquire some memory for the hash table, it could trigger other operator to free some
// memory. The memory acquired here will mostly be used until the end of task.
val memoryManager = context.taskMemoryManager()
var acquired = 0L
var used = 0L
context.addTaskCompletionListener((t: TaskContext) => context.addTaskCompletionListener((t: TaskContext) =>
memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) relation.close()
) )
relation
val copiedIter = iter.map { row =>
// It's hard to guess what's exactly memory will be used, we have a rough guess here.
// TODO: use LongToBytesMap instead of HashMap for memory efficiency
// Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
val needed = 150 + row.getSizeInBytes
if (needed > acquired - used) {
val got = memoryManager.acquireExecutionMemory(
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
acquired += got
if (got < needed) {
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
"hash join, please use sort merge join by setting " +
"spark.sql.join.preferSortMergeJoin=true")
}
}
used += needed
// HashedRelation requires that the UnsafeRow should be separate objects.
row.copy()
}
HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
} }
protected override def doExecute(): RDD[InternalRow] = { protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows") val numOutputRows = longMetric("numOutputRows")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) val hashed = buildHashedRelation(buildIter)
join(streamIter, hashed, numOutputRows) join(streamIter, hashed, numOutputRows)
} }
} }
......
...@@ -21,6 +21,7 @@ import java.util.HashMap ...@@ -21,6 +21,7 @@ import java.util.HashMap
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.vectorized.AggregateHashMap import org.apache.spark.sql.execution.vectorized.AggregateHashMap
...@@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X
Join w long codegen=true 275 / 352 76.2 13.1 19.4X Join w long codegen=true 321 / 371 65.3 15.3 9.3X
*/ */
runBenchmark("Join w long duplicated", N) { runBenchmark("Join w long duplicated", N) {
...@@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X
Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X
*/ */
val dim2 = broadcast(sqlContext.range(M) val dim2 = broadcast(sqlContext.range(M)
...@@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X
Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X
*/ */
val dim3 = broadcast(sqlContext.range(M) val dim3 = broadcast(sqlContext.range(M)
...@@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X
outer join w long codegen=true 216 / 226 97.2 10.3 26.3X outer join w long codegen=true 261 / 276 80.5 12.4 11.7X
*/ */
runBenchmark("semi join w long", N) { runBenchmark("semi join w long", N) {
...@@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X
semi join w long codegen=true 211 / 229 99.2 10.1 22.2X semi join w long codegen=true 237 / 244 88.3 11.3 8.1X
*/ */
} }
...@@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X
shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X
*/ */
} }
...@@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
} }
ignore("hash and BytesToBytesMap") { ignore("hash and BytesToBytesMap") {
val N = 10 << 20 val N = 20 << 20
val benchmark = new Benchmark("BytesToBytesMap", N) val benchmark = new Benchmark("BytesToBytesMap", N)
benchmark.addCase("hash") { iter => benchmark.addCase("UnsafeRowhash") { iter =>
var i = 0 var i = 0
val keyBytes = new Array[Byte](16) val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1) val key = new UnsafeRow(1)
...@@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
} }
} }
benchmark.addCase("murmur3 hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var p = 524283
var s = 0
while (i < N) {
var h = Murmur3_x86_32.hashLong(i, 42)
key.setInt(0, h)
s += h
i += 1
}
}
benchmark.addCase("fast hash") { iter => benchmark.addCase("fast hash") { iter =>
var i = 0 var i = 0
val keyBytes = new Array[Byte](16) val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1) val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var p = 524283
var s = 0 var s = 0
while (i < N) { while (i < N) {
key.setInt(0, i % 1000) var h = i % p
val h = Murmur3_x86_32.hashLong(i % 1000, 42) if (h < 0) {
h += p
}
key.setInt(0, h)
s += h s += h
i += 1 i += 1
} }
...@@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
} }
} }
Seq(false, true).foreach { optimized =>
benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter =>
var i = 0
val valueBytes = new Array[Byte](16)
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
value.setInt(0, 555)
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val map = new LongToUnsafeRowMap(taskMemoryManager, 64)
while (i < 65536) {
value.setInt(0, i)
val key = i % 100000
map.append(key, value)
i += 1
}
if (optimized) {
map.optimize()
}
var s = 0
i = 0
while (i < N) {
val key = i % 100000
if (map.getValue(key, value) != null) {
s += 1
}
i += 1
}
}
}
Seq("off", "on").foreach { heap => Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager( val taskMemoryManager = new TaskMemoryManager(
...@@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val value = new UnsafeRow(1) val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0 var i = 0
while (i < N) { val numKeys = 65536
while (i < numKeys) {
key.setInt(0, i % 65536) key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 65536, 42)) Murmur3_x86_32.hashLong(i % 65536, 42))
if (loc.isDefined) { if (!loc.isDefined) {
value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
value.setInt(0, value.getInt(0) + 1)
i += 1
} else {
loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
} }
i += 1
}
i = 0
var s = 0
while (i < N) {
key.setInt(0, i % 100000)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 100000, 42))
if (loc.isDefined) {
s += 1
}
i += 1
} }
} }
} }
...@@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
hash 112 / 116 93.2 10.7 1.0X UnsafeRow hash 267 / 284 78.4 12.8 1.0X
fast hash 65 / 69 160.9 6.2 1.7X murmur3 hash 102 / 129 205.5 4.9 2.6X
arrayEqual 66 / 69 159.1 6.3 1.7X fast hash 79 / 96 263.8 3.8 3.4X
Java HashMap (Long) 137 / 182 76.3 13.1 0.8X arrayEqual 164 / 172 128.2 7.8 1.6X
Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X Java HashMap (Long) 321 / 399 65.4 15.3 0.8X
Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X
BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X
BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X
Aggregate HashMap 56 / 62 187.9 5.3 2.0X LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X
*/ BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X
BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X
Aggregate HashMap 121 / 131 173.3 5.8 2.2X
*/
benchmark.run() benchmark.run()
} }
......
...@@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ...@@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("compatible BroadcastMode") { test("compatible BroadcastMode") {
val mode1 = IdentityBroadcastMode val mode1 = IdentityBroadcastMode
val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)
assert(mode1.compatibleWith(mode1)) assert(mode1.compatibleWith(mode1))
assert(!mode1.compatibleWith(mode2)) assert(!mode1.compatibleWith(mode2))
...@@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ...@@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(plan sameResult plan) assert(plan sameResult plan)
val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
val hashMode = HashedRelationBroadcastMode(true, output, plan.output) val hashMode = HashedRelationBroadcastMode(output)
val exchange2 = BroadcastExchange(hashMode, plan) val exchange2 = BroadcastExchange(hashMode, plan)
val hashMode2 = val hashMode2 =
HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
val exchange3 = BroadcastExchange(hashMode2, plan) val exchange3 = BroadcastExchange(hashMode2, plan)
val exchange4 = ReusedExchange(output, exchange3) val exchange4 = ReusedExchange(output, exchange3)
......
...@@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer ...@@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val mm = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
test("UnsafeHashedRelation") { test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil) val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema) val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy()) val unsafeData = data.map(toUnsafe(_).copy())
val buildKey = Seq(BoundReference(0, IntegerType, false)) val buildKey = Seq(BoundReference(0, IntegerType, false))
val keyGenerator = UnsafeProjection.create(buildKey) val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm)
val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.isInstanceOf[UnsafeHashedRelation])
assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
...@@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { ...@@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
} }
test("LongArrayRelation") { test("LongToUnsafeRowMap") {
val unsafeProj = UnsafeProjection.create( val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) val key = Seq(BoundReference(0, IntegerType, false))
val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
assert(longRelation.isInstanceOf[LongArrayRelation]) assert(longRelation.keyIsUnique)
val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
(0 until 100).foreach { i => (0 until 100).foreach { i =>
val row = longArrayRelation.getValue(i) val row = longRelation.getValue(i)
assert(row.getInt(0) === i) assert(row.getInt(0) === i)
assert(row.getInt(1) === i + 1) assert(row.getInt(1) === i + 1)
} }
val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
assert(!longRelation2.keyIsUnique)
(0 until 100).foreach { i =>
val rows = longRelation2.get(i).toArray
assert(rows.length === 2)
assert(rows(0).getInt(0) === i)
assert(rows(0).getInt(1) === i + 1)
assert(rows(1).getInt(0) === i)
assert(rows(1).getInt(1) === i + 1)
}
val os = new ByteArrayOutputStream() val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os) val out = new ObjectOutputStream(os)
longArrayRelation.writeExternal(out) longRelation2.writeExternal(out)
out.flush() out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val relation = new LongArrayRelation() val relation = new LongHashedRelation()
relation.readExternal(in) relation.readExternal(in)
assert(!relation.keyIsUnique)
(0 until 100).foreach { i => (0 until 100).foreach { i =>
val row = longArrayRelation.getValue(i) val rows = relation.get(i).toArray
assert(row.getInt(0) === i) assert(rows.length === 2)
assert(row.getInt(1) === i + 1) assert(rows(0).getInt(0) === i)
assert(rows(0).getInt(1) === i + 1)
assert(rows(1).getInt(0) === i)
assert(rows(1).getInt(1) === i + 1)
} }
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment