From de221ea03288fb9fb7c14530425f4a9414b1088f Mon Sep 17 00:00:00 2001
From: Yash Datta <Yash.Datta@guavus.com>
Date: Thu, 29 Jan 2015 15:42:23 -0800
Subject: [PATCH] [SPARK-4786][SQL]: Parquet filter pushdown for castable types

Enable parquet filter pushdown of castable types like short, byte that can be cast to integer

Author: Yash Datta <Yash.Datta@guavus.com>

Closes #4156 from saucam/filter_short and squashes the following commits:

a403979 [Yash Datta] SPARK-4786: Fix styling issues
d029866 [Yash Datta] SPARK-4786: Add test case
cb2e0d9 [Yash Datta] SPARK-4786: Parquet filter pushdown for castable types
---
 .../spark/sql/parquet/ParquetFilters.scala    | 26 +++++++++++++++++-
 .../sql/parquet/ParquetFilterSuite.scala      | 27 ++++++++++++++++++-
 2 files changed, 51 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index f08350878f..0357dcc468 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -164,33 +164,57 @@ private[sql] object ParquetFilters {
 
       case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeEq.lift(dataType).map(_(name, value))
+      case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeEq.lift(dataType).map(_(name, value))
       case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeEq.lift(dataType).map(_(name, value))
-
+      case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeEq.lift(dataType).map(_(name, value))
+      
       case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) =>
         makeNotEq.lift(dataType).map(_(name, value))
+      case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) =>
+        makeNotEq.lift(dataType).map(_(name, value))
       case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) =>
         makeNotEq.lift(dataType).map(_(name, value))
+      case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) =>
+        makeNotEq.lift(dataType).map(_(name, value))
 
       case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeLt.lift(dataType).map(_(name, value))
+      case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeLt.lift(dataType).map(_(name, value))
       case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeGt.lift(dataType).map(_(name, value))
+      case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeGt.lift(dataType).map(_(name, value))
 
       case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeLtEq.lift(dataType).map(_(name, value))
+      case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeLtEq.lift(dataType).map(_(name, value))      
       case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeGtEq.lift(dataType).map(_(name, value))
+      case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeGtEq.lift(dataType).map(_(name, value))
 
       case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeGt.lift(dataType).map(_(name, value))
+      case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeGt.lift(dataType).map(_(name, value)) 
       case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeLt.lift(dataType).map(_(name, value))
+      case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeLt.lift(dataType).map(_(name, value))
 
       case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) =>
         makeGtEq.lift(dataType).map(_(name, value))
+      case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) =>
+        makeGtEq.lift(dataType).map(_(name, value)) 
       case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) =>
         makeLtEq.lift(dataType).map(_(name, value))
+      case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) =>
+        makeLtEq.lift(dataType).map(_(name, value))
 
       case And(lhs, rhs) =>
         (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index c9bc55900d..e78145f4dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -21,7 +21,8 @@ import parquet.filter2.predicate.Operators._
 import parquet.filter2.predicate.{FilterPredicate, Operators}
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
 
@@ -93,6 +94,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
     }
   }
 
+  test("filter pushdown - short") {
+    withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd =>
+      checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq   [_]], 1)
+      checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+      
+      checkFilterPredicate(Cast('_1, IntegerType) < 2,  classOf[Lt  [_]], 1)
+      checkFilterPredicate(Cast('_1, IntegerType) > 3,  classOf[Gt  [_]], 4)
+      checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1)
+      checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4)
+      
+      checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq  [_]], 1)
+      checkFilterPredicate(Literal(2) >   Cast('_1, IntegerType), classOf[Lt  [_]], 1)
+      checkFilterPredicate(Literal(3) <   Cast('_1, IntegerType), classOf[Gt  [_]], 4)
+      checkFilterPredicate(Literal(1) >=  Cast('_1, IntegerType), classOf[LtEq[_]], 1)
+      checkFilterPredicate(Literal(4) <=  Cast('_1, IntegerType), classOf[GtEq[_]], 4)
+      
+      checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4)
+      checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, 
+        classOf[Operators.And], 3)
+      checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, 
+        classOf[Operators.Or],  Seq(Row(1), Row(4)))
+    }
+  }
+
   test("filter pushdown - integer") {
     withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd =>
       checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
-- 
GitLab