diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3bf7382ac67a6b430ed30794507fca55c8da9f8c..5ab2b5316ab1022d2a167fbcc61262cbce9bd023 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel -import org.apache.spark.storage.StorageLevel.MEMORY_ONLY +import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) @@ -74,10 +74,14 @@ private[sql] trait CacheManager { cachedData.clear() } - /** Caches the data produced by the logical representation of the given schema rdd. */ + /** + * Caches the data produced by the logical representation of the given schema rdd. Unlike + * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing + * the in-memory columnar representation of the underlying table is expensive. + */ private[sql] def cacheQuery( query: SchemaRDD, - storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 4f79173a26f886c82138ad2a1555d0cdb1431a1c..22ab0e2613f21913dde28184b2912ef49ce35684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -38,7 +38,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } -private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -91,7 +91,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) - CachedBatch(columnBuilders.map(_.build()), stats) + CachedBatch(columnBuilders.map(_.build().array()), stats) } def hasNext = rowIterator.hasNext @@ -238,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan( def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors - val columnAccessors = - requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + val columnAccessors = requestedColumnIndices.map { batch => + ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch))) + } // Extract rows via column accessors new Iterator[Row] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c87ded81fdc2786ded030ff8e9dd5ec7e573b268..444bc95009c3182c70f30a61fd077b0dd13dee7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -55,10 +55,10 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") - cacheTable("bigData") - assert(table("bigData").count() === 1000000L) - uncacheTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(table("bigData").count() === 200000L) + table("bigData").unpersist() } test("calling .cache() should use in-memory columnar caching") {