diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8364379644c9063fd77c2b39a79450ae07676a7c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+/**
+ * Builds a map that is keyed by an Attribute's expression id. Using the expression id allows values
+ * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization
+ * of the name, or the expected nullability).
+ */
+object AttributeMap {
+  def apply[A](kvs: Seq[(Attribute, A)]) =
+    new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
+}
+
+class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
+  extends Map[Attribute, A] with Serializable {
+
+  override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
+
+  override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
+    (baseMap.map(_._2) + kv).toMap
+
+  override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
+
+  override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 54c6baf1af3bf1bd896023438597802ba18f4bda..fa80b07f8e6be4178bbea5883484f6ed62830aaa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -38,12 +38,20 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
 }
 
 object BindReferences extends Logging {
-  def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
+
+  def bindReference[A <: Expression](
+      expression: A,
+      input: Seq[Attribute],
+      allowFailures: Boolean = false): A = {
     expression.transform { case a: AttributeReference =>
       attachTree(a, "Binding attribute") {
         val ordinal = input.indexWhere(_.exprId == a.exprId)
         if (ordinal == -1) {
-          sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+          if (allowFailures) {
+            a
+          } else {
+            sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+          }
         } else {
           BoundReference(ordinal, a.dataType, a.nullable)
         }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 64d49354dadcd70f189bd3f03c0ef4a22db80fc7..4137ac76637391541f13b43494399affbe4abc0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -26,6 +26,7 @@ import java.util.Properties
 private[spark] object SQLConf {
   val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
   val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
+  val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
   val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
   val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
   val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
@@ -124,6 +125,12 @@ trait SQLConf {
   private[spark] def isParquetBinaryAsString: Boolean =
     getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean
 
+  /**
+   * When set to true, partition pruning for in-memory columnar tables is enabled.
+   */
+  private[spark] def inMemoryPartitionPruning: Boolean =
+    getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 247337a875c754bf290e8a4365d0c235fed56499..b3ec5ded22422c9c591e0ffc3a94d91fd0d36940 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder {
   /**
    * Column statistics information
    */
-  def columnStats: ColumnStats[_, _]
+  def columnStats: ColumnStats
 
   /**
    * Returns the final columnar byte buffer.
@@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder {
 }
 
 private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
-    val columnStats: ColumnStats[T, JvmType],
+    val columnStats: ColumnStats,
     val columnType: ColumnType[T, JvmType])
   extends ColumnBuilder {
 
@@ -81,18 +81,18 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
 
 private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
     columnType: ColumnType[T, JvmType])
-  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
   with NullableColumnBuilder
 
 private[sql] abstract class NativeColumnBuilder[T <: NativeType](
-    override val columnStats: NativeColumnStats[T],
+    override val columnStats: ColumnStats,
     override val columnType: NativeColumnType[T])
   extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
   with NullableColumnBuilder
   with AllCompressionSchemes
   with CompressibleColumnBuilder[T]
 
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN)
 
 private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 6502110e903fe683d0f84cb79c72c9874747d3dc..fc343ccb995c2cb64c2d2501a31ae1d8a0e9457f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -17,381 +17,193 @@
 
 package org.apache.spark.sql.columnar
 
+import java.sql.Timestamp
+
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference}
 import org.apache.spark.sql.catalyst.types._
 
+private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
+  val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)()
+  val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)()
+  val nullCount =  AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+
+  val schema = Seq(lowerBound, upperBound, nullCount)
+}
+
+private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
+  val (forAttribute, schema) = {
+    val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a))
+    (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _))
+  }
+}
+
 /**
  * Used to collect statistical information when building in-memory columns.
  *
  * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]`
  * brings significant performance penalty.
  */
-private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
-  /**
-   * Closed lower bound of this column.
-   */
-  def lowerBound: JvmType
-
-  /**
-   * Closed upper bound of this column.
-   */
-  def upperBound: JvmType
-
+private[sql] sealed trait ColumnStats extends Serializable {
   /**
    * Gathers statistics information from `row(ordinal)`.
    */
-  def gatherStats(row: Row, ordinal: Int)
-
-  /**
-   * Returns `true` if `lower <= row(ordinal) <= upper`.
-   */
-  def contains(row: Row, ordinal: Int): Boolean
+  def gatherStats(row: Row, ordinal: Int): Unit
 
   /**
-   * Returns `true` if `row(ordinal) < upper` holds.
+   * Column statistics represented as a single row, currently including closed lower bound, closed
+   * upper bound and null count.
    */
-  def isAbove(row: Row, ordinal: Int): Boolean
-
-  /**
-   * Returns `true` if `lower < row(ordinal)` holds.
-   */
-  def isBelow(row: Row, ordinal: Int): Boolean
-
-  /**
-   * Returns `true` if `row(ordinal) <= upper` holds.
-   */
-  def isAtOrAbove(row: Row, ordinal: Int): Boolean
-
-  /**
-   * Returns `true` if `lower <= row(ordinal)` holds.
-   */
-  def isAtOrBelow(row: Row, ordinal: Int): Boolean
-}
-
-private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
-  extends ColumnStats[T, T#JvmType] {
-
-  type JvmType = T#JvmType
-
-  protected var (_lower, _upper) = initialBounds
-
-  def initialBounds: (JvmType, JvmType)
-
-  protected def columnType: NativeColumnType[T]
-
-  override def lowerBound: T#JvmType = _lower
-
-  override def upperBound: T#JvmType = _upper
-
-  override def isAtOrAbove(row: Row, ordinal: Int) = {
-    contains(row, ordinal) || isAbove(row, ordinal)
-  }
-
-  override def isAtOrBelow(row: Row, ordinal: Int) = {
-    contains(row, ordinal) || isBelow(row, ordinal)
-  }
+  def collectedStatistics: Row
 }
 
-private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
-  override def isAtOrBelow(row: Row, ordinal: Int) = true
-
-  override def isAtOrAbove(row: Row, ordinal: Int) = true
-
-  override def isBelow(row: Row, ordinal: Int) = true
-
-  override def isAbove(row: Row, ordinal: Int) = true
+private[sql] class NoopColumnStats extends ColumnStats {
 
-  override def contains(row: Row, ordinal: Int) = true
+  override def gatherStats(row: Row, ordinal: Int): Unit = {}
 
-  override def gatherStats(row: Row, ordinal: Int) {}
-
-  override def upperBound = null.asInstanceOf[JvmType]
-
-  override def lowerBound = null.asInstanceOf[JvmType]
+  override def collectedStatistics = Row()
 }
 
-private[sql] abstract class BasicColumnStats[T <: NativeType](
-    protected val columnType: NativeColumnType[T])
-  extends NativeColumnStats[T]
-
-private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
-  override def initialBounds = (true, false)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class ByteColumnStats extends ColumnStats {
+  var upper = Byte.MinValue
+  var lower = Byte.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
-}
-
-private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
-  override def initialBounds = (Byte.MaxValue, Byte.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getByte(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
-  override def initialBounds = (Short.MaxValue, Short.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class ShortColumnStats extends ColumnStats {
+  var upper = Short.MinValue
+  var lower = Short.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
-}
-
-private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
-  override def initialBounds = (Long.MaxValue, Long.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getShort(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
-  override def initialBounds = (Double.MaxValue, Double.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class LongColumnStats extends ColumnStats {
+  var upper = Long.MinValue
+  var lower = Long.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
-}
-
-private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
-  override def initialBounds = (Float.MaxValue, Float.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getLong(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
+}
 
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class DoubleColumnStats extends ColumnStats {
+  var upper = Double.MinValue
+  var lower = Double.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getDouble(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
-}
 
-private[sql] object IntColumnStats {
-  val UNINITIALIZED = 0
-  val INITIALIZED = 1
-  val ASCENDING = 2
-  val DESCENDING = 3
-  val UNORDERED = 4
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-/**
- * Statistical information for `Int` columns. More information is collected since `Int` is
- * frequently used. Extra information include:
- *
- * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
- *   is applicable when searching elements.
- * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
- *   scheme.
- *
- * (This two kinds of information are not used anywhere yet and might be removed later.)
- */
-private[sql] class IntColumnStats extends BasicColumnStats(INT) {
-  import IntColumnStats._
-
-  private var orderedState = UNINITIALIZED
-  private var lastValue: Int = _
-  private var _maxDelta: Int = _
-
-  def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
-  def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
-  def isOrdered = isAscending || isDescending
-  def maxDelta = _maxDelta
-
-  override def initialBounds = (Int.MaxValue, Int.MinValue)
+private[sql] class FloatColumnStats extends ColumnStats {
+  var upper = Float.MinValue
+  var lower = Float.MaxValue
+  var nullCount = 0
 
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
+  override def gatherStats(row: Row, ordinal: Int) {
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getFloat(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
+}
 
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class IntColumnStats extends ColumnStats {
+  var upper = Int.MinValue
+  var lower = Int.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-
-    orderedState = orderedState match {
-      case UNINITIALIZED =>
-        lastValue = field
-        INITIALIZED
-
-      case INITIALIZED =>
-        // If all the integers in the column are the same, ordered state is set to Ascending.
-        // TODO (lian) Confirm whether this is the standard behaviour.
-        val nextState = if (field >= lastValue) ASCENDING else DESCENDING
-        _maxDelta = math.abs(field - lastValue)
-        lastValue = field
-        nextState
-
-      case ASCENDING if field < lastValue =>
-        UNORDERED
-
-      case DESCENDING if field > lastValue =>
-        UNORDERED
-
-      case state @ (ASCENDING | DESCENDING) =>
-        _maxDelta = _maxDelta.max(field - lastValue)
-        lastValue = field
-        state
-
-      case _ =>
-        orderedState
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getInt(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
     }
   }
+
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
-  override def initialBounds = (null, null)
+private[sql] class StringColumnStats extends ColumnStats {
+  var upper: String = null
+  var lower: String = null
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
-    if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    (upperBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
-    }
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    (upperBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      field.compareTo(upperBound) < 0
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getString(ordinal)
+      if (upper == null || value.compareTo(upper) > 0) upper = value
+      if (lower == null || value.compareTo(lower) < 0) lower = value
+    } else {
+      nullCount += 1
     }
   }
 
-  override def isBelow(row: Row, ordinal: Int) = {
-    (lowerBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) < 0
-    }
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) {
-  override def initialBounds = (null, null)
+private[sql] class TimestampColumnStats extends ColumnStats {
+  var upper: Timestamp = null
+  var lower: Timestamp = null
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
-    if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    (upperBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+    if (!row.isNullAt(ordinal)) {
+      val value = row(ordinal).asInstanceOf[Timestamp]
+      if (upper == null || value.compareTo(upper) > 0) upper = value
+      if (lower == null || value.compareTo(lower) < 0) lower = value
+    } else {
+      nullCount += 1
     }
   }
 
-  override def isAbove(row: Row, ordinal: Int) = {
-    (lowerBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      field.compareTo(upperBound) < 0
-    }
-  }
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    (lowerBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) < 0
-    }
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index cb055cd74a5e5d638fd00485d203f2e83f4c2156..dc668e7dc934ca67e95fe53cf8dff962b2b8b563 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.columnar
 
 import java.nio.ByteBuffer
 
+import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
 
@@ -31,23 +33,27 @@ object InMemoryRelation {
     new InMemoryRelation(child.output, useCompression, batchSize, child)()
 }
 
+private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
+
 private[sql] case class InMemoryRelation(
     output: Seq[Attribute],
     useCompression: Boolean,
     batchSize: Int,
     child: SparkPlan)
-    (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null)
+    (private var _cachedColumnBuffers: RDD[CachedBatch] = null)
   extends LogicalPlan with MultiInstanceRelation {
 
   override lazy val statistics =
     Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
 
+  val partitionStatistics = new PartitionStatistics(output)
+
   // If the cached column buffers were not passed in, we calculate them in the constructor.
   // As in Spark, the actual work of caching is lazy.
   if (_cachedColumnBuffers == null) {
     val output = child.output
     val cached = child.execute().mapPartitions { baseIterator =>
-      new Iterator[Array[ByteBuffer]] {
+      new Iterator[CachedBatch] {
         def next() = {
           val columnBuilders = output.map { attribute =>
             val columnType = ColumnType(attribute.dataType)
@@ -68,7 +74,10 @@ private[sql] case class InMemoryRelation(
             rowCount += 1
           }
 
-          columnBuilders.map(_.build())
+          val stats = Row.fromSeq(
+            columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+
+          CachedBatch(columnBuilders.map(_.build()), stats)
         }
 
         def hasNext = baseIterator.hasNext
@@ -79,7 +88,6 @@ private[sql] case class InMemoryRelation(
     _cachedColumnBuffers = cached
   }
 
-
   override def children = Seq.empty
 
   override def newInstance() = {
@@ -96,13 +104,98 @@ private[sql] case class InMemoryRelation(
 
 private[sql] case class InMemoryColumnarTableScan(
     attributes: Seq[Attribute],
+    predicates: Seq[Expression],
     relation: InMemoryRelation)
   extends LeafNode {
 
+  @transient override val sqlContext = relation.child.sqlContext
+
   override def output: Seq[Attribute] = attributes
 
+  // Returned filter predicate should return false iff it is impossible for the input expression
+  // to evaluate to `true' based on statistics collected about this partition batch.
+  val buildFilter: PartialFunction[Expression, Expression] = {
+    case And(lhs: Expression, rhs: Expression)
+      if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+      buildFilter(lhs) && buildFilter(rhs)
+
+    case Or(lhs: Expression, rhs: Expression)
+      if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+      buildFilter(lhs) || buildFilter(rhs)
+
+    case EqualTo(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l && l <= aStats.upperBound
+
+    case EqualTo(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l && l <= aStats.upperBound
+
+    case LessThan(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound < l
+
+    case LessThan(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l < aStats.upperBound
+
+    case LessThanOrEqual(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l
+
+    case LessThanOrEqual(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l <= aStats.upperBound
+
+    case GreaterThan(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l < aStats.upperBound
+
+    case GreaterThan(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound < l
+
+    case GreaterThanOrEqual(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l <= aStats.upperBound
+
+    case GreaterThanOrEqual(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l
+  }
+
+  val partitionFilters = {
+    predicates.flatMap { p =>
+      val filter = buildFilter.lift(p)
+      val boundFilter =
+        filter.map(
+          BindReferences.bindReference(
+            _,
+            relation.partitionStatistics.schema,
+            allowFailures = true))
+
+      boundFilter.foreach(_ =>
+        filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
+
+      // If the filter can't be resolved then we are missing required statistics.
+      boundFilter.filter(_.resolved)
+    }
+  }
+
+  val readPartitions = sparkContext.accumulator(0)
+  val readBatches = sparkContext.accumulator(0)
+
+  private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning
+
   override def execute() = {
+    readPartitions.setValue(0)
+    readBatches.setValue(0)
+
     relation.cachedColumnBuffers.mapPartitions { iterator =>
+      val partitionFilter = newPredicate(
+        partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+        relation.partitionStatistics.schema)
+
       // Find the ordinals of the requested columns.  If none are requested, use the first.
       val requestedColumns = if (attributes.isEmpty) {
         Seq(0)
@@ -110,8 +203,26 @@ private[sql] case class InMemoryColumnarTableScan(
         attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
       }
 
-      iterator
-        .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_)))
+      val rows = iterator
+        // Skip pruned batches
+        .filter { cachedBatch =>
+          if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) {
+            def statsString = relation.partitionStatistics.schema
+              .zip(cachedBatch.stats)
+              .map { case (a, s) => s"${a.name}: $s" }
+              .mkString(", ")
+            logInfo(s"Skipping partition based on stats $statsString")
+            false
+          } else {
+            readBatches += 1
+            true
+          }
+        }
+        // Build column accessors
+        .map { cachedBatch =>
+          requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+        }
+        // Extract rows via column accessors
         .flatMap { columnAccessors =>
           val nextRow = new GenericMutableRow(columnAccessors.length)
           new Iterator[Row] {
@@ -127,6 +238,12 @@ private[sql] case class InMemoryColumnarTableScan(
             override def hasNext = columnAccessors.head.hasNext
           }
         }
+
+      if (rows.hasNext) {
+        readPartitions += 1
+      }
+
+      rows
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index f631ee76fcd78362abb8bcb8e887a6e746d0ac4a..a72970eef7aa4192781d6fe33ccc9fa611ceaeb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
   }
 
   abstract override def appendFrom(row: Row, ordinal: Int) {
+    columnStats.gatherStats(row, ordinal)
     if (row.isNullAt(ordinal)) {
       nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
       nulls.putInt(pos)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8dacb84c8a17e0754986b686ab4b15322df09380..7943d6e1b6fb5b31562aa320f5356b366f307e42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -243,8 +243,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         pruneFilterProject(
           projectList,
           filters,
-          identity[Seq[Expression]], // No filters are pushed down.
-          InMemoryColumnarTableScan(_, mem)) :: Nil
+          identity[Seq[Expression]], // All filters still need to be evaluated.
+          InMemoryColumnarTableScan(_,  filters, mem)) :: Nil
       case _ => Nil
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 5f61fb5e16ea3265dbd259e7bd8bbe6e943e5a1d..cde91ceb68c9848c76ad83a06ef54dd37594bcdc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.sql.catalyst.expressions.Row
 import org.apache.spark.sql.catalyst.types._
 
 class ColumnStatsSuite extends FunSuite {
-  testColumnStats(classOf[BooleanColumnStats],   BOOLEAN)
-  testColumnStats(classOf[ByteColumnStats],      BYTE)
-  testColumnStats(classOf[ShortColumnStats],     SHORT)
-  testColumnStats(classOf[IntColumnStats],       INT)
-  testColumnStats(classOf[LongColumnStats],      LONG)
-  testColumnStats(classOf[FloatColumnStats],     FLOAT)
-  testColumnStats(classOf[DoubleColumnStats],    DOUBLE)
-  testColumnStats(classOf[StringColumnStats],    STRING)
-  testColumnStats(classOf[TimestampColumnStats], TIMESTAMP)
-
-  def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+  testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0))
+  testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0))
+  testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0))
+  testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
+  testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
+  testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
+  testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
+  testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
+
+  def testColumnStats[T <: NativeType, U <: ColumnStats](
       columnStatsClass: Class[U],
-      columnType: NativeColumnType[T]) {
+      columnType: NativeColumnType[T],
+      initialStatistics: Row) {
 
     val columnStatsName = columnStatsClass.getSimpleName
 
     test(s"$columnStatsName: empty") {
       val columnStats = columnStatsClass.newInstance()
-      assertResult(columnStats.initialBounds, "Wrong initial bounds") {
-        (columnStats.lowerBound, columnStats.upperBound)
+      columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) =>
+        assert(actual === expected)
       }
     }
 
@@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite {
       import ColumnarTestUtils._
 
       val columnStats = columnStatsClass.newInstance()
-      val rows = Seq.fill(10)(makeRandomRow(columnType))
+      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
       rows.foreach(columnStats.gatherStats(_, 0))
 
-      val values = rows.map(_.head.asInstanceOf[T#JvmType])
+      val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType])
       val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+      val stats = columnStats.collectedStatistics
 
-      assertResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
-      assertResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
+      assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+      assertResult(10, "Wrong null count")(stats(2))
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index dc813fe146c4736a97a06d10ca28d88efc588645..a77262534a35259974e9d15c4f540b6a4159b106 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._
 import org.apache.spark.sql.execution.SparkSqlSerializer
 
 class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
-  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
   with NullableColumnBuilder
 
 object TestNullableColumnBuilder {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5d2fd4959197c87819d0079a1ed75a7e20a995a3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class IntegerData(i: Int)
+
+class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
+  val originalColumnBatchSize = columnBatchSize
+  val originalInMemoryPartitionPruning = inMemoryPartitionPruning
+
+  override protected def beforeAll() {
+    // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
+    setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+    val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
+    rawData.registerTempTable("intData")
+
+    // Enable in-memory partition pruning
+    setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+  }
+
+  override protected def afterAll() {
+    setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+    setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+  }
+
+  before {
+    cacheTable("intData")
+  }
+
+  after {
+    uncacheTable("intData")
+  }
+
+  // Comparisons
+  checkBatchPruning("i = 1", Seq(1), 1, 1)
+  checkBatchPruning("1 = i", Seq(1), 1, 1)
+  checkBatchPruning("i < 12", 1 to 11, 1, 2)
+  checkBatchPruning("i <= 11", 1 to 11, 1, 2)
+  checkBatchPruning("i > 88", 89 to 100, 1, 2)
+  checkBatchPruning("i >= 89", 89 to 100, 1, 2)
+  checkBatchPruning("12 > i", 1 to 11, 1, 2)
+  checkBatchPruning("11 >= i", 1 to 11, 1, 2)
+  checkBatchPruning("88 < i", 89 to 100, 1, 2)
+  checkBatchPruning("89 <= i", 89 to 100, 1, 2)
+
+  // Conjunction and disjunction
+  checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
+  checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
+  checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
+
+  // With unsupported predicate
+  checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
+  checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10)
+
+  def checkBatchPruning(
+      filter: String,
+      expectedQueryResult: Seq[Int],
+      expectedReadPartitions: Int,
+      expectedReadBatches: Int) {
+
+    test(filter) {
+      val query = sql(s"SELECT * FROM intData WHERE $filter")
+      assertResult(expectedQueryResult.toArray, "Wrong query result") {
+        query.collect().map(_.head).toArray
+      }
+
+      val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
+        case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
+      }.head
+
+      assert(readBatches === expectedReadBatches, "Wrong number of read batches")
+      assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions")
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 5fba00480967c0edeaff885d0f1f3156f38e0be3..e01cc8b4d20f2bfded92569efe66e6a1e015663b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression
 import org.scalatest.FunSuite
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats}
+import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
 import org.apache.spark.sql.columnar.ColumnarTestUtils._
 
 class BooleanBitSetSuite extends FunSuite {
@@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite {
     // Tests encoder
     // -------------
 
-    val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet)
+    val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
     val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
     val values = rows.map(_.head)
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index d8ae2a26778c945d59152f48dbec812d63b60b23..d2969d906c943ceab50cfdbf553bfb5ae31d5dc7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite {
   testDictionaryEncoding(new StringColumnStats, STRING)
 
   def testDictionaryEncoding[T <: NativeType](
-      columnStats: NativeColumnStats[T],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[T]) {
 
     val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 17619dcf974e37e2a30d676d628219a77756b479..322f447c2484032980cd7eddf3f2e60e3309cbd8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite {
   testIntegralDelta(new LongColumnStats, LONG, LongDelta)
 
   def testIntegralDelta[I <: IntegralType](
-      columnStats: NativeColumnStats[I],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[I],
       scheme: IntegralDelta[I]) {
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 40115beb98899fc434dd6c46d2a14910cb5f1c77..218c09ac26362a9e22daa4a3d97699dfbb753cfa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._
 import org.apache.spark.sql.columnar.ColumnarTestUtils._
 
 class RunLengthEncodingSuite extends FunSuite {
-  testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+  testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
   testRunLengthEncoding(new ByteColumnStats,    BYTE)
   testRunLengthEncoding(new ShortColumnStats,   SHORT)
   testRunLengthEncoding(new IntColumnStats,     INT)
@@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite {
   testRunLengthEncoding(new StringColumnStats,  STRING)
 
   def testRunLengthEncoding[T <: NativeType](
-      columnStats: NativeColumnStats[T],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[T]) {
 
     val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
index 72c19fa31d980f4010c79e8e437323e770048ba8..7db723d648d80327fecfc205e1a280af99b0ff19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType
 import org.apache.spark.sql.columnar._
 
 class TestCompressibleColumnBuilder[T <: NativeType](
-    override val columnStats: NativeColumnStats[T],
+    override val columnStats: ColumnStats,
     override val columnType: NativeColumnType[T],
     override val schemes: Seq[CompressionScheme])
   extends NativeColumnBuilder(columnStats, columnType)
@@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType](
 
 object TestCompressibleColumnBuilder {
   def apply[T <: NativeType](
-      columnStats: NativeColumnStats[T],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[T],
       scheme: CompressionScheme) = {
 
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index b589994bd25fa2247722308ba315d9a21825f8fd..ab487d673e8138c7f764b5f5d398c4a5311a0829 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -35,26 +35,29 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
 
   private val originalTimeZone = TimeZone.getDefault
   private val originalLocale = Locale.getDefault
-  private val originalUseCompression = TestHive.useCompression
+  private val originalColumnBatchSize = TestHive.columnBatchSize
+  private val originalInMemoryPartitionPruning = TestHive.inMemoryPartitionPruning
 
   def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
 
   override def beforeAll() {
-    // Enable in-memory columnar caching
     TestHive.cacheTables = true
     // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
     TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
     // Add Locale setting
     Locale.setDefault(Locale.US)
-    // Enable in-memory columnar compression
-    TestHive.setConf(SQLConf.COMPRESS_CACHED, "true")
+    // Set a relatively small column batch size for testing purposes
+    TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5")
+    // Enable in-memory partition pruning for testing purposes
+    TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
   }
 
   override def afterAll() {
     TestHive.cacheTables = false
     TimeZone.setDefault(originalTimeZone)
     Locale.setDefault(originalLocale)
-    TestHive.setConf(SQLConf.COMPRESS_CACHED, originalUseCompression.toString)
+    TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+    TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
   }
 
   /** A list of tests deemed out of scope currently and thus completely disregarded. */