diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 2e7abac1f1bdb1c8cf2d06a4ebfa1a8a520f093f..3c9439b2e9a5225876ccc5f7b405fd7c1fde7e40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -46,7 +46,7 @@ private[sql] trait CacheManager { def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(table(tableName)) + def cacheTable(tableName: String): Unit = cacheQuery(table(tableName), Some(tableName)) /** Removes the specified table from the in-memory cache. */ def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) @@ -81,6 +81,7 @@ private[sql] trait CacheManager { */ private[sql] def cacheQuery( query: SchemaRDD, + tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed if (lookupCachedData(planToCache).nonEmpty) { @@ -90,7 +91,11 @@ private[sql] trait CacheManager { CachedData( planToCache, InMemoryRelation( - useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan)) + useCompression, + columnBatchSize, + storageLevel, + query.queryExecution.executedPlan, + tableName)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index fbec2f9f4b2c137298a0d3ead800e15e870e7ad3..904a276ef3ffba8ae28cac4381ee80f2ef4e95f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -475,7 +475,7 @@ class SchemaRDD( } override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheQuery(this, newLevel) + sqlContext.cacheQuery(this, None, newLevel) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 881d32b105c5f14c7ba2499c1877332b133c9cf9..0cebe823b27073a19799b0015bdbc13b4f8365d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -36,8 +36,9 @@ private[sql] object InMemoryRelation { useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() + child: SparkPlan, + tableName: Option[String]): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) @@ -47,7 +48,8 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan)( + child: SparkPlan, + tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, private var _statistics: Statistics = null) extends LogicalPlan with MultiInstanceRelation { @@ -137,13 +139,13 @@ private[sql] case class InMemoryRelation( } }.persist(storageLevel) - cached.setName(child.toString) + cached.setName(tableName.map(n => s"In-memory table $n").getOrElse(child.toString)) _cachedColumnBuffers = cached } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child)( + newOutput, useCompression, batchSize, storageLevel, child, tableName)( _cachedColumnBuffers, statisticsToBePropagated) } @@ -155,7 +157,8 @@ private[sql] case class InMemoryRelation( useCompression, batchSize, storageLevel, - child)( + child, + tableName)( _cachedColumnBuffers, statisticsToBePropagated).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index f23b9c48cfb40ee9bfe25fb87386d6632e27ba78..afe3f3f07440cd42a15115cc54518a19b1aea335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -160,12 +160,11 @@ case class CacheTableCommand( import sqlContext._ plan.foreach(_.registerTempTable(tableName)) - val schemaRDD = table(tableName) - schemaRDD.cache() + cacheTable(tableName) if (!isLazy) { // Performs eager caching - schemaRDD.count() + table(tableName).count() } Seq.empty[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 765fa827763410c631e5d0e23bbfe6e678a2e2f9..042210176ad7e64f26ccfaaea6949e82a91c2b44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -123,7 +123,7 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => r + case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 15903d07df29a5a2b287d59f3db21ab8f6e59878..fc95dccc74e27e19a3d7520165d99c3c99aff877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -29,7 +29,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) } @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("projection") { val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -53,7 +53,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq)