Skip to content
Snippets Groups Projects
Commit fc1b2566 authored by xuanyuanking's avatar xuanyuanking Committed by Herman van Hovell
Browse files

[SPARK-18700][SQL] Add StripedLock for each table's relation in cache

## What changes were proposed in this pull request?

As the scenario describe in [SPARK-18700](https://issues.apache.org/jira/browse/SPARK-18700

), when cachedDataSourceTables invalided, the coming few queries will fetch all FileStatus in listLeafFiles function. In the condition of table has many partitions, these jobs will occupy much memory of driver finally may cause driver OOM.

In this patch, add StripedLock for each table's relation in cache not for the whole cachedDataSourceTables, each table's load cache operation protected by it.

## How was this patch tested?

Add a multi-thread access table test in `PartitionedTablePerfStatsSuite` and check it only loading once using metrics in `HiveCatalogMetrics`

Author: xuanyuanking <xyliyuanjian@gmail.com>

Closes #16135 from xuanyuanking/SPARK-18700.

(cherry picked from commit 24482858)
Signed-off-by: default avatarHerman van Hovell <hvanhovell@databricks.com>
parent 3080f995
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import com.google.common.util.concurrent.Striped
import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
......@@ -32,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa
import org.apache.spark.sql.hive.orc.OrcFileFormat
import org.apache.spark.sql.types._
/**
* Legacy catalog for interacting with the Hive metastore.
*
......@@ -53,6 +53,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
tableIdent.table.toLowerCase)
}
/** These locks guard against multiple attempts to instantiate a table, which wastes memory. */
private val tableCreationLocks = Striped.lazyWeakLock(100)
/** Acquires a lock on the table cache for the duration of `f`. */
private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = {
val lock = tableCreationLocks.get(tableName)
lock.lock()
try f finally {
lock.unlock()
}
}
/** A cache of Spark SQL data source tables that have been accessed. */
protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = {
val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() {
......@@ -209,72 +221,76 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
}
val cached = getCached(
tableIdentifier,
rootPaths,
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
Some(partitionSchema))
val logicalRelation = cached.getOrElse {
val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong
val fileCatalog = {
val catalog = new CatalogFileIndex(
sparkSession, metastoreRelation.catalogTable, sizeInBytes)
if (lazyPruningEnabled) {
catalog
} else {
catalog.filterPartitions(Nil) // materialize all the partitions in memory
withTableCreationLock(tableIdentifier, {
val cached = getCached(
tableIdentifier,
rootPaths,
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
Some(partitionSchema))
val logicalRelation = cached.getOrElse {
val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong
val fileCatalog = {
val catalog = new CatalogFileIndex(
sparkSession, metastoreRelation.catalogTable, sizeInBytes)
if (lazyPruningEnabled) {
catalog
} else {
catalog.filterPartitions(Nil) // materialize all the partitions in memory
}
}
val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
val dataSchema =
StructType(metastoreSchema
.filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase)))
val relation = HadoopFsRelation(
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema,
bucketSpec = bucketSpec,
fileFormat = defaultSource,
options = options)(sparkSession = sparkSession)
val created = LogicalRelation(relation,
catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
}
val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
val dataSchema =
StructType(metastoreSchema
.filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase)))
val relation = HadoopFsRelation(
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema,
bucketSpec = bucketSpec,
fileFormat = defaultSource,
options = options)(sparkSession = sparkSession)
val created = LogicalRelation(relation, catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
}
logicalRelation
logicalRelation
})
} else {
val rootPath = metastoreRelation.hiveQlTable.getDataLocation
val cached = getCached(tableIdentifier,
Seq(rootPath),
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
None)
val logicalRelation = cached.getOrElse {
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
userSpecifiedSchema = Some(metastoreRelation.schema),
bucketSpec = bucketSpec,
options = options,
className = fileType).resolveRelation(),
withTableCreationLock(tableIdentifier, {
val cached = getCached(tableIdentifier,
Seq(rootPath),
metastoreRelation,
metastoreSchema,
fileFormatClass,
bucketSpec,
None)
val logicalRelation = cached.getOrElse {
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
userSpecifiedSchema = Some(metastoreRelation.schema),
bucketSpec = bucketSpec,
options = options,
className = fileType).resolveRelation(),
catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
}
cachedDataSourceTables.put(tableIdentifier, created)
created
}
logicalRelation
logicalRelation
})
}
result.copy(expectedOutputAttributes = Some(metastoreRelation.output))
}
......
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive
import java.io.File
import java.util.concurrent.{Executors, TimeUnit}
import org.scalatest.BeforeAndAfterEach
......@@ -395,4 +396,34 @@ class PartitionedTablePerfStatsSuite
}
}
}
test("SPARK-18700: table loaded only once even when resolved concurrently") {
withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") {
withTable("test") {
withTempDir { dir =>
HiveCatalogMetrics.reset()
setupPartitionedHiveTable("test", dir, 50)
// select the table in multi-threads
val executorPool = Executors.newFixedThreadPool(10)
(1 to 10).map(threadId => {
val runnable = new Runnable {
override def run(): Unit = {
spark.sql("select * from test where partCol1 = 999").count()
}
}
executorPool.execute(runnable)
None
})
executorPool.shutdown()
executorPool.awaitTermination(30, TimeUnit.SECONDS)
// check the cache hit, we use the metric of METRIC_FILES_DISCOVERED and
// METRIC_PARALLEL_LISTING_JOB_COUNT to check this, while the lock take effect,
// only one thread can really do the build, so the listing job count is 2, the other
// one is cache.load func. Also METRIC_FILES_DISCOVERED is $partition_num * 2
assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 100)
assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 2)
}
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment