diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f982c222af5f0f2c22c8f932a0bac3831e9e0d32..33b9b804fc6016f2c1e5ddb2a05ece21835edcda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -23,6 +23,7 @@ import java.util.{Map => JavaMap}
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
 
 import com.google.common.cache.{CacheBuilder, CacheLoader}
 import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
@@ -910,14 +911,19 @@ object CodeGenerator extends Logging {
     codeAttrField.setAccessible(true)
     classes.foreach { case (_, classBytes) =>
       CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length)
-      val cf = new ClassFile(new ByteArrayInputStream(classBytes))
-      cf.methodInfos.asScala.foreach { method =>
-        method.getAttributes().foreach { a =>
-          if (a.getClass.getName == codeAttr.getName) {
-            CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(
-              codeAttrField.get(a).asInstanceOf[Array[Byte]].length)
+      try {
+        val cf = new ClassFile(new ByteArrayInputStream(classBytes))
+        cf.methodInfos.asScala.foreach { method =>
+          method.getAttributes().foreach { a =>
+            if (a.getClass.getName == codeAttr.getName) {
+              CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(
+                codeAttrField.get(a).asInstanceOf[Array[Byte]].length)
+            }
           }
         }
+      } catch {
+        case NonFatal(e) =>
+          logWarning("Error calculating stats of compiled class.", e)
       }
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 479934a7afc75fd840baea4b0ad59799ef79cab6..56bd5c1891e8d03df7d328cc83063240dfa58cda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution.columnar
 
-import scala.collection.JavaConverters._
-
 import org.apache.commons.lang3.StringUtils
 
 import org.apache.spark.network.util.JavaUtils
@@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.CollectionAccumulator
+import org.apache.spark.util.LongAccumulator
 
 
 object InMemoryRelation {
@@ -63,8 +61,7 @@ case class InMemoryRelation(
     @transient child: SparkPlan,
     tableName: Option[String])(
     @transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
-    val batchStats: CollectionAccumulator[InternalRow] =
-      child.sqlContext.sparkContext.collectionAccumulator[InternalRow])
+    val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
   extends logical.LeafNode with MultiInstanceRelation {
 
   override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child)
@@ -74,21 +71,12 @@ case class InMemoryRelation(
   @transient val partitionStatistics = new PartitionStatistics(output)
 
   override lazy val statistics: Statistics = {
-    if (batchStats.value.isEmpty) {
+    if (batchStats.value == 0L) {
       // Underlying columnar RDD hasn't been materialized, no useful statistics information
       // available, return the default statistics.
       Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)
     } else {
-      // Underlying columnar RDD has been materialized, required information has also been
-      // collected via the `batchStats` accumulator.
-      val sizeOfRow: Expression =
-        BindReferences.bindReference(
-          output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add),
-          partitionStatistics.schema)
-
-      val sizeInBytes =
-        batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
-      Statistics(sizeInBytes = sizeInBytes)
+      Statistics(sizeInBytes = batchStats.value.longValue)
     }
   }
 
@@ -139,10 +127,10 @@ case class InMemoryRelation(
             rowCount += 1
           }
 
+          batchStats.add(totalSize)
+
           val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
             .flatMap(_.values))
-
-          batchStats.add(stats)
           CachedBatch(rowCount, columnBuilders.map { builder =>
             JavaUtils.bufferToArray(builder.build())
           }, stats)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 937839644ad5f9a3086582b141948d9bb7217b6b..0daa29b666f62158c3a4c8f654eaa31b6e18f8e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -232,4 +232,18 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
     val columnTypes2 = List.fill(length2)(IntegerType)
     val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2)
   }
+
+  test("SPARK-17549: cached table size should be correctly calculated") {
+    val data = spark.sparkContext.parallelize(1 to 10, 5).toDF()
+    val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
+    val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None)
+
+    // Materialize the data.
+    val expectedAnswer = data.collect()
+    checkAnswer(cached, expectedAnswer)
+
+    // Check that the right size was calculated.
+    assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
+  }
+
 }