diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c62d5ead869250901a7e820c4f74567101a8cd3a..371d72ef5af086e24c56229d1efb74be42542fea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => def output: Seq[Attribute] @@ -237,4 +237,65 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } override def innerChildren: Seq[PlanType] = subqueries + + /** + * Canonicalized copy of this query plan. + */ + protected lazy val canonicalized: PlanType = this + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually + * the same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * By default this function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. Operators that + * can do better should override this function. + */ + def sameResult(plan: PlanType): Boolean = { + val canonicalizedLeft = this.canonicalized + val canonicalizedRight = plan.canonicalized + canonicalizedLeft.getClass == canonicalizedRight.getClass && + canonicalizedLeft.children.size == canonicalizedRight.children.size && + canonicalizedLeft.cleanArgs == canonicalizedRight.cleanArgs && + (canonicalizedLeft.children, canonicalizedRight.children).zipped.forall(_ sameResult _) + } + + /** + * All the attributes that are used for this plan. + */ + lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output) + + private def cleanExpression(e: Expression): Expression = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = + Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) + BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) + case other => + BindReferences.bindReference(other, allAttributes, allowFailures = true) + } + + /** Args that have cleaned such that differences in expression id should not affect equality */ + protected lazy val cleanArgs: Seq[Any] = { + def cleanArg(arg: Any): Any = arg match { + case e: Expression => cleanExpression(e).canonicalized + case other => other + } + + productIterator.map { + // Children are checked using sameResult above. + case tn: TreeNode[_] if containsChild(tn) => null + case e: Expression => cleanArg(e) + case s: Option[_] => s.map(cleanArg) + case s: Seq[_] => s.map(cleanArg) + case m: Map[_, _] => m.mapValues(cleanArg) + case other => other + }.toSeq + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 31e775d60f950b0363cd1263b448288b96c7d6e8..b32c7d0fcbaa4b2f7c1c582ff0806199bcbf25f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -114,60 +114,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns true when the given logical plan will return the same results as this logical plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually - * the same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Logical operators that - * can do better should override this function. - */ - def sameResult(plan: LogicalPlan): Boolean = { - val cleanLeft = EliminateSubqueryAliases(this) - val cleanRight = EliminateSubqueryAliases(plan) - - cleanLeft.getClass == cleanRight.getClass && - cleanLeft.children.size == cleanRight.children.size && { - logDebug( - s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") - cleanRight.cleanArgs == cleanLeft.cleanArgs - } && - (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - val input = children.flatMap(_.output) - def cleanExpression(e: Expression) = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, input, allowFailures = true) - case other => BindReferences.bindReference(other, input, allowFailures = true) - } - - productIterator.map { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e) - case s: Option[_] => s.map { - case e: Expression => cleanExpression(e) - case other => other - } - case s: Seq[_] => s.map { - case e: Expression => cleanExpression(e) - case other => other - } - case other => other - }.toSeq - } + override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index e01f69f81359ee42f8e98f9f343f3e1a2a77a77a..9dfdf4da78ff605e2fa0a85fca6a6840c2c8458f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.InternalRow */ trait BroadcastMode { def transform(rows: Array[InternalRow]): Any + + /** + * Returns true iff this [[BroadcastMode]] generates the same result as `other`. + */ + def compatibleWith(other: BroadcastMode): Boolean } /** @@ -33,4 +38,8 @@ trait BroadcastMode { case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + + override def compatibleWith(other: BroadcastMode): Boolean = { + this eq other + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 9019e5dfd66c61a31d4e3ba3f6484efd6b20ef69..247f55da1d2a0b1d98d969efe720d2cce3c7e1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.exchange.ReusedExchange import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.util.Utils @@ -31,13 +32,28 @@ class SparkPlanInfo( val simpleString: String, val children: Seq[SparkPlanInfo], val metadata: Map[String, String], - val metrics: Seq[SQLMetricInfo]) + val metrics: Seq[SQLMetricInfo]) { + + override def hashCode(): Int = { + // hashCode of simpleString should be good enough to distinguish the plans from each other + // within a plan + simpleString.hashCode + } + + override def equals(other: Any): Boolean = other match { + case o: SparkPlanInfo => + nodeName == o.nodeName && simpleString == o.simpleString && children == o.children + case _ => false + } +} private[sql] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - - val children = plan.children ++ plan.subqueries + val children = plan match { + case ReusedExchange(_, child) => child :: Nil + case _ => plan.children ++ plan.subqueries + } val metrics = plan.metrics.toSeq.map { case (key, metric) => new SQLMetricInfo(metric.name.getOrElse(key), metric.id, Utils.getFormattedClassName(metric.param)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f07add83d5849b0069f107ba9b39b3a93b4e0350..f856634cf7b66d4d61112bd1cc8cc633f9669354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -46,6 +46,10 @@ case class TungstenAggregate( require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) + override lazy val allAttributes: Seq[Attribute] = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4a9e736f7abdb4d73125402f7b44ab21ea081e1d..49012982273af271a2a5d6933c34bd7131462b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -166,6 +166,9 @@ case class Range( private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + // output attributes should not affect the results + override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override def upstreams(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) .map(i => InternalRow(i)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala index 40cad4b1a7645b91af37514f2ff4a415232e3015..1a5c6a66c484e93e635236dd739cb5f5daff6aa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -34,12 +34,16 @@ import org.apache.spark.util.ThreadUtils */ case class BroadcastExchange( mode: BroadcastMode, - child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output + child: SparkPlan) extends Exchange { override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + override def sameResult(plan: SparkPlan): Boolean = plan match { + case p: BroadcastExchange => + mode.compatibleWith(p.mode) && child.sameResult(p.child) + case _ => false + } + @transient private val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala new file mode 100644 index 0000000000000000000000000000000000000000..12513e9106707fe66f1dcff2056b9aa37464fa1e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +/** + * An interface for exchanges. + */ +abstract class Exchange extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +/** + * A wrapper for reused exchange to have different output, because two exchanges which produce + * logically identical output will have distinct sets of output attribute ids, so we need to + * preserve the original ids because they're what downstream operators are expecting. + */ +case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode { + + override def sameResult(plan: SparkPlan): Boolean = { + // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. + plan.sameResult(child) + } + + def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.executeBroadcast() + } + + // Do not repeat the same tree in explain. + override def treeChildren: Seq[SparkPlan] = Nil +} + +/** + * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * references. + */ +private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + if (!sqlContext.conf.exchangeReuseEnabled) { + return plan + } + // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]() + plan.transformUp { + case exchange: Exchange => + // the exchanges that have same results usually also have same schemas (same column names). + val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]()) + val samePlan = sameSchema.find { e => + exchange.sameResult(e) + } + if (samePlan.isDefined) { + // Keep the output of this exchange, the following plans require that to resolve + // attributes. + ReusedExchange(exchange.output, samePlan.get) + } else { + sameSchema += exchange + exchange + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index de21d7705e1378a4c37c01f99df29cc1f4951c38..4eb4d9adbddc8d3e39308f39eca88a8336b58dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.MutablePair case class ShuffleExchange( var newPartitioning: Partitioning, child: SparkPlan, - @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { + @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { override def nodeName: String = { val extraInfo = coordinator match { @@ -55,8 +55,6 @@ case class ShuffleExchange( override def outputPartitioning: Partitioning = newPartitioning - override def output: Seq[Attribute] = child.output - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) override protected def doPrepare(): Unit = { @@ -103,16 +101,25 @@ case class ShuffleExchange( new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } + /** + * Caches the created ShuffleRowRDD so we can reuse that. + */ + private var cachedShuffleRDD: ShuffledRowRDD = null + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - coordinator match { - case Some(exchangeCoordinator) => - val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) - assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) - shuffleRDD - case None => - val shuffleDependency = prepareShuffleDependency() - preparePostShuffleRDD(shuffleDependency) + // Returns the same ShuffleRowRDD if this plan is used by multiple plans. + if (cachedShuffleRDD == null) { + cachedShuffleRDD = coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } } + cachedShuffleRDD } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 9a3cdaf697e2d8198c72ae685f51ec9d3f8a3bd3..99f8841c8737bd0d3cbfb895ea2b16788bbe1e2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -681,7 +681,7 @@ private[execution] case class HashedRelationBroadcastMode( keys: Seq[Expression], attributes: Seq[Attribute]) extends BroadcastMode { - def transform(rows: Array[InternalRow]): HashedRelation = { + override def transform(rows: Array[InternalRow]): HashedRelation = { val generator = UnsafeProjection.create(keys, attributes) if (canJoinKeyFitWithinLong) { LongHashedRelation(rows.iterator, generator, rows.length) @@ -689,5 +689,18 @@ private[execution] case class HashedRelationBroadcastMode( HashedRelation(rows.iterator, generator, rows.length) } } + + private lazy val canonicalizedKeys: Seq[Expression] = { + keys.map { e => + BindReferences.bindReference(e.canonicalized, attributes) + } + } + + override def compatibleWith(other: BroadcastMode): Boolean = other match { + case m: HashedRelationBroadcastMode => + canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && + canonicalizedKeys == m.canonicalizedKeys + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 83372aa2e930c1af0893d907fe76e482c3bb88cd..94d318e7027894feed6c2a3c456d5f36acbb1d96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -64,7 +64,8 @@ private[sql] object SparkPlanGraph { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null) + val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]() + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges) new SparkPlanGraph(nodes, edges) } @@ -74,7 +75,8 @@ private[sql] object SparkPlanGraph { nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge], parent: SparkPlanGraphNode, - subgraph: SparkPlanGraphCluster): Unit = { + subgraph: SparkPlanGraphCluster, + exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = { planInfo.nodeName match { case "WholeStageCodegen" => val cluster = new SparkPlanGraphCluster( @@ -84,13 +86,14 @@ private[sql] object SparkPlanGraph { mutable.ArrayBuffer[SparkPlanGraphNode]()) nodes += cluster buildSparkPlanGraphNode( - planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges) case "InputAdapter" => - buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null) - case _ => + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case name => val metrics = planInfo.metrics.map { metric => SQLPlanMetric(metric.name, metric.accumulatorId, SQLMetrics.getMetricParam(metric.metricParam)) @@ -103,12 +106,15 @@ private[sql] object SparkPlanGraph { } else { subgraph.nodes += node } + if (name == "ShuffleExchange" || name == "BroadcastExchange") { + exchanges += planInfo -> node + } if (parent != null) { edges += SparkPlanGraphEdge(node.id, parent.id) } planInfo.children.foreach( - buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1d1e2884414d8f734c38c31e263e96b5f15dc216..384102e5eaa5baf964a78d0b5dcb953353a213a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -504,6 +504,10 @@ object SQLConf { " method", isPublic = false) + val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", + defaultValue = Some(true), + doc = "When true, the planner will try to find out duplicated exchanges and re-use them", + isPublic = false) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -564,6 +568,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) + def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 6f81794b2949b103cb97528c99453ffc06093706..98ada4d58af7ec950cb72db9bcb8ba8bdd83945a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource} -import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.util.ExecutionListenerManager - /** * A class that holds all session-specific state in a given [[SQLContext]]. */ @@ -94,7 +93,8 @@ private[sql] class SessionState(ctx: SQLContext) { override val batches: Seq[Batch] = Seq( Batch("Subquery", Once, PlanSubqueries(ctx)), Batch("Add exchange", Once, EnsureRequirements(ctx)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)) + Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)), + Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 55153cda31e0a55e6bde1117f62926b76686c21a..26775c3700e2334b3dca9ffbd4a254a1768b77fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,9 +25,9 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, OneRowRelation, Union} import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -1316,6 +1316,40 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } + test("reuse exchange") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { + val df = sqlContext.range(100) + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + checkAnswer(join, df) + assert( + join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + assert(join.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 1) + val broadcasted = broadcast(join) + val join2 = join.join(broadcasted, "id").join(broadcasted, "id") + checkAnswer(join2, df) + assert( + join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + assert( + join2.queryExecution.executedPlan.collect { case e: BroadcastExchange => true }.size === 1) + assert( + join2.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 4) + } + } + + test("sameResult() on aggregate") { + val df = sqlContext.range(100) + val agg1 = df.groupBy().count() + val agg2 = df.groupBy().count() + // two aggregates with different ExprId within them should have same result + assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) + val agg3 = df.groupBy().sum() + assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) + val df2 = sqlContext.range(101) + val agg4 = df2.groupBy().count() + assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) + } + test("SPARK-12512: support `.` in column name for withColumn()") { val df = Seq("a" -> "b").toDF("col.a", "col.b") checkAnswer(df.select(df("*")), Row("a", "b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index d4f22de90c523b4f3b67ea3f8a24bfb46484c9c7..9f159d1e1e8a8b212d3f5e215a5cc7755edba407 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.physical.SinglePartition -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} +import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -33,4 +35,70 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { input.map(Row.fromTuple) ) } + + test("compatible BroadcastMode") { + val mode1 = IdentityBroadcastMode + val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) + val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) + + assert(mode1.compatibleWith(mode1)) + assert(!mode1.compatibleWith(mode2)) + assert(!mode2.compatibleWith(mode1)) + assert(mode2.compatibleWith(mode2)) + assert(!mode2.compatibleWith(mode3)) + assert(mode3.compatibleWith(mode3)) + } + + test("BroadcastExchange same result") { + val df = sqlContext.range(10) + val plan = df.queryExecution.executedPlan + val output = plan.output + assert(plan sameResult plan) + + val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) + val hashMode = HashedRelationBroadcastMode(true, output, plan.output) + val exchange2 = BroadcastExchange(hashMode, plan) + val hashMode2 = + HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) + val exchange3 = BroadcastExchange(hashMode2, plan) + val exchange4 = ReusedExchange(output, exchange3) + + assert(exchange1 sameResult exchange1) + assert(exchange2 sameResult exchange2) + assert(exchange3 sameResult exchange3) + assert(exchange4 sameResult exchange4) + + assert(!exchange1.sameResult(exchange2)) + assert(!exchange2.sameResult(exchange3)) + assert(!exchange3.sameResult(exchange4)) + assert(exchange4 sameResult exchange3) + } + + test("ShuffleExchange same result") { + val df = sqlContext.range(10) + val plan = df.queryExecution.executedPlan + val output = plan.output + assert(plan sameResult plan) + + val part1 = HashPartitioning(output, 1) + val exchange1 = ShuffleExchange(part1, plan) + val exchange2 = ShuffleExchange(part1, plan) + val part2 = HashPartitioning(output, 2) + val exchange3 = ShuffleExchange(part2, plan) + val part3 = HashPartitioning(output ++ output, 2) + val exchange4 = ShuffleExchange(part3, plan) + val exchange5 = ReusedExchange(output, exchange4) + + assert(exchange1 sameResult exchange1) + assert(exchange2 sameResult exchange2) + assert(exchange3 sameResult exchange3) + assert(exchange4 sameResult exchange4) + assert(exchange5 sameResult exchange5) + + assert(exchange1 sameResult exchange2) + assert(!exchange2.sameResult(exchange3)) + assert(!exchange3.sameResult(exchange4)) + assert(!exchange4.sameResult(exchange5)) + assert(exchange5 sameResult exchange4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a733237a5e717724486a4d23536a7b8ae03f6c71..ab0a7ff62896281ab43048e02ee76433b22a987b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - class PlannerSuite extends SharedSQLContext { import testImplicits._ @@ -472,6 +471,50 @@ class PlannerSuite extends SharedSQLContext { } // --------------------------------------------------------------------------------------------- + + test("Reuse exchanges") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val shuffle = ShuffleExchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val inputPlan = SortMergeJoin( + Literal(1) :: Nil, + Literal(1) :: Nil, + None, + shuffle, + shuffle) + + val outputPlan = ReuseExchange(sqlContext).apply(inputPlan) + if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { + fail(s"Should re-use the shuffle:\n$outputPlan") + } + if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + fail(s"Should have only one shuffle:\n$outputPlan") + } + + // nested exchanges + val inputPlan2 = SortMergeJoin( + Literal(1) :: Nil, + Literal(1) :: Nil, + None, + ShuffleExchange(finalPartitioning, inputPlan), + ShuffleExchange(finalPartitioning, inputPlan)) + + val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2) + if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { + fail(s"Should re-use the two shuffles:\n$outputPlan2") + } + if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + fail(s"Should have only two shuffles:\n$outputPlan") + } + } } // Used for unit-testing EnsureRequirements