diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f08350878f239f9cfff8ff6b28c56b18db6cf9ee..0357dcc4688beb42073516016ab6cc7f14e9d6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -164,33 +164,57 @@ private[sql] object ParquetFilters { case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeEq.lift(dataType).map(_(name, value)) + case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeEq.lift(dataType).map(_(name, value)) case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeEq.lift(dataType).map(_(name, value)) - + case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => + makeNotEq.lift(dataType).map(_(name, value)) case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => + makeNotEq.lift(dataType).map(_(name, value)) case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLt.lift(dataType).map(_(name, value)) + case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLt.lift(dataType).map(_(name, value)) case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGt.lift(dataType).map(_(name, value)) + case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGt.lift(dataType).map(_(name, value)) case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) + case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) + case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLt.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLtEq.lift(dataType).map(_(name, value)) case And(lhs, rhs) => (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index c9bc55900de980c888375a3d3cad6b0779a72e5c..e78145f4dda5a5b9126a199d31c36456a2660dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -21,7 +21,8 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} +import org.apache.spark.sql.types._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} @@ -93,6 +94,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } + test("filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, + classOf[Operators.And], 3) + checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])