diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index f08350878f239f9cfff8ff6b28c56b18db6cf9ee..0357dcc4688beb42073516016ab6cc7f14e9d6a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -164,33 +164,57 @@ private[sql] object ParquetFilters {
 
       case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeEq.lift(dataType).map(_(name, value))
+      case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeEq.lift(dataType).map(_(name, value))
       case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeEq.lift(dataType).map(_(name, value))
-
+      case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeEq.lift(dataType).map(_(name, value))
+      
       case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) =>
         makeNotEq.lift(dataType).map(_(name, value))
+      case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) =>
+        makeNotEq.lift(dataType).map(_(name, value))
       case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) =>
         makeNotEq.lift(dataType).map(_(name, value))
+      case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) =>
+        makeNotEq.lift(dataType).map(_(name, value))
 
       case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeLt.lift(dataType).map(_(name, value))
+      case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeLt.lift(dataType).map(_(name, value))
       case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeGt.lift(dataType).map(_(name, value))
+      case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeGt.lift(dataType).map(_(name, value))
 
       case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeLtEq.lift(dataType).map(_(name, value))
+      case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeLtEq.lift(dataType).map(_(name, value))      
       case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeGtEq.lift(dataType).map(_(name, value))
+      case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeGtEq.lift(dataType).map(_(name, value))
 
       case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeGt.lift(dataType).map(_(name, value))
+      case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeGt.lift(dataType).map(_(name, value)) 
       case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeLt.lift(dataType).map(_(name, value))
+      case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeLt.lift(dataType).map(_(name, value))
 
       case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeGtEq.lift(dataType).map(_(name, value))
+      case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeGtEq.lift(dataType).map(_(name, value)) 
       case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeLtEq.lift(dataType).map(_(name, value))
+      case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeLtEq.lift(dataType).map(_(name, value))
 
       case And(lhs, rhs) =>
         (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index c9bc55900de980c888375a3d3cad6b0779a72e5c..e78145f4dda5a5b9126a199d31c36456a2660dfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -21,7 +21,8 @@ import parquet.filter2.predicate.Operators._
 import parquet.filter2.predicate.{FilterPredicate, Operators}
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
 
@@ -93,6 +94,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
     }
   }
 
+  test("filter pushdown - short") {
+    withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd =>
+      checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq   [_]], 1)
+      checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+      
+      checkFilterPredicate(Cast('_1, IntegerType) < 2,  classOf[Lt  [_]], 1)
+      checkFilterPredicate(Cast('_1, IntegerType) > 3,  classOf[Gt  [_]], 4)
+      checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1)
+      checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4)
+      
+      checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq  [_]], 1)
+      checkFilterPredicate(Literal(2) >   Cast('_1, IntegerType), classOf[Lt  [_]], 1)
+      checkFilterPredicate(Literal(3) <   Cast('_1, IntegerType), classOf[Gt  [_]], 4)
+      checkFilterPredicate(Literal(1) >=  Cast('_1, IntegerType), classOf[LtEq[_]], 1)
+      checkFilterPredicate(Literal(4) <=  Cast('_1, IntegerType), classOf[GtEq[_]], 4)
+      
+      checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4)
+      checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, 
+        classOf[Operators.And], 3)
+      checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, 
+        classOf[Operators.Or],  Seq(Row(1), Row(4)))
+    }
+  }
+
   test("filter pushdown - integer") {
     withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd =>
       checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])