diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 56bd5c1891e8d03df7d328cc83063240dfa58cda..479934a7afc75fd840baea4b0ad59799ef79cab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.columnar +import scala.collection.JavaConverters._ + import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils @@ -29,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.LongAccumulator +import org.apache.spark.util.CollectionAccumulator object InMemoryRelation { @@ -61,7 +63,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) + val batchStats: CollectionAccumulator[InternalRow] = + child.sqlContext.sparkContext.collectionAccumulator[InternalRow]) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) @@ -71,12 +74,21 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override lazy val statistics: Statistics = { - if (batchStats.value == 0L) { + if (batchStats.value.isEmpty) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) } else { - Statistics(sizeInBytes = batchStats.value.longValue) + // Underlying columnar RDD has been materialized, required information has also been + // collected via the `batchStats` accumulator. + val sizeOfRow: Expression = + BindReferences.bindReference( + output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), + partitionStatistics.schema) + + val sizeInBytes = + batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum + Statistics(sizeInBytes = sizeInBytes) } } @@ -127,10 +139,10 @@ case class InMemoryRelation( rowCount += 1 } - batchStats.add(totalSize) - val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) .flatMap(_.values)) + + batchStats.add(stats) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0daa29b666f62158c3a4c8f654eaa31b6e18f8e6..937839644ad5f9a3086582b141948d9bb7217b6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -232,18 +232,4 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } - - test("SPARK-17549: cached table size should be correctly calculated") { - val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() - val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) - - // Materialize the data. - val expectedAnswer = data.collect() - checkAnswer(cached, expectedAnswer) - - // Check that the right size was calculated. - assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) - } - }