From 6f54dee66100e5e58f6649158db257eb5009bd6a Mon Sep 17 00:00:00 2001 From: Cheng Lian <lian@databricks.com> Date: Mon, 16 Feb 2015 12:48:55 -0800 Subject: [PATCH] [SPARK-5296] [SQL] Add more filter types for data sources API This PR adds the following filter types for data sources API: - `IsNull` - `IsNotNull` - `Not` - `And` - `Or` The code which converts Catalyst predicate expressions to data sources filters is very similar to filter conversion logics in `ParquetFilters` which converts Catalyst predicates to Parquet filter predicates. In this way we can support nested AND/OR/NOT predicates without changing current `BaseScan` type hierarchy. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4623) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #4623 from liancheng/more-fiters and squashes the following commits: 1b296f4 [Cheng Lian] Add more filter types for data sources API --- .../org/apache/spark/sql/SQLContext.scala | 9 ++- .../apache/spark/sql/parquet/newParquet.scala | 5 +- .../sql/sources/DataSourceStrategy.scala | 81 +++++++++++++------ .../apache/spark/sql/sources/filters.scala | 5 ++ .../spark/sql/sources/FilteredScanSuite.scala | 34 +++++++- 5 files changed, 103 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b42a52ebd2..1442250569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -28,16 +28,16 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} +import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ -import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation, _} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.{Partition, SparkContext} @@ -867,7 +867,8 @@ class SQLContext(@transient val sparkContext: SparkContext) val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) + val filterCondition = + prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 9279f5a903..9bb34e2df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} + import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter import parquet.hadoop.metadata.CompressionCodecName @@ -42,6 +43,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ @@ -497,7 +499,8 @@ case class ParquetRelation2( _.references.map(_.name).toSet.subsetOf(partitionColumnNames) } - val rawPredicate = partitionPruningPredicates.reduceOption(And).getOrElse(Literal(true)) + val rawPredicate = + partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) val boundPredicate = InterpretedPredicate(rawPredicate transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 624369afe8..a853385fda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, Strategy, execution} +import org.apache.spark.sql.{Row, Strategy, execution, sources} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -88,7 +88,7 @@ private[sql] object DataSourceStrategy extends Strategy { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(And) + val filterCondition = filterPredicates.reduceLeftOption(expressions.And) val pushedFilters = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. @@ -118,27 +118,60 @@ private[sql] object DataSourceStrategy extends Strategy { } } - /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ - protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - - case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - - case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - - case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => - GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => - LessThanOrEqual(a.name, v) - - case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => - LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => - GreaterThanOrEqual(a.name, v) + /** + * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, + * and convert them. + */ + protected[sql] def selectFilters(filters: Seq[Expression]) = { + def translate(predicate: Expression): Option[Filter] = predicate match { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => + Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Literal(v, _), a: Attribute) => + Some(sources.EqualTo(a.name, v)) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => + Some(sources.GreaterThan(a.name, v)) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => + Some(sources.LessThan(a.name, v)) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => + Some(sources.LessThan(a.name, v)) + case expressions.LessThan(Literal(v, _), a: Attribute) => + Some(sources.GreaterThan(a.name, v)) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, v)) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => + Some(sources.LessThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, v)) + + case expressions.InSet(a: Attribute, set) => + Some(sources.In(a.name, set.toArray)) + + case expressions.IsNull(a: Attribute) => + Some(sources.IsNull(a.name)) + case expressions.IsNotNull(a: Attribute) => + Some(sources.IsNotNull(a.name)) + + case expressions.And(left, right) => + (translate(left) ++ translate(right)).reduceOption(sources.And) + + case expressions.Or(left, right) => + for { + leftFilter <- translate(left) + rightFilter <- translate(right) + } yield sources.Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translate(child).map(sources.Not) + + case _ => None + } - case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) + filters.flatMap(translate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4a9fefc12b..1e4505e36d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -25,3 +25,8 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter case class LessThan(attribute: String, value: Any) extends Filter case class LessThanOrEqual(attribute: String, value: Any) extends Filter case class In(attribute: String, values: Array[Any]) extends Filter +case class IsNull(attribute: String) extends Filter +case class IsNotNull(attribute: String) extends Filter +case class And(left: Filter, right: Filter) extends Filter +case class Or(left: Filter, right: Filter) extends Filter +case class Not(child: Filter) extends Filter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 390538d35a..41cd35683c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -47,16 +47,22 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL FiltersPushed.list = filters - val filterFunctions = filters.collect { + def translateFilter(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) + case IsNull("a") => (a: Int) => false // Int can't be null + case IsNotNull("a") => (a: Int) => true + case Not(pred) => (a: Int) => !translateFilter(pred)(a) + case And(left, right) => (a: Int) => translateFilter(left)(a) && translateFilter(right)(a) + case Or(left, right) => (a: Int) => translateFilter(left)(a) || translateFilter(right)(a) + case _ => (a: Int) => true } - def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + def eval(a: Int) = !filters.map(translateFilter(_)(a)).contains(false) sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -136,6 +142,26 @@ class FilteredScanSuite extends DataSourceTest { "SELECT * FROM oneToTenFiltered WHERE b = 2", Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a IS NULL", + Seq.empty[Row]) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a IS NOT NULL", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", + (2 to 4).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", + (6 to 10).map(i => Row(i, i * 2)).toSeq) + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) @@ -162,6 +188,10 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) + testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution -- GitLab