Skip to content
Snippets Groups Projects
Commit fc1b2566 authored by xuanyuanking's avatar xuanyuanking Committed by Herman van Hovell
Browse files

[SPARK-18700][SQL] Add StripedLock for each table's relation in cache

## What changes were proposed in this pull request?

As the scenario describe in [SPARK-18700](https://issues.apache.org/jira/browse/SPARK-18700

), when cachedDataSourceTables invalided, the coming few queries will fetch all FileStatus in listLeafFiles function. In the condition of table has many partitions, these jobs will occupy much memory of driver finally may cause driver OOM.

In this patch, add StripedLock for each table's relation in cache not for the whole cachedDataSourceTables, each table's load cache operation protected by it.

## How was this patch tested?

Add a multi-thread access table test in `PartitionedTablePerfStatsSuite` and check it only loading once using metrics in `HiveCatalogMetrics`

Author: xuanyuanking <xyliyuanjian@gmail.com>

Closes #16135 from xuanyuanking/SPARK-18700.

(cherry picked from commit 24482858)
Signed-off-by: default avatarHerman van Hovell <hvanhovell@databricks.com>
parent 3080f995
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive package org.apache.spark.sql.hive
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import com.google.common.util.concurrent.Striped
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
...@@ -32,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa ...@@ -32,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa
import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.hive.orc.OrcFileFormat
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
/** /**
* Legacy catalog for interacting with the Hive metastore. * Legacy catalog for interacting with the Hive metastore.
* *
...@@ -53,6 +53,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log ...@@ -53,6 +53,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
tableIdent.table.toLowerCase) tableIdent.table.toLowerCase)
} }
/** These locks guard against multiple attempts to instantiate a table, which wastes memory. */
private val tableCreationLocks = Striped.lazyWeakLock(100)
/** Acquires a lock on the table cache for the duration of `f`. */
private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = {
val lock = tableCreationLocks.get(tableName)
lock.lock()
try f finally {
lock.unlock()
}
}
/** A cache of Spark SQL data source tables that have been accessed. */ /** A cache of Spark SQL data source tables that have been accessed. */
protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = { protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = {
val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() {
...@@ -209,72 +221,76 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log ...@@ -209,72 +221,76 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
} }
} }
val cached = getCached( withTableCreationLock(tableIdentifier, {
tableIdentifier, val cached = getCached(
rootPaths, tableIdentifier,
metastoreRelation, rootPaths,
metastoreSchema, metastoreRelation,
fileFormatClass, metastoreSchema,
bucketSpec, fileFormatClass,
Some(partitionSchema)) bucketSpec,
Some(partitionSchema))
val logicalRelation = cached.getOrElse {
val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong val logicalRelation = cached.getOrElse {
val fileCatalog = { val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong
val catalog = new CatalogFileIndex( val fileCatalog = {
sparkSession, metastoreRelation.catalogTable, sizeInBytes) val catalog = new CatalogFileIndex(
if (lazyPruningEnabled) { sparkSession, metastoreRelation.catalogTable, sizeInBytes)
catalog if (lazyPruningEnabled) {
} else { catalog
catalog.filterPartitions(Nil) // materialize all the partitions in memory } else {
catalog.filterPartitions(Nil) // materialize all the partitions in memory
}
} }
val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
val dataSchema =
StructType(metastoreSchema
.filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase)))
val relation = HadoopFsRelation(
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema,
bucketSpec = bucketSpec,
fileFormat = defaultSource,
options = options)(sparkSession = sparkSession)
val created = LogicalRelation(relation,
catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
} }
val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
val dataSchema =
StructType(metastoreSchema
.filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase)))
val relation = HadoopFsRelation(
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = dataSchema,
bucketSpec = bucketSpec,
fileFormat = defaultSource,
options = options)(sparkSession = sparkSession)
val created = LogicalRelation(relation, catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created)
created
}
logicalRelation logicalRelation
})
} else { } else {
val rootPath = metastoreRelation.hiveQlTable.getDataLocation val rootPath = metastoreRelation.hiveQlTable.getDataLocation
withTableCreationLock(tableIdentifier, {
val cached = getCached(tableIdentifier, val cached = getCached(tableIdentifier,
Seq(rootPath), Seq(rootPath),
metastoreRelation, metastoreRelation,
metastoreSchema, metastoreSchema,
fileFormatClass, fileFormatClass,
bucketSpec, bucketSpec,
None) None)
val logicalRelation = cached.getOrElse { val logicalRelation = cached.getOrElse {
val created = val created =
LogicalRelation( LogicalRelation(
DataSource( DataSource(
sparkSession = sparkSession, sparkSession = sparkSession,
paths = rootPath.toString :: Nil, paths = rootPath.toString :: Nil,
userSpecifiedSchema = Some(metastoreRelation.schema), userSpecifiedSchema = Some(metastoreRelation.schema),
bucketSpec = bucketSpec, bucketSpec = bucketSpec,
options = options, options = options,
className = fileType).resolveRelation(), className = fileType).resolveRelation(),
catalogTable = Some(metastoreRelation.catalogTable)) catalogTable = Some(metastoreRelation.catalogTable))
cachedDataSourceTables.put(tableIdentifier, created) cachedDataSourceTables.put(tableIdentifier, created)
created created
} }
logicalRelation logicalRelation
})
} }
result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) result.copy(expectedOutputAttributes = Some(metastoreRelation.output))
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive package org.apache.spark.sql.hive
import java.io.File import java.io.File
import java.util.concurrent.{Executors, TimeUnit}
import org.scalatest.BeforeAndAfterEach import org.scalatest.BeforeAndAfterEach
...@@ -395,4 +396,34 @@ class PartitionedTablePerfStatsSuite ...@@ -395,4 +396,34 @@ class PartitionedTablePerfStatsSuite
} }
} }
} }
test("SPARK-18700: table loaded only once even when resolved concurrently") {
withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") {
withTable("test") {
withTempDir { dir =>
HiveCatalogMetrics.reset()
setupPartitionedHiveTable("test", dir, 50)
// select the table in multi-threads
val executorPool = Executors.newFixedThreadPool(10)
(1 to 10).map(threadId => {
val runnable = new Runnable {
override def run(): Unit = {
spark.sql("select * from test where partCol1 = 999").count()
}
}
executorPool.execute(runnable)
None
})
executorPool.shutdown()
executorPool.awaitTermination(30, TimeUnit.SECONDS)
// check the cache hit, we use the metric of METRIC_FILES_DISCOVERED and
// METRIC_PARALLEL_LISTING_JOB_COUNT to check this, while the lock take effect,
// only one thread can really do the build, so the listing job count is 2, the other
// one is cache.load func. Also METRIC_FILES_DISCOVERED is $partition_num * 2
assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 100)
assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 2)
}
}
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment