diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 470307bd940ad9f54aeb2fc937e120f827ab3953..bc7e73ae1ba8773d8b985413aa378adfe37cc1b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.columnar
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -53,219 +53,288 @@ private[columnar] sealed trait ColumnStats extends Serializable {
   /**
    * Gathers statistics information from `row(ordinal)`.
    */
-  def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    if (row.isNullAt(ordinal)) {
-      nullCount += 1
-      // 4 bytes for null position
-      sizeInBytes += 4
-    }
+  def gatherStats(row: InternalRow, ordinal: Int): Unit
+
+  /**
+   * Gathers statistics information on `null`.
+   */
+  def gatherNullStats(): Unit = {
+    nullCount += 1
+    // 4 bytes for null position
+    sizeInBytes += 4
     count += 1
   }
 
   /**
-   * Column statistics represented as a single row, currently including closed lower bound, closed
+   * Column statistics represented as an array, currently including closed lower bound, closed
    * upper bound and null count.
    */
-  def collectedStatistics: GenericInternalRow
+  def collectedStatistics: Array[Any]
 }
 
 /**
  * A no-op ColumnStats only used for testing purposes.
  */
-private[columnar] class NoopColumnStats extends ColumnStats {
-  override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
+private[columnar] final class NoopColumnStats extends ColumnStats {
+  override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+    if (!row.isNullAt(ordinal)) {
+      count += 1
+    } else {
+      gatherNullStats
+    }
+  }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L))
+  override def collectedStatistics: Array[Any] = Array[Any](null, null, nullCount, count, 0L)
 }
 
-private[columnar] class BooleanColumnStats extends ColumnStats {
+private[columnar] final class BooleanColumnStats extends ColumnStats {
   protected var upper = false
   protected var lower = true
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getBoolean(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += BOOLEAN.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Boolean): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += BOOLEAN.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class ByteColumnStats extends ColumnStats {
+private[columnar] final class ByteColumnStats extends ColumnStats {
   protected var upper = Byte.MinValue
   protected var lower = Byte.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getByte(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += BYTE.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Byte): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += BYTE.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class ShortColumnStats extends ColumnStats {
+private[columnar] final class ShortColumnStats extends ColumnStats {
   protected var upper = Short.MinValue
   protected var lower = Short.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getShort(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += SHORT.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Short): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += SHORT.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class IntColumnStats extends ColumnStats {
+private[columnar] final class IntColumnStats extends ColumnStats {
   protected var upper = Int.MinValue
   protected var lower = Int.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getInt(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += INT.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Int): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += INT.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class LongColumnStats extends ColumnStats {
+private[columnar] final class LongColumnStats extends ColumnStats {
   protected var upper = Long.MinValue
   protected var lower = Long.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getLong(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += LONG.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Long): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += LONG.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class FloatColumnStats extends ColumnStats {
+private[columnar] final class FloatColumnStats extends ColumnStats {
   protected var upper = Float.MinValue
   protected var lower = Float.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getFloat(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += FLOAT.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Float): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += FLOAT.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class DoubleColumnStats extends ColumnStats {
+private[columnar] final class DoubleColumnStats extends ColumnStats {
   protected var upper = Double.MinValue
   protected var lower = Double.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getDouble(ordinal)
-      if (value > upper) upper = value
-      if (value < lower) lower = value
-      sizeInBytes += DOUBLE.defaultSize
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Double): Unit = {
+    if (value > upper) upper = value
+    if (value < lower) lower = value
+    sizeInBytes += DOUBLE.defaultSize
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class StringColumnStats extends ColumnStats {
+private[columnar] final class StringColumnStats extends ColumnStats {
   protected var upper: UTF8String = null
   protected var lower: UTF8String = null
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getUTF8String(ordinal)
-      if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
-      if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
-      sizeInBytes += STRING.actualSize(row, ordinal)
+      val size = STRING.actualSize(row, ordinal)
+      gatherValueStats(value, size)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: UTF8String, size: Int): Unit = {
+    if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
+    if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
+    sizeInBytes += size
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class BinaryColumnStats extends ColumnStats {
+private[columnar] final class BinaryColumnStats extends ColumnStats {
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
-      sizeInBytes += BINARY.actualSize(row, ordinal)
+      val size = BINARY.actualSize(row, ordinal)
+      sizeInBytes += size
+      count += 1
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
+  override def collectedStatistics: Array[Any] =
+    Array[Any](null, null, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
+private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
   def this(dt: DecimalType) = this(dt.precision, dt.scale)
 
   protected var upper: Decimal = null
   protected var lower: Decimal = null
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
       val value = row.getDecimal(ordinal, precision, scale)
-      if (upper == null || value.compareTo(upper) > 0) upper = value
-      if (lower == null || value.compareTo(lower) < 0) lower = value
       // TODO: this is not right for DecimalType with precision > 18
-      sizeInBytes += 8
+      gatherValueStats(value)
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
+  def gatherValueStats(value: Decimal): Unit = {
+    if (upper == null || value.compareTo(upper) > 0) upper = value
+    if (lower == null || value.compareTo(lower) < 0) lower = value
+    sizeInBytes += 8
+    count += 1
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats {
+private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats {
   val columnType = ColumnType(dataType)
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
-      sizeInBytes += columnType.actualSize(row, ordinal)
+      val size = columnType.actualSize(row, ordinal)
+      sizeInBytes += size
+      count += 1
+    } else {
+      gatherNullStats
     }
   }
 
-  override def collectedStatistics: GenericInternalRow =
-    new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
+  override def collectedStatistics: Array[Any] =
+    Array[Any](null, null, nullCount, count, sizeInBytes)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 0a9f3e799990f16df9afca8bed60dc8749418c62..3486a6bce81800e92409ffadb55299f20bde690a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -123,8 +123,8 @@ case class InMemoryRelation(
 
           batchStats.add(totalSize)
 
-          val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
-            .flatMap(_.values))
+          val stats = InternalRow.fromSeq(
+            columnBuilders.flatMap(_.columnStats.collectedStatistics))
           CachedBatch(rowCount, columnBuilders.map { builder =>
             JavaUtils.bufferToArray(builder.build())
           }, stats)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
index b2d04f7c5a6e36ee93b6766a02771dba745a72b9..d4e7e362c6c8c7e7d23016f6905f8b486974f4f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
@@ -18,33 +18,29 @@
 package org.apache.spark.sql.execution.columnar
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.types._
 
 class ColumnStatsSuite extends SparkFunSuite {
-  testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
-  testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
-  testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
-  testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
-  testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
-  testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
-  testColumnStats(classOf[DoubleColumnStats], DOUBLE,
-    createRow(Double.MaxValue, Double.MinValue, 0))
-  testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
-  testDecimalColumnStats(createRow(null, null, 0))
-
-  def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)
+  testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
+  testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0))
+  testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0))
+  testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0))
+  testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
+  testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
+  testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
+  testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
+  testDecimalColumnStats(Array(null, null, 0))
 
   def testColumnStats[T <: AtomicType, U <: ColumnStats](
       columnStatsClass: Class[U],
       columnType: NativeColumnType[T],
-      initialStatistics: GenericInternalRow): Unit = {
+      initialStatistics: Array[Any]): Unit = {
 
     val columnStatsName = columnStatsClass.getSimpleName
 
     test(s"$columnStatsName: empty") {
       val columnStats = columnStatsClass.newInstance()
-      columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
+      columnStats.collectedStatistics.zip(initialStatistics).foreach {
         case (actual, expected) => assert(actual === expected)
       }
     }
@@ -60,11 +56,11 @@ class ColumnStatsSuite extends SparkFunSuite {
       val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
       val stats = columnStats.collectedStatistics
 
-      assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
-      assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
-      assertResult(10, "Wrong null count")(stats.values(2))
-      assertResult(20, "Wrong row count")(stats.values(3))
-      assertResult(stats.values(4), "Wrong size in bytes") {
+      assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+      assertResult(10, "Wrong null count")(stats(2))
+      assertResult(20, "Wrong row count")(stats(3))
+      assertResult(stats(4), "Wrong size in bytes") {
         rows.map { row =>
           if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
         }.sum
@@ -73,14 +69,14 @@ class ColumnStatsSuite extends SparkFunSuite {
   }
 
   def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
-      initialStatistics: GenericInternalRow): Unit = {
+      initialStatistics: Array[Any]): Unit = {
 
     val columnStatsName = classOf[DecimalColumnStats].getSimpleName
     val columnType = COMPACT_DECIMAL(15, 10)
 
     test(s"$columnStatsName: empty") {
       val columnStats = new DecimalColumnStats(15, 10)
-      columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
+      columnStats.collectedStatistics.zip(initialStatistics).foreach {
         case (actual, expected) => assert(actual === expected)
       }
     }
@@ -96,11 +92,11 @@ class ColumnStatsSuite extends SparkFunSuite {
       val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
       val stats = columnStats.collectedStatistics
 
-      assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
-      assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
-      assertResult(10, "Wrong null count")(stats.values(2))
-      assertResult(20, "Wrong row count")(stats.values(3))
-      assertResult(stats.values(4), "Wrong size in bytes") {
+      assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+      assertResult(10, "Wrong null count")(stats(2))
+      assertResult(20, "Wrong row count")(stats(3))
+      assertResult(stats(4), "Wrong size in bytes") {
         rows.map { row =>
           if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
         }.sum