diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index b6ecd3cb065ae21378ca6adcab159ab30043f9f1..d3081ba7accd2f595818c01dc535b0d12dba80ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -72,9 +72,18 @@ case class BroadcastExchangeExec( val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val input: Array[InternalRow] = child.executeCollect() + if (input.length >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + } val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + longMetric("dataSize") += dataSize + if (dataSize >= (8L << 30)) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } // Construct and broadcast the relation. val relation = mode.transform(input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cb41457b6653fa096ab16c3c6e63f2ed61463474..cd6b97a855412fbfd4df9c709ea8868657442c91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -410,9 +410,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def init(): Unit = { if (mm != null) { + require(capacity < 512000000, "Cannot broadcast more than 512 millions rows") var n = 1 while (n < capacity) n *= 2 - ensureAcquireMemory(n * 2 * 8 + (1 << 20)) + ensureAcquireMemory(n * 2L * 8 + (1 << 20)) array = new Array[Long](n * 2) mask = n * 2 - 2 page = new Array[Long](1 << 17) // 1M bytes @@ -788,7 +789,7 @@ private[joins] object LongHashedRelation { sizeEstimate: Int, taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index b7b08dc4b126f239b9d75b64020269defde6c349..a5b56541c90f75ce2faa5771de1905a0ba6dee4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -212,4 +212,19 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(longRelation.estimatedSize > (2L << 30)) longRelation.close() } + + test("build HashedRelation with more than 100 millions rows") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), + BoundReference(1, StringType, true))) + val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100))) + val key = Seq(BoundReference(0, IntegerType, false)) + val rows = (0 until (1 << 10)).iterator.map { i => + unsafeRow.setInt(0, i % 1000000) + unsafeRow.setInt(1, i) + unsafeRow + } + val m = LongHashedRelation(rows, key, 100 << 20, mm) + m.close() + } }