Skip to content
Snippets Groups Projects
Commit 39e2bad6 authored by Marcelo Vanzin's avatar Marcelo Vanzin Committed by Yin Huai
Browse files

[SPARK-17549][SQL] Only collect table size stat in driver for cached relation.

The existing code caches all stats for all columns for each partition
in the driver; for a large relation, this causes extreme memory usage,
which leads to gc hell and application failures.

It seems that only the size in bytes of the data is actually used in the
driver, so instead just colllect that. In executors, the full stats are
still kept, but that's not a big problem; we expect the data to be distributed
and thus not really incur in too much memory pressure in each individual
executor.

There are also potential improvements on the executor side, since the data
being stored currently is very wasteful (e.g. storing boxed types vs.
primitive types for stats). But that's a separate issue.

On a mildly related change, I'm also adding code to catch exceptions in the
code generator since Janino was breaking with the test data I tried this
patch on.

Tested with unit tests and by doing a count a very wide table (20k columns)
with many partitions.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #15112 from vanzin/SPARK-17549.
parent b9323fc9
No related branches found
No related tags found
No related merge requests found
...@@ -23,6 +23,7 @@ import java.util.{Map => JavaMap} ...@@ -23,6 +23,7 @@ import java.util.{Map => JavaMap}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
...@@ -910,14 +911,19 @@ object CodeGenerator extends Logging { ...@@ -910,14 +911,19 @@ object CodeGenerator extends Logging {
codeAttrField.setAccessible(true) codeAttrField.setAccessible(true)
classes.foreach { case (_, classBytes) => classes.foreach { case (_, classBytes) =>
CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length)
val cf = new ClassFile(new ByteArrayInputStream(classBytes)) try {
cf.methodInfos.asScala.foreach { method => val cf = new ClassFile(new ByteArrayInputStream(classBytes))
method.getAttributes().foreach { a => cf.methodInfos.asScala.foreach { method =>
if (a.getClass.getName == codeAttr.getName) { method.getAttributes().foreach { a =>
CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( if (a.getClass.getName == codeAttr.getName) {
codeAttrField.get(a).asInstanceOf[Array[Byte]].length) CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(
codeAttrField.get(a).asInstanceOf[Array[Byte]].length)
}
} }
} }
} catch {
case NonFatal(e) =>
logWarning("Error calculating stats of compiled class.", e)
} }
} }
} }
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.columnar package org.apache.spark.sql.execution.columnar
import scala.collection.JavaConverters._
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
import org.apache.spark.network.util.JavaUtils import org.apache.spark.network.util.JavaUtils
...@@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical ...@@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.CollectionAccumulator import org.apache.spark.util.LongAccumulator
object InMemoryRelation { object InMemoryRelation {
...@@ -63,8 +61,7 @@ case class InMemoryRelation( ...@@ -63,8 +61,7 @@ case class InMemoryRelation(
@transient child: SparkPlan, @transient child: SparkPlan,
tableName: Option[String])( tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val batchStats: CollectionAccumulator[InternalRow] = val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
child.sqlContext.sparkContext.collectionAccumulator[InternalRow])
extends logical.LeafNode with MultiInstanceRelation { extends logical.LeafNode with MultiInstanceRelation {
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child)
...@@ -74,21 +71,12 @@ case class InMemoryRelation( ...@@ -74,21 +71,12 @@ case class InMemoryRelation(
@transient val partitionStatistics = new PartitionStatistics(output) @transient val partitionStatistics = new PartitionStatistics(output)
override lazy val statistics: Statistics = { override lazy val statistics: Statistics = {
if (batchStats.value.isEmpty) { if (batchStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, no useful statistics information // Underlying columnar RDD hasn't been materialized, no useful statistics information
// available, return the default statistics. // available, return the default statistics.
Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)
} else { } else {
// Underlying columnar RDD has been materialized, required information has also been Statistics(sizeInBytes = batchStats.value.longValue)
// collected via the `batchStats` accumulator.
val sizeOfRow: Expression =
BindReferences.bindReference(
output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add),
partitionStatistics.schema)
val sizeInBytes =
batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
Statistics(sizeInBytes = sizeInBytes)
} }
} }
...@@ -139,10 +127,10 @@ case class InMemoryRelation( ...@@ -139,10 +127,10 @@ case class InMemoryRelation(
rowCount += 1 rowCount += 1
} }
batchStats.add(totalSize)
val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
.flatMap(_.values)) .flatMap(_.values))
batchStats.add(stats)
CachedBatch(rowCount, columnBuilders.map { builder => CachedBatch(rowCount, columnBuilders.map { builder =>
JavaUtils.bufferToArray(builder.build()) JavaUtils.bufferToArray(builder.build())
}, stats) }, stats)
......
...@@ -232,4 +232,18 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { ...@@ -232,4 +232,18 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
val columnTypes2 = List.fill(length2)(IntegerType) val columnTypes2 = List.fill(length2)(IntegerType)
val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2)
} }
test("SPARK-17549: cached table size should be correctly calculated") {
val data = spark.sparkContext.parallelize(1 to 10, 5).toDF()
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None)
// Materialize the data.
val expectedAnswer = data.collect()
checkAnswer(cached, expectedAnswer)
// Check that the right size was calculated.
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment