diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index 410e9e51ba20843b1e8ccf921f7afb7dd9bc2cd2..d224332d8a6c9ca2c162292921f16b189db03f05 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -43,7 +43,7 @@ public class UDFXPathUtil { private XPathExpression expression = null; private String oldPath = null; - public Object eval(String xml, String path, QName qname) { + public Object eval(String xml, String path, QName qname) throws XPathExpressionException { if (xml == null || path == null || qname == null) { return null; } @@ -56,7 +56,7 @@ public class UDFXPathUtil { try { expression = xpath.compile(path); } catch (XPathExpressionException e) { - expression = null; + throw new RuntimeException("Invalid XPath '" + path + "'" + e.getMessage(), e); } oldPath = path; } @@ -66,31 +66,30 @@ public class UDFXPathUtil { } reader.set(xml); - try { return expression.evaluate(inputSource, qname); } catch (XPathExpressionException e) { - throw new RuntimeException("Invalid expression '" + oldPath + "'", e); + throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); } } - public Boolean evalBoolean(String xml, String path) { + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } - public String evalString(String xml, String path) { + public String evalString(String xml, String path) throws XPathExpressionException { return (String) eval(xml, path, XPathConstants.STRING); } - public Double evalNumber(String xml, String path) { + public Double evalNumber(String xml, String path) throws XPathExpressionException { return (Double) eval(xml, path, XPathConstants.NUMBER); } - public Node evalNode(String xml, String path) { + public Node evalNode(String xml, String path) throws XPathExpressionException { return (Node) eval(xml, path, XPathConstants.NODE); } - public NodeList evalNodeList(String xml, String path) { + public NodeList evalNodeList(String xml, String path) throws XPathExpressionException { return (NodeList) eval(xml, path, XPathConstants.NODESET); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c8bbbf88532dcfbdf54ee6844312015bed7d9318..54568b7445df803a0a417bebaa97725758f4ce9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -310,7 +310,15 @@ object FunctionRegistry { expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), + expression[XPathList]("xpath"), expression[XPathBoolean]("xpath_boolean"), + expression[XPathDouble]("xpath_double"), + expression[XPathDouble]("xpath_number"), + expression[XPathFloat]("xpath_float"), + expression[XPathInt]("xpath_int"), + expression[XPathLong]("xpath_long"), + expression[XPathShort]("xpath_short"), + expression[XPathString]("xpath_string"), // datetime functions expression[AddMonths]("add_months"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala deleted file mode 100644 index 2a5256c7f56fd6225232afd5fecc12e6f31bcac5..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.xml - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, StringType} -import org.apache.spark.unsafe.types.UTF8String - - -@ExpressionDescription( - usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.", - extended = "> SELECT _FUNC_('<a><b>1</b></a>','a/b');\ntrue") -case class XPathBoolean(xml: Expression, path: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - - @transient private lazy val xpathUtil = new UDFXPathUtil - - // If the path is a constant, cache the path string so that we don't need to convert path - // from UTF8String to String for every row. - @transient lazy val pathLiteral: String = path match { - case Literal(str: UTF8String, _) => str.toString - case _ => null - } - - override def prettyName: String = "xpath_boolean" - - override def dataType: DataType = BooleanType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - - override def left: Expression = xml - override def right: Expression = path - - override protected def nullSafeEval(xml: Any, path: Any): Any = { - val xmlString = xml.asInstanceOf[UTF8String].toString - if (pathLiteral ne null) { - xpathUtil.evalBoolean(xmlString, pathLiteral) - } else { - xpathUtil.evalBoolean(xmlString, path.asInstanceOf[UTF8String].toString) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala new file mode 100644 index 0000000000000000000000000000000000000000..47f039e6a4cc490fc01cdd90df243fb7e2a49289 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Base class for xpath_boolean, xpath_double, xpath_int, etc. + * + * This is not the world's most efficient implementation due to type conversion, but works. + */ +abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + override def left: Expression = xml + override def right: Expression = path + + /** XPath expressions are always nullable, e.g. if the xml string is empty. */ + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!path.foldable) { + TypeCheckFailure("path should be a string literal") + } else { + super.checkInputDataTypes() + } + } + + @transient protected lazy val xpathUtil = new UDFXPathUtil + @transient protected lazy val pathString: String = path.eval().asInstanceOf[UTF8String].toString + + /** Concrete implementations need to override the following three methods. */ + def xml: Expression + def path: Expression +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.", + extended = "> SELECT _FUNC_('<a><b>1</b></a>','a/b');\ntrue") +case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract { + + override def prettyName: String = "xpath_boolean" + override def dataType: DataType = BooleanType + + override def nullSafeEval(xml: Any, path: Any): Any = { + xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString) + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a short value that matches the xpath expression", + extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3") +case class XPathShort(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_int" + override def dataType: DataType = ShortType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.shortValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns an integer value that matches the xpath expression", + extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3") +case class XPathInt(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_int" + override def dataType: DataType = IntegerType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.intValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a long value that matches the xpath expression", + extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3") +case class XPathLong(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_long" + override def dataType: DataType = LongType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.longValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a float value that matches the xpath expression", + extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3.0") +case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_float" + override def dataType: DataType = FloatType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.floatValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a double value that matches the xpath expression", + extended = "> SELECT _FUNC_('<a><b>1</b><b>2</b></a>','sum(a/b)');\n3.0") +case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_float" + override def dataType: DataType = DoubleType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.doubleValue() + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns the text contents of the first xml node that matches the xpath expression", + extended = "> SELECT _FUNC_('<a><b>b</b><c>cc</c></a>','a/c');\ncc") +// scalastyle:on line.size.limit +case class XPathString(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_string" + override def dataType: DataType = StringType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) + UTF8String.fromString(ret) + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a string array of values within xml nodes that match the xpath expression", + extended = "> SELECT _FUNC_('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()');\n['b1','b2','b3']") +// scalastyle:on line.size.limit +case class XPathList(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath" + override def dataType: DataType = ArrayType(StringType, containsNull = false) + + override def nullSafeEval(xml: Any, path: Any): Any = { + val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString) + if (nodeList ne null) { + val ret = new Array[UTF8String](nodeList.getLength) + var i = 0 + while (i < nodeList.getLength) { + ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue) + i += 1 + } + new GenericArrayData(ret) + } else { + null + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index a5614f83844e0aa25c25925b8e662e91e9f7af88..c4cde7091154b1c3bba138be77c9ba25e8b2efc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -43,8 +43,9 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(util.eval("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "", STRING) == null) // wrong expression: - assert( - util.eval("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/text(", STRING) == null) + intercept[RuntimeException] { + util.eval("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/text(", STRING) + } } test("generic eval") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index f7c65c667efbde854af38f4b5d0fe4c8579e9758..bfa18a0919e45d4b6dc56fc8c7198ad550ea2542 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.xml import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StringType /** @@ -27,35 +26,183 @@ import org.apache.spark.sql.types.StringType */ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - private def testBoolean[T](xml: String, path: String, expected: T): Unit = { - checkEvaluation( - XPathBoolean(Literal.create(xml, StringType), Literal.create(path, StringType)), - expected) + /** A helper function that tests null and error behaviors for xpath expressions. */ + private def testNullAndErrorBehavior[T <: AnyRef](testExpr: (String, String, T) => Unit): Unit = { + // null input should lead to null output + testExpr("<a><b>b1</b><b id='b_2'>b2</b></a>", null, null.asInstanceOf[T]) + testExpr(null, "a", null.asInstanceOf[T]) + testExpr(null, null, null.asInstanceOf[T]) + + // Empty input should also lead to null output + testExpr("", "a", null.asInstanceOf[T]) + testExpr("<a></a>", "", null.asInstanceOf[T]) + testExpr("", "", null.asInstanceOf[T]) + + // Test error message for invalid XML document + val e1 = intercept[RuntimeException] { testExpr("<a>/a>", "a", null.asInstanceOf[T]) } + assert(e1.getCause.getMessage.contains("Invalid XML document") && + e1.getCause.getMessage.contains("<a>/a>")) + + // Test error message for invalid xpath + val e2 = intercept[RuntimeException] { testExpr("<a></a>", "!#$", null.asInstanceOf[T]) } + assert(e2.getCause.getMessage.contains("Invalid XPath") && + e2.getCause.getMessage.contains("!#$")) } test("xpath_boolean") { - testBoolean("<a><b>b</b></a>", "a/b", true) - testBoolean("<a><b>b</b></a>", "a/c", false) - testBoolean("<a><b>b</b></a>", "a/b = \"b\"", true) - testBoolean("<a><b>b</b></a>", "a/b = \"c\"", false) - testBoolean("<a><b>10</b></a>", "a/b < 10", false) - testBoolean("<a><b>10</b></a>", "a/b = 10", true) + def testExpr[T](xml: String, path: String, expected: java.lang.Boolean): Unit = { + checkEvaluation( + XPathBoolean(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a><b>b</b></a>", "a/b", true) + testExpr("<a><b>b</b></a>", "a/c", false) + testExpr("<a><b>b</b></a>", "a/b = \"b\"", true) + testExpr("<a><b>b</b></a>", "a/b = \"c\"", false) + testExpr("<a><b>10</b></a>", "a/b < 10", false) + testExpr("<a><b>10</b></a>", "a/b = 10", true) - // null input - testBoolean(null, null, null) - testBoolean(null, "a", null) - testBoolean("<a><b>10</b></a>", null, null) + testNullAndErrorBehavior(testExpr) + } - // exception handling for invalid input - intercept[Exception] { - testBoolean("<a>/a>", "a", null) + test("xpath_short") { + def testExpr[T](xml: String, path: String, expected: java.lang.Short): Unit = { + checkEvaluation( + XPathShort(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) } + + testExpr("<a>this is not a number</a>", "a", 0.toShort) + testExpr("<a>try a boolean</a>", "a = 10", 0.toShort) + testExpr( + "<a><b class=\"odd\">10000</b><b class=\"even\">2</b><b class=\"odd\">4</b><c>8</c></a>", + "sum(a/b[@class=\"odd\"])", + 10004.toShort) + + testNullAndErrorBehavior(testExpr) } - test("xpath_boolean path cache invalidation") { - // This is a test to ensure the expression is not reusing the path for different strings - val expr = XPathBoolean(Literal("<a><b>b</b></a>"), 'path.string.at(0)) - checkEvaluation(expr, true, create_row("a/b")) - checkEvaluation(expr, false, create_row("a/c")) + test("xpath_int") { + def testExpr[T](xml: String, path: String, expected: java.lang.Integer): Unit = { + checkEvaluation( + XPathInt(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a>this is not a number</a>", "a", 0) + testExpr("<a>try a boolean</a>", "a = 10", 0) + testExpr( + "<a><b class=\"odd\">100000</b><b class=\"even\">2</b><b class=\"odd\">4</b><c>8</c></a>", + "sum(a/b[@class=\"odd\"])", + 100004) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_long") { + def testExpr[T](xml: String, path: String, expected: java.lang.Long): Unit = { + checkEvaluation( + XPathLong(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a>this is not a number</a>", "a", 0L) + testExpr("<a>try a boolean</a>", "a = 10", 0L) + testExpr( + "<a><b class=\"odd\">9000000000</b><b class=\"even\">2</b><b class=\"odd\">4</b><c>8</c></a>", + "sum(a/b[@class=\"odd\"])", + 9000000004L) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_float") { + def testExpr[T](xml: String, path: String, expected: java.lang.Float): Unit = { + checkEvaluation( + XPathFloat(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a>this is not a number</a>", "a", Float.NaN) + testExpr("<a>try a boolean</a>", "a = 10", 0.0F) + testExpr("<a><b class=\"odd\">1</b><b class=\"even\">2</b><b class=\"odd\">4</b><c>8</c></a>", + "sum(a/b[@class=\"odd\"])", + 5.0F) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_double") { + def testExpr[T](xml: String, path: String, expected: java.lang.Double): Unit = { + checkEvaluation( + XPathDouble(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a>this is not a number</a>", "a", Double.NaN) + testExpr("<a>try a boolean</a>", "a = 10", 0.0) + testExpr("<a><b class=\"odd\">1</b><b class=\"even\">2</b><b class=\"odd\">4</b><c>8</c></a>", + "sum(a/b[@class=\"odd\"])", + 5.0) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_string") { + def testExpr[T](xml: String, path: String, expected: String): Unit = { + checkEvaluation( + XPathString(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a><b>bb</b><c>cc</c></a>", "a", "bbcc") + testExpr("<a><b>bb</b><c>cc</c></a>", "a/b", "bb") + testExpr("<a><b>bb</b><c>cc</c></a>", "a/c", "cc") + testExpr("<a><b>bb</b><c>cc</c></a>", "a/d", "") + testExpr("<a><b>b1</b><b>b2</b></a>", "//b", "b1") + testExpr("<a><b>b1</b><b>b2</b></a>", "a/b[1]", "b1") + testExpr("<a><b>b1</b><b id='b_2'>b2</b></a>", "a/b[@id='b_2']", "b2") + + testNullAndErrorBehavior(testExpr) + } + + test("xpath") { + def testExpr[T](xml: String, path: String, expected: Seq[String]): Unit = { + checkEvaluation( + XPathList(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/text()", Seq.empty[String]) + testExpr("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/*/text()", + Seq("b1", "b2", "b3", "c1", "c2")) + testExpr("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/b/text()", + Seq("b1", "b2", "b3")) + testExpr("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/c/text()", Seq("c1", "c2")) + testExpr("<a><b class='bb'>b1</b><b>b2</b><b>b3</b><c class='bb'>c1</c><c>c2</c></a>", + "a/*[@class='bb']/text()", Seq("b1", "c1")) + + testNullAndErrorBehavior(testExpr) + } + + test("accept only literal path") { + def testExpr(exprCtor: (Expression, Expression) => Expression): Unit = { + // Validate that literal (technically this is foldable) paths are supported + val litPath = exprCtor(Literal("abcd"), Concat(Literal("/") :: Literal("/") :: Nil)) + assert(litPath.checkInputDataTypes().isSuccess) + + // Validate that non-foldable paths are not supported. + val nonLitPath = exprCtor(Literal("abcd"), NonFoldableLiteral("/")) + assert(nonLitPath.checkInputDataTypes().isFailure) + } + + testExpr(XPathBoolean) + testExpr(XPathShort) + testExpr(XPathInt) + testExpr(XPathLong) + testExpr(XPathFloat) + testExpr(XPathDouble) + testExpr(XPathString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1d33e7970be8e0a2ed2de94b5e5bff3d4a320fb8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/** + * End-to-end tests for xpath expressions. + */ +class XPathFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("xpath_boolean") { + val df = Seq("<a><b>b</b></a>").toDF("xml") + checkAnswer(df.selectExpr("xpath_boolean(xml, 'a/b')"), Row(true)) + } + + test("xpath_short, xpath_int, xpath_long") { + val df = Seq("<a><b>1</b><b>2</b></a>").toDF("xml") + checkAnswer( + df.selectExpr( + "xpath_short(xml, 'sum(a/b)')", + "xpath_int(xml, 'sum(a/b)')", + "xpath_long(xml, 'sum(a/b)')"), + Row(3.toShort, 3, 3L)) + } + + test("xpath_float, xpath_double, xpath_number") { + val df = Seq("<a><b>1.0</b><b>2.1</b></a>").toDF("xml") + checkAnswer( + df.selectExpr( + "xpath_float(xml, 'sum(a/b)')", + "xpath_double(xml, 'sum(a/b)')", + "xpath_number(xml, 'sum(a/b)')"), + Row(3.1.toFloat, 3.1, 3.1)) + } + + test("xpath_string") { + val df = Seq("<a><b>b</b><c>cc</c></a>").toDF("xml") + checkAnswer(df.selectExpr("xpath_string(xml, 'a/c')"), Row("cc")) + } + + test("xpath") { + val df = Seq("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>").toDF("xml") + checkAnswer(df.selectExpr("xpath(xml, 'a/*/text()')"), Row(Seq("b1", "b2", "b3", "c1", "c2"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala deleted file mode 100644 index 532d48cc265ac71ef3ef9c24efcd9b2749c360ff..0000000000000000000000000000000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.SharedSQLContext - -/** - * End-to-end tests for XML expressions. - */ -class XmlFunctionsSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - test("xpath_boolean") { - val df = Seq("<a><b>b</b></a>" -> "a/b").toDF("xml", "path") - checkAnswer(df.selectExpr("xpath_boolean(xml, path)"), Row(true)) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9c7f461362d8432e2692f39076d88b3750a32c8c..6f36abc4db0ed66df846dcd53a289bfb8f0a9d0a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -236,8 +236,6 @@ private[sql] class HiveSessionCatalog( // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( "hash", "java_method", "histogram_numeric", - "percentile", "percentile_approx", "reflect", "str_to_map", - "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", - "xpath_number", "xpath_short", "xpath_string" + "percentile", "percentile_approx", "reflect", "str_to_map" ) }