diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d2626440b9434d6d84cb5b54d317596d191ce07d..b43b7ee71e7aaa7e3916b0b63ff4ccb3a4763058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -43,16 +43,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def inputSet: AttributeSet = AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + /** + * The set of all attributes that are produced by this node. + */ + def producedAttributes: AttributeSet = AttributeSet.empty + /** * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. - * - * Note that virtual columns should be excluded. Currently, we only support the grouping ID - * virtual column. */ - def missingInput: AttributeSet = - (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) + def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e7a11dba97313b246d2c67826c97c431a72da1..572d7d2f0b537aae0e089deb7e4ecdd7ae463196 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e105932d5dd80591f67fae672bfa19d035..6d859551f8c52e45fb6a468073da53a75b6a5b9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -295,6 +295,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ abstract class LeafNode extends LogicalPlan { override def children: Seq[LogicalPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64ef4d799659f88c141a1098346bb5b4c5290a45..5f34d4a4eb73c43e6f75e02bb0ce2f2c7e77d487 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -526,7 +526,7 @@ case class MapPartitions[T, U]( uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } /** Factory for constructing new `AppendColumn` nodes. */ @@ -552,7 +552,7 @@ case class AppendColumns[T, U]( newColumns: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns - override def missingInput: AttributeSet = super.missingInput -- newColumns + override def producedAttributes: AttributeSet = AttributeSet(newColumns) } /** Factory for constructing new `MapGroups` nodes. */ @@ -587,7 +587,7 @@ case class MapGroups[K, T, U]( groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } /** Factory for constructing new `CoGroup` nodes. */ @@ -630,5 +630,5 @@ case class CoGroup[Key, Left, Right, Result]( rightGroup: Seq[Attribute], left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index ea5a9afe03b00797ffb56578cdc1679d81f0a692..5c01af011d3064818079d46f4a3c3a2cb96dcc30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation} +import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} @@ -84,6 +84,8 @@ private[sql] case class LogicalRDD( case _ => false } + override def producedAttributes: AttributeSet = outputSet + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 54b8cb58285c23e7ad2c5b19461b438259fd7a0c..0c613e91b979f514bea8748073c92104956bd46e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -54,6 +54,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { + override def expressions: Seq[Expression] = generator :: Nil + val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ec98f81041343d730a81a0cec93a3cea994109bc..fe9b2ad4a0bc3994cf792ec55e36d3a03f074a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -279,6 +279,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] trait LeafNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } private[sql] trait UnaryNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c5470a6989de7b3a42c8c388aa333941c53238d1..c4587ba677b2fa8e0162f3c6428009a46b8efcc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -36,6 +36,15 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index b8849c827048af8068ba105fc1ebbc5d555fe24d..9d758eb3b7c32124b46c373ef6b47e2021a1e5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -55,6 +55,11 @@ case class TungstenAggregate( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.length == 0 => AllTuples :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6b7b3bbbf60585a3666d801ac29cd5c27f32dd77..f19d72f06721808217d61b9b726385697432bd63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -369,6 +369,7 @@ case class MapPartitions[T, U]( uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = outputSet override def canProcessSafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -391,6 +392,7 @@ case class AppendColumns[T, U]( uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = AttributeSet(newColumns) // We are using an unsafe combiner. override def canProcessSafeRows: Boolean = false @@ -424,6 +426,7 @@ case class MapGroups[K, T, U]( groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = outputSet override def canProcessSafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -467,6 +470,7 @@ case class CoGroup[Key, Left, Right, Result]( rightGroup: Seq[Attribute], left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def producedAttributes: AttributeSet = outputSet override def canProcessSafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 4afa5f8ec1035952e3c84e6c4df95830a390918a..aa7a668e0e938be8b2b8d398f9964f7b3f0a16dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -66,6 +66,8 @@ private[sql] case class InMemoryRelation( private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { + override def producedAttributes: AttributeSet = outputSet + private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = if (_batchStats == null) { child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 78a98798eff648adc16e18e3b84835324ebac0de..359a1e7f8424ac6e91cc50f13453e61878d4638c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -15,16 +15,14 @@ * limitations under the License. */ -package test.org.apache.spark.sql +package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.UTF8String case class FastOperator(output: Seq[Attribute]) extends SparkPlan { @@ -34,6 +32,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { sparkContext.parallelize(Seq(row)) } + override def producedAttributes: AttributeSet = outputSet override def children: Seq[SparkPlan] = Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 442ae79f4f86f3062f3d4b31e842d599d0e89f05..815372f19233bb2b152a011ac666c99bbc0493b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -130,6 +130,8 @@ abstract class QueryTest extends PlanTest { checkJsonFormat(analyzedDF) + assertEmptyMissingInput(df) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -275,6 +277,18 @@ abstract class QueryTest extends PlanTest { """.stripMargin) } } + + /** + * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans. + */ + def assertEmptyMissingInput(query: Queryable): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}") + } } object QueryTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 806d2b9b0b7d405c2b7cd6befea0b63dde79ab79..8141136de5311ae223ae3fbd35778e94069bbf1b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -51,6 +51,9 @@ case class HiveTableScan( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + override def producedAttributes: AttributeSet = outputSet ++ + AttributeSet(partitionPruningPred.flatMap(_.references)) + // Retrieve the original attributes based on expression ID so that capitalization matches. val attributes = requestedAttributes.map(relation.attributeMap) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index d9b9ba4bfdfed5717b19752ad48aa449d675e9ab..a61e162f48f1bebcabc3920be27539059f201793 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -60,6 +60,8 @@ case class ScriptTransformation( override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override def producedAttributes: AttributeSet = outputSet -- inputSet + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) protected override def doExecute(): RDD[InternalRow] = {