diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ffb206af0e906346c7eee09a0b8105e7defe19dc..6d2b95e83a4440a62825deb3b062b02d29bef280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -213,7 +213,7 @@ private[sql] trait SQLTestUtils */ protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema - val withoutFilters = df.queryExecution.sparkPlan transform { + val withoutFilters = df.queryExecution.sparkPlan.transform { case FilterExec(_, child) => child } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index c025c12a90a2d286ab4650591d8b5752678d4c11..c463bc839480883b802422c8e832c9466c3d3073 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.orc -import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder -import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.spark.internal.Logging import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ /** * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. @@ -56,29 +55,35 @@ import org.apache.spark.sql.sources._ * known to be convertible. */ private[orc] object OrcFilters extends Logging { - def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + def createFilter(schema: StructType, filters: Array[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + // First, tries to convert each filter individually to see whether it's convertible, and then // collect all convertible ones to build the final `SearchArgument`. val convertibleFilters = for { filter <- filters - _ <- buildSearchArgument(filter, SearchArgumentFactory.newBuilder()) + _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) } yield filter for { // Combines all convertible filters using `And` to produce a single conjunction conjunction <- convertibleFilters.reduceOption(And) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate - builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) } yield builder.build() } - private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { + private def buildSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Option[Builder] = { def newBuilder = SearchArgumentFactory.newBuilder() - def isSearchableLiteral(value: Any): Boolean = value match { - // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. - case _: String | _: Long | _: Double | _: Byte | _: Short | _: Integer | _: Float => true - case _: DateWritable | _: HiveDecimal | _: HiveChar | _: HiveVarchar => true + def isSearchableType(dataType: DataType): Boolean = dataType match { + // Only the values in the Spark types below can be recognized by + // the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. + case ByteType | ShortType | FloatType | DoubleType => true + case IntegerType | LongType | StringType => true case _ => false } @@ -92,55 +97,55 @@ private[orc] object OrcFilters extends Logging { // Pushing one side of AND down is only safe to do at the top level. // You can see ParquetRelation's initializeLocalJobFunc method as an example. for { - _ <- buildSearchArgument(left, newBuilder) - _ <- buildSearchArgument(right, newBuilder) - lhs <- buildSearchArgument(left, builder.startAnd()) - rhs <- buildSearchArgument(right, lhs) + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) } yield rhs.end() case Or(left, right) => for { - _ <- buildSearchArgument(left, newBuilder) - _ <- buildSearchArgument(right, newBuilder) - lhs <- buildSearchArgument(left, builder.startOr()) - rhs <- buildSearchArgument(right, lhs) + _ <- buildSearchArgument(dataTypeMap, left, newBuilder) + _ <- buildSearchArgument(dataTypeMap, right, newBuilder) + lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr()) + rhs <- buildSearchArgument(dataTypeMap, right, lhs) } yield rhs.end() case Not(child) => for { - _ <- buildSearchArgument(child, newBuilder) - negate <- buildSearchArgument(child, builder.startNot()) + _ <- buildSearchArgument(dataTypeMap, child, newBuilder) + negate <- buildSearchArgument(dataTypeMap, child, builder.startNot()) } yield negate.end() // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). - case EqualTo(attribute, value) if isSearchableLiteral(value) => + case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startAnd().equals(attribute, value).end()) - case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) if isSearchableLiteral(value) => + case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) if isSearchableLiteral(value) => + case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startNot().lessThan(attribute, value).end()) - case IsNull(attribute) => + case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startAnd().isNull(attribute).end()) - case IsNotNull(attribute) => + case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) if values.forall(isSearchableLiteral) => + case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 89d258e844280859002ecaaeec2fd7dca99b80fa..fed31503043e9fb48e82bd5b51b108e3645d3941 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -118,7 +118,7 @@ private[sql] class DefaultSource hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates - OrcFilters.createFilter(filters.toArray).foreach { f => + OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => hadoopConf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } @@ -272,14 +272,6 @@ private[orc] case class OrcTableScan( val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) val conf = job.getConfiguration - // Tries to push down filters if ORC filter push-down is enabled - if (sparkSession.sessionState.conf.orcFilterPushDown) { - OrcFilters.createFilter(filters).foreach { f => - conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) - conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) - } - } - // Figure out the actual schema from the ORC source (without partition columns) so that we // can pick the correct ordinals. Note that this assumes that all files have the same schema. val orcFormat = new DefaultSource @@ -287,6 +279,15 @@ private[orc] case class OrcTableScan( orcFormat .inferSchema(sparkSession, Map.empty, inputPaths) .getOrElse(sys.error("Failed to read schema from target ORC files.")) + + // Tries to push down filters if ORC filter push-down is enabled + if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(dataSchema, filters).foreach { f => + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + // Sets requested columns OrcRelation.setRequiredColumns(conf, dataSchema, StructType.fromAttributes(attributes)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index ddabab3a14b514a281c599df66386740d59bc056..8c027f9935f879ebf03144fefe06e2aa6b03a9ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ @@ -54,7 +55,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(selectedFilters.toArray) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") checker(maybeFilter.get) } @@ -78,10 +79,28 @@ class OrcFilterSuite extends QueryTest with OrcTest { checkFilterPredicate(df, predicate, checkLogicalOperator) } - test("filter pushdown - boolean") { - withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) - } + private def checkNoFilterPredicate + (predicate: Predicate) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") } test("filter pushdown - integer") { @@ -189,16 +208,6 @@ class OrcFilterSuite extends QueryTest with OrcTest { } } - test("filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) - } - - withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => - checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) - } - } - test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked @@ -238,4 +247,40 @@ class OrcFilterSuite extends QueryTest with OrcTest { ) } } + + test("no filter pushdown - non-supported types") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) + } + // ArrayType + withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => + checkNoFilterPredicate('_1.isNull) + } + // DecimalType + withOrcDataFrame((1 to 4).map(i => Tuple1(BigDecimal.valueOf(i)))) { implicit df => + checkNoFilterPredicate('_1 <= BigDecimal.valueOf(4)) + } + // BinaryType + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkNoFilterPredicate('_1 <=> 1.b) + } + // BooleanType + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkNoFilterPredicate('_1 === true) + } + // TimestampType + val stringTimestamp = "2015-08-20 15:57:00" + withOrcDataFrame(Seq(Tuple1(Timestamp.valueOf(stringTimestamp)))) { implicit df => + checkNoFilterPredicate('_1 <=> Timestamp.valueOf(stringTimestamp)) + } + // DateType + val stringDate = "2015-01-01" + withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => + checkNoFilterPredicate('_1 === Date.valueOf(stringDate)) + } + // MapType + withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df => + checkNoFilterPredicate('_1.isNotNull) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index fb678be234a2760ff544c6268011a61d7d0ebdbe..aa9c1189db3b1d6abd2babbffc21d635ed88a2be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -443,4 +443,18 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-14962 Produce correct results on array type with isnotnull") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val data = (0 until 10).map(i => Tuple1(Array(i))) + withOrcFile(data) { file => + val actual = sqlContext + .read + .orc(file) + .where("_1 is not null") + val expected = data.toDF() + checkAnswer(actual, expected) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index bdd3428a897427ae0f1bbfc1f52c640e450ccbd4..96a7364437c782aaf96685807b16288b69ddf591 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ case class OrcData(intField: Int, stringField: String) @@ -182,12 +183,16 @@ class OrcSourceSuite extends OrcSuite { test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) assertResult( """leaf-0 = (LESS_THAN a 10) |expr = leaf-0 """.stripMargin.trim ) { - OrcFilters.createFilter(Array( + OrcFilters.createFilter(schema, Array( LessThan("a", 10), StringContains("b", "prefix") )).get.toString @@ -199,7 +204,7 @@ class OrcSourceSuite extends OrcSuite { |expr = leaf-0 """.stripMargin.trim ) { - OrcFilters.createFilter(Array( + OrcFilters.createFilter(schema, Array( LessThan("a", 10), Not(And( GreaterThan("a", 1),