From 780586a9f2400c3fdfdb9a6b954001a3c9663941 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Wed, 12 Jul 2017 09:23:54 -0700 Subject: [PATCH] [SPARK-17701][SQL] Refactor RowDataSourceScanExec so its sameResult call does not compare strings ## What changes were proposed in this pull request? Currently, `RowDataSourceScanExec` and `FileSourceScanExec` rely on a "metadata" string map to implement equality comparison, since the RDDs they depend on cannot be directly compared. This has resulted in a number of correctness bugs around exchange reuse, e.g. SPARK-17673 and SPARK-16818. To make these comparisons less brittle, we should refactor these classes to compare constructor parameters directly instead of relying on the metadata map. This PR refactors `RowDataSourceScanExec`, `FileSourceScanExec` will be fixed in the follow-up PR. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #18600 from cloud-fan/minor. --- .../sql/execution/DataSourceScanExec.scala | 65 +++++++++---------- .../spark/sql/execution/SparkPlan.scala | 5 -- .../spark/sql/execution/SparkPlanInfo.scala | 4 +- .../datasources/DataSourceStrategy.scala | 57 ++++++---------- .../sql/execution/ui/SparkPlanGraph.scala | 5 +- 5 files changed, 56 insertions(+), 80 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a0def68d88..588c937a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -33,21 +33,23 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation - val metastoreTableIdentifier: Option[TableIdentifier] + val tableIdentifier: Option[TableIdentifier] protected val nodeNamePrefix: String = "" override val nodeName: String = { - s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}" + s"Scan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" } + // Metadata that describes more details of this scan. + protected def metadata: Map[String, String] + override def simpleString: String = { val metadataEntries = metadata.toSeq.sorted.map { case (key, value) => @@ -73,34 +75,25 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { /** Physical plan node for scanning data from a relation. */ case class RowDataSourceScanExec( - output: Seq[Attribute], + fullOutput: Seq[Attribute], + requiredColumnsIndex: Seq[Int], + filters: Set[Filter], + handledFilters: Set[Filter], rdd: RDD[InternalRow], @transient relation: BaseRelation, - override val outputPartitioning: Partitioning, - override val metadata: Map[String, String], - override val metastoreTableIdentifier: Option[TableIdentifier]) + override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec { + def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput) + override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - val outputUnsafeRows = relation match { - case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] => - !SparkSession.getActiveSession.get.sessionState.conf.getConf( - SQLConf.PARQUET_VECTORIZED_READER_ENABLED) - case _: HadoopFsRelation => true - case _ => false - } - protected override def doExecute(): RDD[InternalRow] = { - val unsafeRow = if (outputUnsafeRows) { - rdd - } else { - rdd.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map(proj) - } + val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map(proj) } val numOutputRows = longMetric("numOutputRows") @@ -126,24 +119,31 @@ case class RowDataSourceScanExec( ctx.INPUT_ROW = row ctx.currentVars = null val columnsRowInput = exprRows.map(_.genCode(ctx)) - val inputRow = if (outputUnsafeRows) row else null s""" |while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} + | ${consume(ctx, columnsRowInput, null).trim} | if (shouldStop()) return; |} """.stripMargin } - // Only care about `relation` and `metadata` when canonicalizing. + override val metadata: Map[String, String] = { + val markedFilters = for (filter <- filters) yield { + if (handledFilters.contains(filter)) s"*$filter" else s"$filter" + } + Map( + "ReadSchema" -> output.toStructType.catalogString, + "PushedFilters" -> markedFilters.mkString("[", ", ", "]")) + } + + // Don't care about `rdd` and `tableIdentifier` when canonicalizing. override lazy val canonicalized: SparkPlan = copy( - output.map(QueryPlan.normalizeExprId(_, output)), + fullOutput.map(QueryPlan.normalizeExprId(_, fullOutput)), rdd = null, - outputPartitioning = null, - metastoreTableIdentifier = None) + tableIdentifier = None) } /** @@ -154,7 +154,7 @@ case class RowDataSourceScanExec( * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. * @param dataFilters Filters on non-partition columns. - * @param metastoreTableIdentifier identifier for the table in the metastore. + * @param tableIdentifier identifier for the table in the metastore. */ case class FileSourceScanExec( @transient relation: HadoopFsRelation, @@ -162,7 +162,7 @@ case class FileSourceScanExec( requiredSchema: StructType, partitionFilters: Seq[Expression], dataFilters: Seq[Expression], - override val metastoreTableIdentifier: Option[TableIdentifier]) + override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { val supportsBatch: Boolean = relation.fileFormat.supportBatch( @@ -261,7 +261,6 @@ case class FileSourceScanExec( private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - // These metadata values make scan plans uniquely identifiable for equality checking. override val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") val location = relation.location diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index db975614c9..c7277c21ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -71,11 +71,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } - /** - * @return Metadata that describes more details of this SparkPlan. - */ - def metadata: Map[String, String] = Map.empty - /** * @return All metrics containing metrics of this SparkPlan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 7aa93126fd..06b69625fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -31,7 +31,6 @@ class SparkPlanInfo( val nodeName: String, val simpleString: String, val children: Seq[SparkPlanInfo], - val metadata: Map[String, String], val metrics: Seq[SQLMetricInfo]) { override def hashCode(): Int = { @@ -58,7 +57,6 @@ private[execution] object SparkPlanInfo { new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) } - new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), - plan.metadata, metrics) + new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e05a8d5f02..587b9b450e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.concurrent.Callable -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -288,10 +286,11 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case l @ LogicalRelation(baseRelation: TableScan, _, _) => RowDataSourceScanExec( l.output, + l.output.indices, + Set.empty, + Set.empty, toCatalystRDD(l, baseRelation.buildScan()), baseRelation, - UnknownPartitioning(0), - Map.empty, None) :: Nil case _ => Nil @@ -354,36 +353,10 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with val (unhandledPredicates, pushedFilters, handledFilters) = selectFilters(relation.relation, candidatePredicates) - // A set of column attributes that are only referenced by pushed down filters. We can eliminate - // them from requested columns. - val handledSet = { - val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains) - val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references)) - AttributeSet(handledPredicates.flatMap(_.references)) -- - (projectSet ++ unhandledSet).map(relation.attributeMap) - } - // Combines all Catalyst filter `Expression`s that are either not convertible to data source // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) - // These metadata values make scan plans uniquely identifiable for equality checking. - // TODO(SPARK-17701) using strings for equality checking is brittle - val metadata: Map[String, String] = { - val pairs = ArrayBuffer.empty[(String, String)] - - // Mark filters which are handled by the underlying DataSource with an Astrisk - if (pushedFilters.nonEmpty) { - val markedFilters = for (filter <- pushedFilters) yield { - if (handledFilters.contains(filter)) s"*$filter" else s"$filter" - } - pairs += ("PushedFilters" -> markedFilters.mkString("[", ", ", "]")) - } - pairs += ("ReadSchema" -> - StructType.fromAttributes(projects.map(_.toAttribute)).catalogString) - pairs.toMap - } - if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { @@ -395,24 +368,36 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with .asInstanceOf[Seq[Attribute]] // Match original case of attributes. .map(relation.attributeMap) - // Don't request columns that are only referenced by pushed filters. - .filterNot(handledSet.contains) val scan = RowDataSourceScanExec( - projects.map(_.toAttribute), + relation.output, + requestedColumns.map(relation.output.indexOf), + pushedFilters.toSet, + handledFilters, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, UnknownPartitioning(0), metadata, + relation.relation, relation.catalogTable.map(_.identifier)) filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan) } else { + // A set of column attributes that are only referenced by pushed down filters. We can + // eliminate them from requested columns. + val handledSet = { + val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains) + val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references)) + AttributeSet(handledPredicates.flatMap(_.references)) -- + (projectSet ++ unhandledSet).map(relation.attributeMap) + } // Don't request columns that are only referenced by pushed filters. val requestedColumns = (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq val scan = RowDataSourceScanExec( - requestedColumns, + relation.output, + requestedColumns.map(relation.output.indexOf), + pushedFilters.toSet, + handledFilters, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, UnknownPartitioning(0), metadata, + relation.relation, relation.catalogTable.map(_.identifier)) execution.ProjectExec( projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 9d4ebcce4d..884f945815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -113,7 +113,7 @@ object SparkPlanGraph { } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, - planInfo.simpleString, planInfo.metadata, metrics) + planInfo.simpleString, metrics) if (subgraph == null) { nodes += node } else { @@ -143,7 +143,6 @@ private[ui] class SparkPlanGraphNode( val id: Long, val name: String, val desc: String, - val metadata: Map[String, String], val metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { @@ -177,7 +176,7 @@ private[ui] class SparkPlanGraphCluster( desc: String, val nodes: mutable.ArrayBuffer[SparkPlanGraphNode], metrics: Seq[SQLPlanMetric]) - extends SparkPlanGraphNode(id, name, desc, Map.empty, metrics) { + extends SparkPlanGraphNode(id, name, desc, metrics) { override def makeDotNode(metricsValue: Map[Long, String]): String = { val duration = metrics.filter(_.name.startsWith(WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) -- GitLab