diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 9dfd84cbc99410e49a560d393bec173795a70ac5..86c788aaa828a24f44104a42e37ec295d588d87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -57,9 +57,9 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - ResolvedHint(plan, isBroadcastable = Option(true)) + ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - ResolvedHint(plan, isBroadcastable = Option(true)) + ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. @@ -88,7 +88,7 @@ object ResolveHints { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. - ResolvedHint(h.child, isBroadcastable = Option(true)) + ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true))) } else { // Otherwise, find within the subtree query plans that should be broadcasted. applyBroadcastHint(h.child, h.parameters.toSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6bdcf490ca5c84602150058b7e11f9b4175f0cb5..2ebb2ff323c6bb78200404833159d686e546059b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -347,7 +347,7 @@ abstract class UnaryNode extends LogicalPlan { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable) + Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 81bb374cb0500271184d6c0e153b8c77ee4c869c..a64562b5dbd93ae6c779b91478e0e8b7751b8491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -46,13 +46,13 @@ import org.apache.spark.util.Utils * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. * @param attributeStats Statistics for Attributes. - * @param isBroadcastable If true, output is small enough to be used in a broadcast join. + * @param hints Query hints. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil), - isBroadcastable: Boolean = false) { + hints: HintInfo = HintInfo()) { override def toString: String = "Statistics(" + simpleString + ")" @@ -65,14 +65,9 @@ case class Statistics( } else { "" }, - s"isBroadcastable=$isBroadcastable" + s"hints=$hints" ).filter(_.nonEmpty).mkString(", ") } - - /** Must be called when computing stats for a join operator to reset hints. */ - def resetHintsForJoin(): Statistics = copy( - isBroadcastable = false - ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9f34b371740bd98231f4fab54fe9ff2555d19c19..6878b6b179c3a47d6d086808584a3f2004ef861e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -195,9 +195,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation val leftSize = left.stats(conf).sizeInBytes val rightSize = right.stats(conf).sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - val isBroadcastable = left.stats(conf).isBroadcastable || right.stats(conf).isBroadcastable - - Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable) + Statistics( + sizeInBytes = sizeInBytes, + hints = left.stats(conf).hints.resetForJoin()) } } @@ -364,7 +364,8 @@ case class Join( case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - super.computeStats(conf).resetHintsForJoin() + val stats = super.computeStats(conf) + stats.copy(hints = stats.hints.resetForJoin()) } if (conf.cboEnabled) { @@ -560,7 +561,7 @@ case class Aggregate( Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), rowCount = Some(1), - isBroadcastable = child.stats(conf).isBroadcastable) + hints = child.stats(conf).hints) } else { super.computeStats(conf) } @@ -749,7 +750,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), rowCount = Some(rowCount), - isBroadcastable = childStats.isBroadcastable) + hints = childStats.hints) } } @@ -770,7 +771,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo Statistics( sizeInBytes = 1, rowCount = Some(0), - isBroadcastable = childStats.isBroadcastable) + hints = childStats.hints) } else { // The output row count of LocalLimit should be the sum of row counts from each partition. // However, since the number of partitions is not available here, we just use statistics of @@ -827,7 +828,7 @@ case class Sample( } val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) // Don't propagate column stats, because we don't know the distribution after a sample operation - Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) + Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 9bcbfbb4d1397898337a80a5db0dc01aaee41f73..b96d7bc9cfdb6470fac0b8937e168d8edfd9b9c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -35,15 +35,31 @@ case class UnresolvedHint(name: String, parameters: Seq[String], child: LogicalP /** * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]]. */ -case class ResolvedHint( - child: LogicalPlan, - isBroadcastable: Option[Boolean] = None) +case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) extends UnaryNode { override def output: Seq[Attribute] = child.output override def computeStats(conf: SQLConf): Statistics = { val stats = child.stats(conf) - isBroadcastable.map(x => stats.copy(isBroadcastable = x)).getOrElse(stats) + stats.copy(hints = hints) + } +} + + +case class HintInfo( + isBroadcastable: Option[Boolean] = None) { + + /** Must be called when computing stats for a join operator to reset hints. */ + def resetForJoin(): HintInfo = copy( + isBroadcastable = None + ) + + override def toString: String = { + if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) { + "none" + } else { + isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 48b5fbb03ef1e01472794dd4e83671746e21dcfe..a0c23198451a84f900e99139b259630e3a176536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -56,7 +56,7 @@ object AggregateEstimation { sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, - isBroadcastable = childStats.isBroadcastable)) + hints = childStats.hints)) } else { None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index bb914e11a139a0859d38b322a163e7afef5b1a3c..3d5148008c628a3c97d1a685d4fc3229777da9ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -36,17 +36,17 @@ class ResolveHintsSuite extends AnalysisTest { test("case-sensitive or insensitive parameters") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, isBroadcastable = Option(true)), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), - ResolvedHint(testRelation, isBroadcastable = Option(true)), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, isBroadcastable = Option(true)), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = true) checkAnalysis( @@ -58,28 +58,28 @@ class ResolveHintsSuite extends AnalysisTest { test("multiple broadcast hint aliases") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(ResolvedHint(testRelation, isBroadcastable = Option(true)), - ResolvedHint(testRelation2, isBroadcastable = Option(true)), Inner, None), + Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), - ResolvedHint(table("table").where('a > 1), isBroadcastable = Option(true))), - ResolvedHint(testRelation.where('a > 1), isBroadcastable = Option(true)).analyze, + ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))), + ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - ResolvedHint(testRelation, isBroadcastable = Option(true)), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - ResolvedHint(testRelation, isBroadcastable = Option(true)), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. @@ -104,7 +104,7 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - ResolvedHint(testRelation.where('a > 1).select('a), isBroadcastable = Option(true)) + ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true))) .select('a).analyze, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 81b91e63b8f675e8b778077f7e32f529d90fd86c..2afea6dd3d37c1aee6ec8137361fbdc5cd91c309 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -37,19 +37,20 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { test("BroadcastHint estimation") { val filter = Filter(Literal(true), plan) - val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false, + val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) - val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false) + val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4)) checkStats( filter, expectedStatsCboOn = filterStatsCboOn, expectedStatsCboOff = filterStatsCboOff) - val broadcastHint = ResolvedHint(filter, isBroadcastable = Option(true)) + val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true))) checkStats( broadcastHint, - expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), - expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true)) + expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))), + expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true))) + ) } test("limit estimation: limit < child's rowCount") { @@ -94,15 +95,13 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5981b49da277edd73c4a1c16aa74eb790e728ae3..843ce63161220d9d1e9a5f578a2d549c0d8e0d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).isBroadcastable || + plan.stats(conf).hints.isBroadcastable.getOrElse(false) || (plan.stats(conf).sizeInBytes >= 0 && plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 563eae0b6483fa86c3af2a52cdfec25ae2f74e7a..36c0f18b6e2e358b7a33776e920c2fa442c413a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.ResolvedHint +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.SQLConf @@ -1020,7 +1020,7 @@ object functions { */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, isBroadcastable = Option(true)))(df.exprEnc) + ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ddc393c8da053060082e21cc6f385872d53fd577..601324f2c017267dc771d0d9ce993ccc1f3fb909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -164,7 +164,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared numbers.foreach { case (input, (expectedSize, expectedRows)) => val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + - s" isBroadcastable=${stats.isBroadcastable}" + s" hints=none" assert(stats.simpleString == expectedString) } }