diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0bde48ce57c86e8722c082d62cbd780ee5b175eb..3f9227a8ae002267fded2bf95fb136633866eca9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -301,6 +302,7 @@ object FunctionRegistry { expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), + expression[XPathBoolean]("xpath_boolean"), // datetime functions expression[AddMonths]("add_months"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala new file mode 100644 index 0000000000000000000000000000000000000000..2a5256c7f56fd6225232afd5fecc12e6f31bcac5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.", + extended = "> SELECT _FUNC_('<a><b>1</b></a>','a/b');\ntrue") +case class XPathBoolean(xml: Expression, path: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + + @transient private lazy val xpathUtil = new UDFXPathUtil + + // If the path is a constant, cache the path string so that we don't need to convert path + // from UTF8String to String for every row. + @transient lazy val pathLiteral: String = path match { + case Literal(str: UTF8String, _) => str.toString + case _ => null + } + + override def prettyName: String = "xpath_boolean" + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def left: Expression = xml + override def right: Expression = path + + override protected def nullSafeEval(xml: Any, path: Any): Any = { + val xmlString = xml.asInstanceOf[UTF8String].toString + if (pathLiteral ne null) { + xpathUtil.evalBoolean(xmlString, pathLiteral) + } else { + xpathUtil.evalBoolean(xmlString, path.asInstanceOf[UTF8String].toString) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f7c65c667efbde854af38f4b5d0fe4c8579e9758 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} +import org.apache.spark.sql.types.StringType + +/** + * Test suite for various xpath functions. + */ +class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def testBoolean[T](xml: String, path: String, expected: T): Unit = { + checkEvaluation( + XPathBoolean(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + test("xpath_boolean") { + testBoolean("<a><b>b</b></a>", "a/b", true) + testBoolean("<a><b>b</b></a>", "a/c", false) + testBoolean("<a><b>b</b></a>", "a/b = \"b\"", true) + testBoolean("<a><b>b</b></a>", "a/b = \"c\"", false) + testBoolean("<a><b>10</b></a>", "a/b < 10", false) + testBoolean("<a><b>10</b></a>", "a/b = 10", true) + + // null input + testBoolean(null, null, null) + testBoolean(null, "a", null) + testBoolean("<a><b>10</b></a>", null, null) + + // exception handling for invalid input + intercept[Exception] { + testBoolean("<a>/a>", "a", null) + } + } + + test("xpath_boolean path cache invalidation") { + // This is a test to ensure the expression is not reusing the path for different strings + val expr = XPathBoolean(Literal("<a><b>b</b></a>"), 'path.string.at(0)) + checkEvaluation(expr, true, create_row("a/b")) + checkEvaluation(expr, false, create_row("a/c")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..532d48cc265ac71ef3ef9c24efcd9b2749c360ff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/** + * End-to-end tests for XML expressions. + */ +class XmlFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("xpath_boolean") { + val df = Seq("<a><b>b</b></a>" -> "a/b").toDF("xml", "path") + checkAnswer(df.selectExpr("xpath_boolean(xml, path)"), Row(true)) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 2589b9d4a0284648b5e9081196a10fefe661f41c..fa560a044b42a332a037a534ca0126c95986a51f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -241,7 +241,7 @@ private[sql] class HiveSessionCatalog( "elt", "hash", "java_method", "histogram_numeric", "map_keys", "map_values", "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", - "xpath", "xpath_boolean", "xpath_double", "xpath_float", "xpath_int", "xpath_long", + "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", "xpath_number", "xpath_short", "xpath_string", // table generating function