diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 7abb9f06b131078816124d0351d36d8a2188d3df..449a303b59eed504126ec99a4e0d2df9f042431c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf @@ -46,6 +47,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTableParti import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -589,18 +591,67 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet - filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + object ExtractableLiteral { + def unapply(expr: Expression): Option[String] = expr match { + case Literal(value, _: IntegralType) => Some(value.toString) + case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) + case _ => None + } + } + + object ExtractableLiterals { + def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { + exprs.map(ExtractableLiteral.unapply).foldLeft(Option(Seq.empty[String])) { + case (Some(accum), Some(value)) => Some(accum :+ value) + case _ => None + } + } + } + + object ExtractableValues { + private lazy val valueToLiteralString: PartialFunction[Any, String] = { + case value: Byte => value.toString + case value: Short => value.toString + case value: Int => value.toString + case value: Long => value.toString + case value: UTF8String => quoteStringLiteral(value.toString) + } + + def unapply(values: Set[Any]): Option[Seq[String]] = { + values.toSeq.foldLeft(Option(Seq.empty[String])) { + case (Some(accum), value) if valueToLiteralString.isDefinedAt(value) => + Some(accum :+ valueToLiteralString(value)) + case _ => None + } + } + } + + def convertInToOr(a: Attribute, values: Seq[String]): String = { + values.map(value => s"${a.name} = $value").mkString("(", " or ", ")") + } + + lazy val convert: PartialFunction[Expression, String] = { + case In(a: Attribute, ExtractableLiterals(values)) + if !varcharKeys.contains(a.name) && values.nonEmpty => + convertInToOr(a, values) + case InSet(a: Attribute, ExtractableValues(values)) + if !varcharKeys.contains(a.name) && values.nonEmpty => + convertInToOr(a, values) + case op @ BinaryComparison(a: Attribute, ExtractableLiteral(value)) if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + s"${a.name} ${op.symbol} $value" + case op @ BinaryComparison(ExtractableLiteral(value), a: Attribute) if !varcharKeys.contains(a.name) => - s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" - }.mkString(" and ") + s"$value ${op.symbol} ${a.name}" + case op @ And(expr1, expr2) + if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => + (convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") + case op @ Or(expr1, expr2) + if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => + s"(${convert(expr1)} or ${convert(expr2)})" + } + + filters.map(convert.lift).collect { case Some(filterString) => filterString }.mkString(" and ") } private def quoteStringLiteral(str: String): String = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index e85ea5a59427de77967a4ea5a7e4d29572de1ef1..ae804ce7c7b07fa472ceed20b44fdb10c2437983 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -25,9 +25,7 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.SparkConf import org.apache.spark.util.Utils -private[client] class HiveClientBuilder { - private val sparkConf = new SparkConf() - +private[client] object HiveClientBuilder { // In order to speed up test execution during development or in Jenkins, you can specify the path // of an existing Ivy cache: private val ivyPath: Option[String] = { @@ -52,7 +50,7 @@ private[client] class HiveClientBuilder { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, - sparkConf = sparkConf, + sparkConf = new SparkConf(), hadoopConf = hadoopConf, config = buildConf(extraConf), ivyPath = ivyPath).createClient() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 4790331168bd284f72b95a79a7786e9e1157127c..6a2c23a015529497a07384ebf615eb6254a5efc5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -19,21 +19,25 @@ package org.apache.spark.sql.hive.client import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EmptyRow, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, InSet, LessThan, LessThanOrEqual, Like, Literal, Or} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{ByteType, IntegerType, StringType} -class HiveClientSuite extends SparkFunSuite { - private val clientBuilder = new HiveClientBuilder +// TODO: Refactor this to `HivePartitionFilteringSuite` +class HiveClientSuite(version: String) + extends HiveVersionSuite(version) with BeforeAndAfterAll { + import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname - test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { - val testPartitionCount = 5 + private val testPartitionCount = 3 * 24 * 4 + private def init(tryDirectSql: Boolean): HiveClient = { val storageFormat = CatalogStorageFormat( locationUri = None, inputFormat = None, @@ -43,19 +47,214 @@ class HiveClientSuite extends SparkFunSuite { properties = Map.empty) val hadoopConf = new Configuration() - hadoopConf.setBoolean(tryDirectSqlKey, false) - val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) - client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)") + hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) + val client = buildClient(hadoopConf) + client + .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + + val partitions = + for { + ds <- 20170101 to 20170103 + h <- 0 to 23 + chunk <- Seq("aa", "ab", "ba", "bb") + } yield CatalogTablePartition(Map( + "ds" -> ds.toString, + "h" -> h.toString, + "chunk" -> chunk + ), storageFormat) + assert(partitions.size == testPartitionCount) - val partitions = (1 to testPartitionCount).map { part => - CatalogTablePartition(Map("part" -> part.toString), storageFormat) - } client.createPartitions( "default", "test", partitions, ignoreIfExists = false) + client + } + override def beforeAll() { + client = init(true) + } + + test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { + val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3)))) + Seq(parseExpression("ds=20170101"))) assert(filteredPartitions.size == testPartitionCount) } + + test("getPartitionsByFilter: ds=20170101") { + testMetastorePartitionFiltering( + "ds=20170101", + 20170101 to 20170101, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds=(20170101 + 1) and h=0") { + // Should return all partitions where h=0 because getPartitionsByFilter does not support + // comparisons to non-literal values + testMetastorePartitionFiltering( + "ds=(20170101 + 1) and h=0", + 20170101 to 20170103, + 0 to 0, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: chunk='aa'") { + testMetastorePartitionFiltering( + "chunk='aa'", + 20170101 to 20170103, + 0 to 23, + "aa" :: Nil) + } + + test("getPartitionsByFilter: 20170101=ds") { + testMetastorePartitionFiltering( + "20170101=ds", + 20170101 to 20170101, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds=20170101 and h=10") { + testMetastorePartitionFiltering( + "ds=20170101 and h=10", + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds=20170101 or ds=20170102") { + testMetastorePartitionFiltering( + "ds=20170101 or ds=20170102", + 20170101 to 20170102, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { + testMetastorePartitionFiltering( + "ds in (20170102, 20170103)", + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { + testMetastorePartitionFiltering( + "ds in (20170102, 20170103)", + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, Set() ++ list.map(_.eval(EmptyRow))) + }) + } + + test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { + testMetastorePartitionFiltering( + "chunk in ('ab', 'ba')", + 20170101 to 20170103, + 0 to 23, + "ab" :: "ba" :: Nil) + } + + test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { + testMetastorePartitionFiltering( + "chunk in ('ab', 'ba')", + 20170101 to 20170103, + 0 to 23, + "ab" :: "ba" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, Set() ++ list.map(_.eval(EmptyRow))) + }) + } + + test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { + val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) + testMetastorePartitionFiltering( + "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", + day1 :: day2 :: Nil) + } + + test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { + val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + // Day 2 should include all hours because we can't build a filter for h<(7+1) + val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) + testMetastorePartitionFiltering( + "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", + day1 :: day2 :: Nil) + } + + test("getPartitionsByFilter: " + + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { + val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) + val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) + testMetastorePartitionFiltering( + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + day1 :: day2 :: Nil) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedDs: Seq[Int], + expectedH: Seq[Int], + expectedChunks: Seq[String]): Unit = { + testMetastorePartitionFiltering( + filterString, + (expectedDs, expectedH, expectedChunks) :: Nil, + identity) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedDs: Seq[Int], + expectedH: Seq[Int], + expectedChunks: Seq[String], + transform: Expression => Expression): Unit = { + testMetastorePartitionFiltering( + filterString, + (expectedDs, expectedH, expectedChunks) :: Nil, + identity) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { + testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + } + + private def testMetastorePartitionFiltering( + filterString: String, + expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], + transform: Expression => Expression): Unit = { + val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), + Seq( + transform(parseExpression(filterString)) + )) + + val expectedPartitionCount = expectedPartitionCubes.map { + case (expectedDs, expectedH, expectedChunks) => + expectedDs.size * expectedH.size * expectedChunks.size + }.sum + + val expectedPartitions = expectedPartitionCubes.map { + case (expectedDs, expectedH, expectedChunks) => + for { + ds <- expectedDs + h <- expectedH + chunk <- expectedChunks + } yield Set( + "ds" -> ds.toString, + "h" -> h.toString, + "chunk" -> chunk + ) + }.reduce(_ ++ _) + + val actualFilteredPartitionCount = filteredPartitions.size + + assert(actualFilteredPartitionCount == expectedPartitionCount, + s"Expected $expectedPartitionCount partitions but got $actualFilteredPartitionCount") + assert(filteredPartitions.map(_.spec.toSet).toSet == expectedPartitions.toSet) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala new file mode 100644 index 0000000000000000000000000000000000000000..de1be2115b2d852f7e36add6d9a33b45b0dc2472 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.Suite + +class HiveClientSuites extends Suite with HiveClientVersions { + override def nestedSuites: IndexedSeq[Suite] = { + // Hive 0.12 does not provide the partition filtering API we call + versions.filterNot(_ == "0.12").map(new HiveClientSuite(_)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala new file mode 100644 index 0000000000000000000000000000000000000000..2e7dfde8b2fa557c098898a7894f6660352fa249 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.immutable.IndexedSeq + +import org.apache.spark.SparkFunSuite + +private[client] trait HiveClientVersions { + protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..986c6675cbb63f838bbf2b75d6920bc5de1ad4d5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import org.apache.hadoop.conf.Configuration +import org.scalatest.Tag + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.hive.HiveUtils + +private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + protected var client: HiveClient = null + + protected def buildClient(hadoopConf: Configuration): HiveClient = { + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and + // hive.metastore.schema.verification from false to true since 2.0 + // For details, see the JIRA HIVE-6113 and HIVE-12463 + if (version == "2.0" || version == "2.1") { + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") + } + HiveClientBuilder + .buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + } + + override def suiteName: String = s"${super.suiteName}($version)" + + override protected def test(testName: String, testTags: Tag*)(testFun: => Unit): Unit = { + super.test(s"$version: $testName", testTags: _*)(testFun) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index f109843f5be20917a2ca5d803bbcc451df61b922..82fbdd645ebe00bedadb8d1425ee0b0ceb67b3b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -47,11 +47,11 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +// TODO: Refactor this to `HiveClientSuite` and make it a subclass of `HiveVersionSuite` @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - private val clientBuilder = new HiveClientBuilder - import clientBuilder.buildClient + import HiveClientBuilder.buildClient /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f`