Skip to content
Snippets Groups Projects
Commit 767d4807 authored by Sameer Agarwal's avatar Sameer Agarwal Committed by Herman van Hovell
Browse files

[SPARK-17415][SQL] Better error message for driver-side broadcast join OOMs

## What changes were proposed in this pull request?

This is a trivial patch that catches all `OutOfMemoryError` while building the broadcast hash relation and rethrows it by wrapping it in a nice error message.

## How was this patch tested?

Existing Tests

Author: Sameer Agarwal <sameerag@cs.berkeley.edu>

Closes #14979 from sameeragarwal/broadcast-join-error.
parent 883c7631
No related branches found
No related tags found
No related merge requests found
...@@ -21,6 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} ...@@ -21,6 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._ import scala.concurrent.duration._
import org.apache.spark.{broadcast, SparkException} import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow
...@@ -28,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPar ...@@ -28,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPar
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils import org.apache.spark.util.ThreadUtils
/** /**
...@@ -70,38 +72,47 @@ case class BroadcastExchangeExec( ...@@ -70,38 +72,47 @@ case class BroadcastExchangeExec(
// This will run in another thread. Set the execution id so that we can connect these jobs // This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution. // with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) { SQLExecution.withExecutionId(sparkContext, executionId) {
val beforeCollect = System.nanoTime() try {
// Note that we use .executeCollect() because we don't want to convert data to Scala types val beforeCollect = System.nanoTime()
val input: Array[InternalRow] = child.executeCollect() // Note that we use .executeCollect() because we don't want to convert data to Scala types
if (input.length >= 512000000) { val input: Array[InternalRow] = child.executeCollect()
throw new SparkException( if (input.length >= 512000000) {
s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") throw new SparkException(
s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
}
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}
// Construct and broadcast the relation.
val relation = mode.transform(input)
val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
}
broadcasted
} catch {
case oe: OutOfMemoryError =>
throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
.initCause(oe.getCause)
} }
val beforeBuild = System.nanoTime()
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}
// Construct and broadcast the relation.
val relation = mode.transform(input)
val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
}
broadcasted
} }
}(BroadcastExchangeExec.executionContext) }(BroadcastExchangeExec.executionContext)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment