diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index edbde5d10b47c130ef6c8955f50fe1038633461b..0407cf6a1edbc77b3e056cb6f1ae7812accecb4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -32,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.types._ - /** * Legacy catalog for interacting with the Hive metastore. * @@ -53,6 +53,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log tableIdent.table.toLowerCase) } + /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ + private val tableCreationLocks = Striped.lazyWeakLock(100) + + /** Acquires a lock on the table cache for the duration of `f`. */ + private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = { + val lock = tableCreationLocks.get(tableName) + lock.lock() + try f finally { + lock.unlock() + } + } + /** A cache of Spark SQL data source tables that have been accessed. */ protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = { val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { @@ -209,72 +221,76 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - val cached = getCached( - tableIdentifier, - rootPaths, - metastoreRelation, - metastoreSchema, - fileFormatClass, - bucketSpec, - Some(partitionSchema)) - - val logicalRelation = cached.getOrElse { - val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong - val fileCatalog = { - val catalog = new CatalogFileIndex( - sparkSession, metastoreRelation.catalogTable, sizeInBytes) - if (lazyPruningEnabled) { - catalog - } else { - catalog.filterPartitions(Nil) // materialize all the partitions in memory + withTableCreationLock(tableIdentifier, { + val cached = getCached( + tableIdentifier, + rootPaths, + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + Some(partitionSchema)) + + val logicalRelation = cached.getOrElse { + val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong + val fileCatalog = { + val catalog = new CatalogFileIndex( + sparkSession, metastoreRelation.catalogTable, sizeInBytes) + if (lazyPruningEnabled) { + catalog + } else { + catalog.filterPartitions(Nil) // materialize all the partitions in memory + } } + val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet + val dataSchema = + StructType(metastoreSchema + .filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase))) + + val relation = HadoopFsRelation( + location = fileCatalog, + partitionSchema = partitionSchema, + dataSchema = dataSchema, + bucketSpec = bucketSpec, + fileFormat = defaultSource, + options = options)(sparkSession = sparkSession) + + val created = LogicalRelation(relation, + catalogTable = Some(metastoreRelation.catalogTable)) + cachedDataSourceTables.put(tableIdentifier, created) + created } - val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet - val dataSchema = - StructType(metastoreSchema - .filterNot(field => partitionSchemaColumnNames.contains(field.name.toLowerCase))) - - val relation = HadoopFsRelation( - location = fileCatalog, - partitionSchema = partitionSchema, - dataSchema = dataSchema, - bucketSpec = bucketSpec, - fileFormat = defaultSource, - options = options)(sparkSession = sparkSession) - - val created = LogicalRelation(relation, catalogTable = Some(metastoreRelation.catalogTable)) - cachedDataSourceTables.put(tableIdentifier, created) - created - } - logicalRelation + logicalRelation + }) } else { val rootPath = metastoreRelation.hiveQlTable.getDataLocation - - val cached = getCached(tableIdentifier, - Seq(rootPath), - metastoreRelation, - metastoreSchema, - fileFormatClass, - bucketSpec, - None) - val logicalRelation = cached.getOrElse { - val created = - LogicalRelation( - DataSource( - sparkSession = sparkSession, - paths = rootPath.toString :: Nil, - userSpecifiedSchema = Some(metastoreRelation.schema), - bucketSpec = bucketSpec, - options = options, - className = fileType).resolveRelation(), + withTableCreationLock(tableIdentifier, { + val cached = getCached(tableIdentifier, + Seq(rootPath), + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + None) + val logicalRelation = cached.getOrElse { + val created = + LogicalRelation( + DataSource( + sparkSession = sparkSession, + paths = rootPath.toString :: Nil, + userSpecifiedSchema = Some(metastoreRelation.schema), + bucketSpec = bucketSpec, + options = options, + className = fileType).resolveRelation(), catalogTable = Some(metastoreRelation.catalogTable)) - cachedDataSourceTables.put(tableIdentifier, created) - created - } + cachedDataSourceTables.put(tableIdentifier, created) + created + } - logicalRelation + logicalRelation + }) } result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 65c02d473b79d2dad1d35171dbd7ff045ccccb7e..55b72c625db41886b54a8f9aaa9b3bc5d7b9c02c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.File +import java.util.concurrent.{Executors, TimeUnit} import org.scalatest.BeforeAndAfterEach @@ -395,4 +396,34 @@ class PartitionedTablePerfStatsSuite } } } + + test("SPARK-18700: table loaded only once even when resolved concurrently") { + withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { + withTable("test") { + withTempDir { dir => + HiveCatalogMetrics.reset() + setupPartitionedHiveTable("test", dir, 50) + // select the table in multi-threads + val executorPool = Executors.newFixedThreadPool(10) + (1 to 10).map(threadId => { + val runnable = new Runnable { + override def run(): Unit = { + spark.sql("select * from test where partCol1 = 999").count() + } + } + executorPool.execute(runnable) + None + }) + executorPool.shutdown() + executorPool.awaitTermination(30, TimeUnit.SECONDS) + // check the cache hit, we use the metric of METRIC_FILES_DISCOVERED and + // METRIC_PARALLEL_LISTING_JOB_COUNT to check this, while the lock take effect, + // only one thread can really do the build, so the listing job count is 2, the other + // one is cache.load func. Also METRIC_FILES_DISCOVERED is $partition_num * 2 + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 100) + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 2) + } + } + } + } }