diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 87f58e5f1aa37e71a186942e8d816426ed31fb3b..dd8e5c6da08c9dd6e362d6c9332ee07dee53487e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -569,13 +569,24 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { s"$v ${op.symbol} ${a.name}" case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" }.mkString(" and ") } + private def quoteStringLiteral(str: String): String = { + if (!str.contains("\"")) { + s""""$str"""" + } else if (!str.contains("'")) { + s"""'$str'""" + } else { + throw new UnsupportedOperationException( + """Partition filter cannot have both `"` and `'` characters""") + } + } + override def getPartitionsByFilter( hive: Hive, table: Table, @@ -584,6 +595,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // Hive getPartitionsByFilter() takes a string that represents partition // predicates like "str_key=\"value\" and int_key=1 ..." val filter = convertFilters(table, predicates) + val partitions = if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index cd96c85f3e209d5c4990cc3ca8151b6d151602f5..031c1a5ec0ec30657389700faca16093431f62cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -65,6 +65,11 @@ class FiltersSuite extends SparkFunSuite with Logging { (Literal("") === a("varchar", StringType)) :: Nil, "") + filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning", + (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) :: + (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, + """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { val converted = shim.convertFilters(testTable, filters) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index e607af67f93e59a29243f77edcdf6a1e37b06389..1619115007441b1aea0c0327674d5a44ca4872b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2015,4 +2015,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue()) attempt.isSuccess && attempt.get == 0 } + + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { + withTable("spark_19912") { + Seq( + (1, "p1", "q1"), + (2, "'", "q2"), + (3, "\"", "q3"), + (4, "p1\" and q=\"q1", "q4") + ).toDF("a", "p", "q").write.partitionBy("p", "q").saveAsTable("spark_19912") + + val table = spark.table("spark_19912") + checkAnswer(table.filter($"p" === "'").select($"a"), Row(2)) + checkAnswer(table.filter($"p" === "\"").select($"a"), Row(3)) + checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4)) + } + } }