Skip to content
Snippets Groups Projects
Commit ad3cc131 authored by Wenchen Fan's avatar Wenchen Fan
Browse files

[SPARK-20245][SQL][MINOR] pass output to LogicalRelation directly

## What changes were proposed in this pull request?

Currently `LogicalRelation` has a `expectedOutputAttributes` parameter, which makes it hard to reason about what the actual output is. Like other leaf nodes, `LogicalRelation` should also take `output` as a parameter, to simplify the logic

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17552 from cloud-fan/minor.
parent 626b4caf
No related branches found
No related tags found
No related merge requests found
Showing with 49 additions and 55 deletions
...@@ -27,7 +27,7 @@ import com.google.common.base.Objects ...@@ -27,7 +27,7 @@ import com.google.common.base.Objects
import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal}
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.catalyst.util.quoteIdentifier
...@@ -403,14 +403,14 @@ object CatalogTypes { ...@@ -403,14 +403,14 @@ object CatalogTypes {
*/ */
case class CatalogRelation( case class CatalogRelation(
tableMeta: CatalogTable, tableMeta: CatalogTable,
dataCols: Seq[Attribute], dataCols: Seq[AttributeReference],
partitionCols: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { partitionCols: Seq[AttributeReference]) extends LeafNode with MultiInstanceRelation {
assert(tableMeta.identifier.database.isDefined) assert(tableMeta.identifier.database.isDefined)
assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType))
assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) assert(tableMeta.dataSchema.sameType(dataCols.toStructType))
// The partition column should always appear after data columns. // The partition column should always appear after data columns.
override def output: Seq[Attribute] = dataCols ++ partitionCols override def output: Seq[AttributeReference] = dataCols ++ partitionCols
def isPartitioned: Boolean = partitionCols.nonEmpty def isPartitioned: Boolean = partitionCols.nonEmpty
......
...@@ -231,16 +231,17 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] ...@@ -231,16 +231,17 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
options = table.storage.properties ++ pathOption, options = table.storage.properties ++ pathOption,
catalogTable = Some(table)) catalogTable = Some(table))
LogicalRelation( LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
dataSource.resolveRelation(checkFilesExist = false),
catalogTable = Some(table))
} }
}).asInstanceOf[LogicalRelation] }).asInstanceOf[LogicalRelation]
// It's possible that the table schema is empty and need to be inferred at runtime. We should if (r.output.isEmpty) {
// not specify expected outputs for this case. // It's possible that the table schema is empty and need to be inferred at runtime. For this
val expectedOutputs = if (r.output.isEmpty) None else Some(r.output) // case, we don't need to change the output of the cached plan.
plan.copy(expectedOutputAttributes = expectedOutputs) plan
} else {
plan.copy(output = r.output)
}
} }
override def apply(plan: LogicalPlan): LogicalPlan = plan transform { override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
......
...@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources ...@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.BaseRelation
...@@ -26,31 +26,13 @@ import org.apache.spark.util.Utils ...@@ -26,31 +26,13 @@ import org.apache.spark.util.Utils
/** /**
* Used to link a [[BaseRelation]] in to a logical query plan. * Used to link a [[BaseRelation]] in to a logical query plan.
*
* Note that sometimes we need to use `LogicalRelation` to replace an existing leaf node without
* changing the output attributes' IDs. The `expectedOutputAttributes` parameter is used for
* this purpose. See https://issues.apache.org/jira/browse/SPARK-10741 for more details.
*/ */
case class LogicalRelation( case class LogicalRelation(
relation: BaseRelation, relation: BaseRelation,
expectedOutputAttributes: Option[Seq[Attribute]] = None, output: Seq[AttributeReference],
catalogTable: Option[CatalogTable] = None) catalogTable: Option[CatalogTable])
extends LeafNode with MultiInstanceRelation { extends LeafNode with MultiInstanceRelation {
override val output: Seq[AttributeReference] = {
val attrs = relation.schema.toAttributes
expectedOutputAttributes.map { expectedAttrs =>
assert(expectedAttrs.length == attrs.length)
attrs.zip(expectedAttrs).map {
// We should respect the attribute names provided by base relation and only use the
// exprId in `expectedOutputAttributes`.
// The reason is that, some relations(like parquet) will reconcile attribute names to
// workaround case insensitivity issue.
case (attr, expected) => attr.withExprId(expected.exprId)
}
}.getOrElse(attrs)
}
// Logical Relations are distinct if they have different output for the sake of transformations. // Logical Relations are distinct if they have different output for the sake of transformations.
override def equals(other: Any): Boolean = other match { override def equals(other: Any): Boolean = other match {
case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output
...@@ -87,11 +69,8 @@ case class LogicalRelation( ...@@ -87,11 +69,8 @@ case class LogicalRelation(
* unique expression ids. We respect the `expectedOutputAttributes` and create * unique expression ids. We respect the `expectedOutputAttributes` and create
* new instances of attributes in it. * new instances of attributes in it.
*/ */
override def newInstance(): this.type = { override def newInstance(): LogicalRelation = {
LogicalRelation( this.copy(output = output.map(_.newInstance()))
relation,
expectedOutputAttributes.map(_.map(_.newInstance())),
catalogTable).asInstanceOf[this.type]
} }
override def refresh(): Unit = relation match { override def refresh(): Unit = relation match {
...@@ -101,3 +80,11 @@ case class LogicalRelation( ...@@ -101,3 +80,11 @@ case class LogicalRelation(
override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
} }
object LogicalRelation {
def apply(relation: BaseRelation): LogicalRelation =
LogicalRelation(relation, relation.schema.toAttributes, None)
def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation =
LogicalRelation(relation, relation.schema.toAttributes, Some(table))
}
...@@ -59,9 +59,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { ...@@ -59,9 +59,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation = val prunedFsRelation =
fsRelation.copy(location = prunedFileIndex)(sparkSession) fsRelation.copy(location = prunedFileIndex)(sparkSession)
val prunedLogicalRelation = logicalRelation.copy( val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation)
relation = prunedFsRelation,
expectedOutputAttributes = Some(logicalRelation.output))
// Keep partition-pruning predicates so that they are visible in physical planning // Keep partition-pruning predicates so that they are visible in physical planning
val filterExpression = filters.reduceLeft(And) val filterExpression = filters.reduceLeft(And)
......
...@@ -75,13 +75,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { ...@@ -75,13 +75,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
|USING ${classOf[TestOptionsSource].getCanonicalName} |USING ${classOf[TestOptionsSource].getCanonicalName}
|OPTIONS (PATH '/tmp/path') |OPTIONS (PATH '/tmp/path')
""".stripMargin) """.stripMargin)
assert(getPathOption("src") == Some("file:/tmp/path")) assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path")))
} }
// should exist even path option is not specified when creating table // should exist even path option is not specified when creating table
withTable("src") { withTable("src") {
sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}")
assert(getPathOption("src") == Some(CatalogUtils.URIToString(defaultTablePath("src")))) assert(getPathOption("src").map(makeQualifiedPath) == Some(defaultTablePath("src")))
} }
} }
...@@ -95,9 +95,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { ...@@ -95,9 +95,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
|OPTIONS (PATH '$p') |OPTIONS (PATH '$p')
|AS SELECT 1 |AS SELECT 1
""".stripMargin) """.stripMargin)
assert(CatalogUtils.stringToURI( assert(
spark.table("src").schema.head.metadata.getString("path")) == spark.table("src").schema.head.metadata.getString("path") ==
makeQualifiedPath(p.getAbsolutePath)) p.getAbsolutePath)
} }
} }
...@@ -109,8 +109,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { ...@@ -109,8 +109,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
|USING ${classOf[TestOptionsSource].getCanonicalName} |USING ${classOf[TestOptionsSource].getCanonicalName}
|AS SELECT 1 |AS SELECT 1
""".stripMargin) """.stripMargin)
assert(spark.table("src").schema.head.metadata.getString("path") == assert(
CatalogUtils.URIToString(defaultTablePath("src"))) makeQualifiedPath(spark.table("src").schema.head.metadata.getString("path")) ==
defaultTablePath("src"))
} }
} }
...@@ -122,13 +123,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { ...@@ -122,13 +123,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
|USING ${classOf[TestOptionsSource].getCanonicalName} |USING ${classOf[TestOptionsSource].getCanonicalName}
|OPTIONS (PATH '/tmp/path')""".stripMargin) |OPTIONS (PATH '/tmp/path')""".stripMargin)
sql("ALTER TABLE src SET LOCATION '/tmp/path2'") sql("ALTER TABLE src SET LOCATION '/tmp/path2'")
assert(getPathOption("src") == Some("/tmp/path2")) assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path2")))
} }
withTable("src", "src2") { withTable("src", "src2") {
sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}")
sql("ALTER TABLE src RENAME TO src2") sql("ALTER TABLE src RENAME TO src2")
assert(getPathOption("src2") == Some(CatalogUtils.URIToString(defaultTablePath("src2")))) assert(getPathOption("src2").map(makeQualifiedPath) == Some(defaultTablePath("src2")))
} }
} }
......
...@@ -175,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log ...@@ -175,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
bucketSpec = None, bucketSpec = None,
fileFormat = fileFormat, fileFormat = fileFormat,
options = options)(sparkSession = sparkSession) options = options)(sparkSession = sparkSession)
val created = LogicalRelation(fsRelation, catalogTable = Some(updatedTable)) val created = LogicalRelation(fsRelation, updatedTable)
tableRelationCache.put(tableIdentifier, created) tableRelationCache.put(tableIdentifier, created)
created created
} }
...@@ -203,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log ...@@ -203,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
bucketSpec = None, bucketSpec = None,
options = options, options = options,
className = fileType).resolveRelation(), className = fileType).resolveRelation(),
catalogTable = Some(updatedTable)) table = updatedTable)
tableRelationCache.put(tableIdentifier, created) tableRelationCache.put(tableIdentifier, created)
created created
...@@ -212,7 +212,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log ...@@ -212,7 +212,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
logicalRelation logicalRelation
}) })
} }
result.copy(expectedOutputAttributes = Some(relation.output)) // The inferred schema may have different filed names as the table schema, we should respect
// it, but also respect the exprId in table relation output.
assert(result.output.length == relation.output.length &&
result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType })
val newOutput = result.output.zip(relation.output).map {
case (a1, a2) => a1.withExprId(a2.exprId)
}
result.copy(output = newOutput)
} }
private def inferIfNeeded( private def inferIfNeeded(
......
...@@ -329,7 +329,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto ...@@ -329,7 +329,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
fileFormat = new ParquetFileFormat(), fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark) options = Map.empty)(sparkSession = spark)
val plan = LogicalRelation(relation, catalogTable = Some(tableMeta)) val plan = LogicalRelation(relation, tableMeta)
spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan))
assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined)
...@@ -342,7 +342,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto ...@@ -342,7 +342,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
bucketSpec = None, bucketSpec = None,
fileFormat = new ParquetFileFormat(), fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark) options = Map.empty)(sparkSession = spark)
val samePlan = LogicalRelation(sameRelation, catalogTable = Some(tableMeta)) val samePlan = LogicalRelation(sameRelation, tableMeta)
assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined)
} }
......
...@@ -58,7 +58,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te ...@@ -58,7 +58,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
fileFormat = new ParquetFileFormat(), fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark) options = Map.empty)(sparkSession = spark)
val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta)) val logicalRelation = LogicalRelation(relation, tableMeta)
val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze
val optimized = Optimize.execute(query) val optimized = Optimize.execute(query)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment