diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 5542a521b173dab286b39c5974503fc7747b2e3a..d12778c7583df66de523e2d29606c133a4d0c179 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, IntegralType} /** @@ -312,37 +312,41 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { + /** + * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. + * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". + * + * Unsupported predicates are skipped. + */ + def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) .map(col => col.getName).toSet - // Hive getPartitionsByFilter() takes a string that represents partition - // predicates like "str_key=\"value\" and int_key=1 ..." - val filter = predicates.flatMap { expr => - expr match { - case op @ BinaryComparison(lhs, rhs) => { - lhs match { - case AttributeReference(_, _, _, _) => { - rhs.dataType match { - case _: IntegralType => - Some(lhs.prettyString + op.symbol + rhs.prettyString) - case _: StringType if (!varcharKeys.contains(lhs.prettyString)) => - Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"") - case _ => None - } - } - case _ => None - } - } - case _ => None - } + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} "$v"""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s""""$v" ${op.symbol} ${a.name}""" }.mkString(" and ") + } + + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + // Hive getPartitionsByFilter() takes a string that represents partition + // predicates like "str_key=\"value\" and int_key=1 ..." + val filter = convertFilters(table, predicates) val partitions = if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0efcf80bd4ea722e147ccb6a2363bc7cd8e53e0a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A set of tests for the filter conversion logic used when pushing partition pruning into the + * metastore + */ +class FiltersSuite extends SparkFunSuite with Logging { + private val shim = new Shim_v0_13 + + private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") + private val varCharCol = new FieldSchema() + varCharCol.setName("varchar") + varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) + testTable.setPartCols(varCharCol :: Nil) + + filterTest("string filter", + (a("stringcol", StringType) > Literal("test")) :: Nil, + "stringcol > \"test\"") + + filterTest("string filter backwards", + (Literal("test") > a("stringcol", StringType)) :: Nil, + "\"test\" > stringcol") + + filterTest("int filter", + (a("intcol", IntegerType) === Literal(1)) :: Nil, + "intcol = 1") + + filterTest("int filter backwards", + (Literal(1) === a("intcol", IntegerType)) :: Nil, + "1 = intcol") + + filterTest("int and string filter", + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + "1 = intcol and \"a\" = strcol") + + filterTest("skip varchar", + (Literal("") === a("varchar", StringType)) :: Nil, + "") + + private def filterTest(name: String, filters: Seq[Expression], result: String) = { + test(name){ + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail( + s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() +}