Skip to content
Snippets Groups Projects
Commit 9308bf11 authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-15390] fix broadcast with 100 millions rows

## What changes were proposed in this pull request?

When broadcast a table with more than 100 millions rows (should not ideally), the size of needed memory will overflow.

This PR fix the overflow by converting it to Long when calculating the size of memory.

Also add more checking in broadcast to show reasonable messages.

## How was this patch tested?

Add test.

Author: Davies Liu <davies@databricks.com>

Closes #13182 from davies/fix_broadcast.
parent 31f63ac2
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import org.apache.spark.broadcast
import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
......@@ -72,9 +72,18 @@ case class BroadcastExchangeExec(
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = child.executeCollect()
if (input.length >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
}
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}
// Construct and broadcast the relation.
val relation = mode.transform(input)
......
......@@ -410,9 +410,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private def init(): Unit = {
if (mm != null) {
require(capacity < 512000000, "Cannot broadcast more than 512 millions rows")
var n = 1
while (n < capacity) n *= 2
ensureAcquireMemory(n * 2 * 8 + (1 << 20))
ensureAcquireMemory(n * 2L * 8 + (1 << 20))
array = new Array[Long](n * 2)
mask = n * 2 - 2
page = new Array[Long](1 << 17) // 1M bytes
......@@ -788,7 +789,7 @@ private[joins] object LongHashedRelation {
sizeEstimate: Int,
taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
val keyGenerator = UnsafeProjection.create(key)
// Create a mapping of key -> rows
......
......@@ -212,4 +212,19 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(longRelation.estimatedSize > (2L << 30))
longRelation.close()
}
test("build HashedRelation with more than 100 millions rows") {
val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false),
BoundReference(1, StringType, true)))
val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100)))
val key = Seq(BoundReference(0, IntegerType, false))
val rows = (0 until (1 << 10)).iterator.map { i =>
unsafeRow.setInt(0, i % 1000000)
unsafeRow.setInt(1, i)
unsafeRow
}
val m = LongHashedRelation(rows, key, 100 << 20, mm)
m.close()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment