Skip to content
Snippets Groups Projects
Commit 5cb5edaf authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long

## What changes were proposed in this pull request?

Currently, we use java HashMap for HashedRelation if the key could fit within a Long. The java HashMap and CompactBuffer are not memory efficient, the memory used by them is also accounted accurately.

This PR introduce a LongToUnsafeRowMap (similar to BytesToBytesMap) for better memory efficiency and performance.

This PR reopen #12190 to fix bugs.

## How was this patch tested?

Existing tests.

Author: Davies Liu <davies@databricks.com>

Closes #12278 from davies/long_map3.
parent dfce9665
No related branches found
No related tags found
No related merge requests found
Showing
with 602 additions and 361 deletions
......@@ -716,7 +716,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
offset += klen;
Platform.copyMemory(vbase, voff, base, offset, vlen);
offset += vlen;
Platform.putLong(base, offset, 0);
// put this value at the beginning of the list
Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0);
// --- Update bookkeeping data structures ----------------------------------------------------
offset = currentPage.getBaseOffset();
......@@ -724,17 +725,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
pageCursor += recordLength;
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
currentPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
updateAddressesAndSizes(storedKeyAddress);
numValues++;
if (isDefined) {
// put this pair at the end of chain
while (nextValue()) { /* do nothing */ }
Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress);
nextValue(); // point to new added value
} else {
if (!isDefined) {
numKeys++;
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
isDefined = true;
if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {
......
......@@ -454,7 +454,7 @@ case class TungstenAggregate(
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
ctx.addMutableState(hashMapClassName, hashMapTerm, s"")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
......@@ -467,6 +467,7 @@ case class TungstenAggregate(
s"""
${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
$hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
......
......@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
......@@ -50,10 +51,7 @@ case class BroadcastHashJoin(
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(
canJoinKeyFitWithinLong,
rewriteKeyExpr(buildKeys),
buildPlan.output)
val mode = HashedRelationBroadcastMode(buildKeys)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
......@@ -68,7 +66,7 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy()
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize)
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
join(streamedIter, hashed, numOutputRows)
}
}
......@@ -105,7 +103,7 @@ case class BroadcastHashJoin(
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($relationTerm.getMemorySize());
| incPeakExecutionMemory($relationTerm.estimatedSize());
""".stripMargin)
(broadcastRelation, relationTerm)
}
......@@ -118,15 +116,13 @@ case class BroadcastHashJoin(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
if (canJoinKeyFitWithinLong) {
if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
// generate the join key as Long
val expr = rewriteKeyExpr(streamedKeys).head
val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
val ev = streamedKeys.head.gen(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
(ev, s"${ev.value}.anyNull()")
}
}
......
......@@ -17,16 +17,12 @@
package org.apache.spark.sql.execution.joins
import java.util.NoSuchElementException
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
self: SparkPlan =>
......@@ -59,9 +55,15 @@ trait HashJoin {
case BuildRight => (right, left)
}
protected lazy val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
}
}
/**
......@@ -69,7 +71,7 @@ trait HashJoin {
*
* If not, returns the original expressions.
*/
def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
var keyExpr: Expression = null
var width = 0
keys.foreach { e =>
......@@ -84,17 +86,8 @@ trait HashJoin {
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
// hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
// value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
// with two same ints have hash code 0, we rotate the bits of second one.
val rotated = if (e.dataType == IntegerType) {
// (e >>> 15) | (e << 17)
BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
} else {
e
}
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
......@@ -105,17 +98,11 @@ trait HashJoin {
keyExpr :: Nil
}
protected lazy val canJoinKeyFitWithinLong: Boolean = {
val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
val key = rewriteKeyExpr(buildKeys)
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
}
protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
UnsafeProjection.create(buildKeys)
protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
UnsafeProjection.create(streamedKeys)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
......
......@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.memory.MemoryMode
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
......@@ -57,54 +56,20 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val context = TaskContext.get()
if (!canJoinKeyFitWithinLong) {
// build BytesToBytesMap
val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
// This relation is usually used until the end of task.
context.addTaskCompletionListener((t: TaskContext) =>
relation.close()
)
return relation
}
// try to acquire some memory for the hash table, it could trigger other operator to free some
// memory. The memory acquired here will mostly be used until the end of task.
val memoryManager = context.taskMemoryManager()
var acquired = 0L
var used = 0L
val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
// This relation is usually used until the end of task.
context.addTaskCompletionListener((t: TaskContext) =>
memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
relation.close()
)
val copiedIter = iter.map { row =>
// It's hard to guess what's exactly memory will be used, we have a rough guess here.
// TODO: use LongToBytesMap instead of HashMap for memory efficiency
// Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
val needed = 150 + row.getSizeInBytes
if (needed > acquired - used) {
val got = memoryManager.acquireExecutionMemory(
Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
acquired += got
if (got < needed) {
throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
"hash join, please use sort merge join by setting " +
"spark.sql.join.preferSortMergeJoin=true")
}
}
used += needed
// HashedRelation requires that the UnsafeRow should be separate objects.
row.copy()
}
HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
relation
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
val hashed = buildHashedRelation(buildIter)
join(streamIter, hashed, numOutputRows)
}
}
......
......@@ -21,6 +21,7 @@ import java.util.HashMap
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.vectorized.AggregateHashMap
......@@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X
Join w long codegen=true 275 / 352 76.2 13.1 19.4X
Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X
Join w long codegen=true 321 / 371 65.3 15.3 9.3X
*/
runBenchmark("Join w long duplicated", N) {
......@@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X
Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X
Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X
Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X
*/
val dim2 = broadcast(sqlContext.range(M)
......@@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X
Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X
Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X
Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X
*/
val dim3 = broadcast(sqlContext.range(M)
......@@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X
outer join w long codegen=true 216 / 226 97.2 10.3 26.3X
outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X
outer join w long codegen=true 261 / 276 80.5 12.4 11.7X
*/
runBenchmark("semi join w long", N) {
......@@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X
semi join w long codegen=true 211 / 229 99.2 10.1 22.2X
semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X
semi join w long codegen=true 237 / 244 88.3 11.3 8.1X
*/
}
......@@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X
shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X
shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X
shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X
*/
}
......@@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("hash and BytesToBytesMap") {
val N = 10 << 20
val N = 20 << 20
val benchmark = new Benchmark("BytesToBytesMap", N)
benchmark.addCase("hash") { iter =>
benchmark.addCase("UnsafeRowhash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
......@@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
benchmark.addCase("murmur3 hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var p = 524283
var s = 0
while (i < N) {
var h = Murmur3_x86_32.hashLong(i, 42)
key.setInt(0, h)
s += h
i += 1
}
}
benchmark.addCase("fast hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var p = 524283
var s = 0
while (i < N) {
key.setInt(0, i % 1000)
val h = Murmur3_x86_32.hashLong(i % 1000, 42)
var h = i % p
if (h < 0) {
h += p
}
key.setInt(0, h)
s += h
i += 1
}
......@@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
Seq(false, true).foreach { optimized =>
benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter =>
var i = 0
val valueBytes = new Array[Byte](16)
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
value.setInt(0, 555)
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val map = new LongToUnsafeRowMap(taskMemoryManager, 64)
while (i < 65536) {
value.setInt(0, i)
val key = i % 100000
map.append(key, value)
i += 1
}
if (optimized) {
map.optimize()
}
var s = 0
i = 0
while (i < N) {
val key = i % 100000
if (map.getValue(key, value) != null) {
s += 1
}
i += 1
}
}
}
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
......@@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
while (i < N) {
val numKeys = 65536
while (i < numKeys) {
key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 65536, 42))
if (loc.isDefined) {
value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
value.setInt(0, value.getInt(0) + 1)
i += 1
} else {
if (!loc.isDefined) {
loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
}
i += 1
}
i = 0
var s = 0
while (i < N) {
key.setInt(0, i % 100000)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 100000, 42))
if (loc.isDefined) {
s += 1
}
i += 1
}
}
}
......@@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
hash 112 / 116 93.2 10.7 1.0X
fast hash 65 / 69 160.9 6.2 1.7X
arrayEqual 66 / 69 159.1 6.3 1.7X
Java HashMap (Long) 137 / 182 76.3 13.1 0.8X
Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X
Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X
BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X
BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X
Aggregate HashMap 56 / 62 187.9 5.3 2.0X
*/
UnsafeRow hash 267 / 284 78.4 12.8 1.0X
murmur3 hash 102 / 129 205.5 4.9 2.6X
fast hash 79 / 96 263.8 3.8 3.4X
arrayEqual 164 / 172 128.2 7.8 1.6X
Java HashMap (Long) 321 / 399 65.4 15.3 0.8X
Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X
Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X
LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X
LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X
BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X
BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X
Aggregate HashMap 121 / 131 173.3 5.8 2.2X
*/
benchmark.run()
}
......
......@@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("compatible BroadcastMode") {
val mode1 = IdentityBroadcastMode
val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq())
val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq())
val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)
assert(mode1.compatibleWith(mode1))
assert(!mode1.compatibleWith(mode2))
......@@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(plan sameResult plan)
val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
val hashMode = HashedRelationBroadcastMode(true, output, plan.output)
val hashMode = HashedRelationBroadcastMode(output)
val exchange2 = BroadcastExchange(hashMode, plan)
val hashMode2 =
HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output)
HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
val exchange3 = BroadcastExchange(hashMode2, plan)
val exchange4 = ReusedExchange(output, exchange3)
......
......@@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val mm = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy())
val buildKey = Seq(BoundReference(0, IntegerType, false))
val keyGenerator = UnsafeProjection.create(buildKey)
val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
......@@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
test("LongArrayRelation") {
test("LongToUnsafeRowMap") {
val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
val longRelation = LongHashedRelation(rows.iterator, keyProj, 100)
assert(longRelation.isInstanceOf[LongArrayRelation])
val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
val key = Seq(BoundReference(0, IntegerType, false))
val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
assert(longRelation.keyIsUnique)
(0 until 100).foreach { i =>
val row = longArrayRelation.getValue(i)
val row = longRelation.getValue(i)
assert(row.getInt(0) === i)
assert(row.getInt(1) === i + 1)
}
val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
assert(!longRelation2.keyIsUnique)
(0 until 100).foreach { i =>
val rows = longRelation2.get(i).toArray
assert(rows.length === 2)
assert(rows(0).getInt(0) === i)
assert(rows(0).getInt(1) === i + 1)
assert(rows(1).getInt(0) === i)
assert(rows(1).getInt(1) === i + 1)
}
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
longArrayRelation.writeExternal(out)
longRelation2.writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val relation = new LongArrayRelation()
val relation = new LongHashedRelation()
relation.readExternal(in)
assert(!relation.keyIsUnique)
(0 until 100).foreach { i =>
val row = longArrayRelation.getValue(i)
assert(row.getInt(0) === i)
assert(row.getInt(1) === i + 1)
val rows = relation.get(i).toArray
assert(rows.length === 2)
assert(rows(0).getInt(0) === i)
assert(rows(0).getInt(1) === i + 1)
assert(rows(1).getInt(0) === i)
assert(rows(1).getInt(1) === i + 1)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment