From 1b499993ad185b04dd5065facb565cbe7e249521 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@databricks.com> Date: Tue, 9 Jun 2015 16:24:38 +0800 Subject: [PATCH] [SPARK-7886] Add built-in expressions to FunctionRegistry. This patch switches to using FunctionRegistry for built-in expressions. It is based on #6463, but with some work to simplify it along with unit tests. TODOs for future pull requests: - Use static registration so we don't need to register all functions every time we start a new SQLContext - Switch to using this in HiveContext Author: Reynold Xin <rxin@databricks.com> Author: Santiago M. Mola <santi@mola.io> Closes #6710 from rxin/udf-registry and squashes the following commits: 6930822 [Reynold Xin] Fixed Python test. b802c9a [Reynold Xin] Made UDF case insensitive. e60d815 [Reynold Xin] Made UDF case insensitive. 852f9c0 [Reynold Xin] Fixed style violation. e76a3c1 [Reynold Xin] Fixed parser. 52ddaba [Reynold Xin] Fixed compilation. ee7854f [Reynold Xin] Improved error reporting. ff906f2 [Reynold Xin] More robust constructor calling. 77b46f1 [Reynold Xin] Simplified the code. 2a2a149 [Reynold Xin] Merge pull request #6463 from smola/SPARK-7886 8616924 [Santiago M. Mola] [SPARK-7886] Add built-in expressions to FunctionRegistry. --- python/pyspark/sql/dataframe.py | 2 +- .../apache/spark/sql/catalyst/SqlParser.scala | 75 +++++------ .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/FunctionRegistry.scala | 127 +++++++++++++----- .../sql/catalyst/expressions/Expression.scala | 9 ++ .../sql/catalyst/expressions/random.scala | 23 +++- .../expressions/stringOperations.scala | 7 + .../sql/catalyst/util/StringKeyHashMap.scala | 44 ++++++ .../org/apache/spark/sql/SQLContext.scala | 6 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 42 ++++++ .../apache/spark/sql/hive/HiveContext.scala | 9 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 14 +- 12 files changed, 269 insertions(+), 93 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e9dd05e2d0..9615e57649 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -746,7 +746,7 @@ class DataFrame(object): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e85312aee7..f74c17d583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -48,26 +49,21 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") - protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") - protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") @@ -80,13 +76,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") - protected val LAST = Keyword("LAST") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") - protected val LOWER = Keyword("LOWER") - protected val MAX = Keyword("MAX") - protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -100,15 +92,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") - protected val SQRT = Keyword("SQRT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") - protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") @@ -277,25 +264,36 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val function: Parser[Expression] = - ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } - | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } - | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } - | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ - { case exps => CountDistinct(exps) } - | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ - { case exp => ApproxCountDistinct(exp) } - | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } - | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } - | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } - | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } - | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } - | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } - | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } - | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } - | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case c ~ t ~ f => If(c, t, f) } + ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => + if (lexical.normalizeKeyword(udfName) == "count") { + Count(Literal(1)) + } else { + throw new AnalysisException(s"invalid expression $udfName(*)") + } + } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => + lexical.normalizeKeyword(udfName) match { + case "sum" => SumDistinct(exprs.head) + case "count" => CountDistinct(exprs) + } + } + | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp) + } else { + throw new AnalysisException(s"invalid function approximate $udfName") + } + } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp, s.toDouble) + } else { + throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") + } + } | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { case casePart ~ altPart ~ elsePart => @@ -304,16 +302,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } ++ elsePart casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } - | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } - | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } - ) + ) protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5883d938b6..02b10c444d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -461,7 +461,9 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u @ UnresolvedFunction(name, children) if u.childrenResolved => - registry.lookupFunction(name, children) + withPosition(u) { + registry.lookupFunction(name, children) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0849faa9bf..406f6fad84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,24 +17,27 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.Expression -import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.StringKeyHashMap + /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - type FunctionBuilder = Seq[Expression] => Expression def registerFunction(name: String, builder: FunctionBuilder): Unit + @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression - - def conf: CatalystConf } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) @@ -45,16 +48,19 @@ trait OverrideFunctionRegistry extends FunctionRegistry { } } -class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry { +class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders(name)(children) + val func = functionBuilders.get(name).getOrElse { + throw new AnalysisException(s"undefined function $name") + } + func(children) } } @@ -70,30 +76,89 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - - override def conf: CatalystConf = throw new UnsupportedOperationException } -/** - * Build a map with String type of key, and it also supports either key case - * sensitive or insensitive. - * TODO move this into util folder? - */ -object StringKeyHashMap { - def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { - case false => new StringKeyHashMap[T](_.toLowerCase) - case true => new StringKeyHashMap[T](identity) - } -} -class StringKeyHashMap[T](normalizer: (String) => String) { - private val base = new collection.mutable.HashMap[String, T]() +object FunctionRegistry { - def apply(key: String): T = base(normalizer(key)) + type FunctionBuilder = Seq[Expression] => Expression - def get(key: String): Option[T] = base.get(normalizer(key)) - def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) - def remove(key: String): Option[T] = base.remove(normalizer(key)) - def iterator: Iterator[(String, T)] = base.toIterator + val expressions: Map[String, FunctionBuilder] = Map( + // Non aggregate functions + expression[Abs]("abs"), + expression[CreateArray]("array"), + expression[Coalesce]("coalesce"), + expression[Explode]("explode"), + expression[Lower]("lower"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[Rand]("rand"), + expression[Randn]("randn"), + expression[CreateStruct]("struct"), + expression[Sqrt]("sqrt"), + expression[Upper]("upper"), + + // Math functions + expression[Acos]("acos"), + expression[Asin]("asin"), + expression[Atan]("atan"), + expression[Atan2]("atan2"), + expression[Cbrt]("cbrt"), + expression[Ceil]("ceil"), + expression[Cos]("cos"), + expression[Exp]("exp"), + expression[Expm1]("expm1"), + expression[Floor]("floor"), + expression[Hypot]("hypot"), + expression[Log]("log"), + expression[Log10]("log10"), + expression[Log1p]("log1p"), + expression[Pow]("pow"), + expression[Rint]("rint"), + expression[Signum]("signum"), + expression[Sin]("sin"), + expression[Sinh]("sinh"), + expression[Tan]("tan"), + expression[Tanh]("tanh"), + expression[ToDegrees]("todegrees"), + expression[ToRadians]("toradians"), + + // aggregate functions + expression[Average]("avg"), + expression[Count]("count"), + expression[First]("first"), + expression[Last]("last"), + expression[Max]("max"), + expression[Min]("min"), + expression[Sum]("sum") + ) + + /** See usage above. */ + private def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + // Use the companion class to find apply methods. + val objectClass = Class.forName(tag.runtimeClass.getName + "$") + val companionObj = objectClass.getDeclaredField("MODULE$").get(null) + + // See if we can find an apply that accepts Seq[Expression] + val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption + + val builder = (expressions: Seq[Expression]) => { + if (varargApply.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression] + } else { + // Otherwise, find an apply method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match { + case Success(e) => + e + case Failure(e) => + throw new AnalysisException(s"Invalid number of arguments for function $name") + } + f.invoke(companionObj, expressions : _*).asInstanceOf[Expression] + } + } + (name, builder) + } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a9a9c0cfb7..f2ed1f0929 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -23,6 +23,15 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ + +/** + * For Catalyst to work correctly, concrete implementations of [[Expression]]s must be case classes + * whose constructor arguments are all Expressions types. In addition, if we want to support more + * than one constructor, define those constructors explicitly as apply methods in the companion + * object. + * + * See [[Substring]] for an example. + */ abstract class Expression extends TreeNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index b2647124c4..6e4e9cb1be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -46,11 +47,29 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ -case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) { +case class Rand(seed: Long) extends RDG(seed) { override def eval(input: Row): Double = rng.nextDouble() } +object Rand { + def apply(): Rand = apply(Utils.random.nextLong()) + + def apply(seed: Expression): Rand = apply(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) +} + /** Generate a random column with i.i.d. gaussian random distribution. */ -case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) { +case class Randn(seed: Long) extends RDG(seed) { override def eval(input: Row): Double = rng.nextGaussian() } + +object Randn { + def apply(): Randn = apply(Utils.random.nextLong()) + + def apply(seed: Expression): Randn = apply(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index aae122a981..856f56488c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -227,6 +227,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") @@ -287,3 +288,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) case _ => s"SUBSTR($str, $pos, $len)" } } + +object Substring { + def apply(str: Expression, pos: Expression): Substring = { + apply(str, pos, Literal(Integer.MAX_VALUE)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala new file mode 100644 index 0000000000..191d5e6399 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + + def remove(key: String): Option[T] = base.remove(normalizer(key)) + + def iterator: Iterator[(String, T)] = base.toIterator +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ddb54025ba..8cad3885b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -120,7 +120,11 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf) + protected[sql] lazy val functionRegistry: FunctionRegistry = { + val fr = new SimpleFunctionRegistry + FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) } + fr + } @transient protected[sql] lazy val analyzer: Analyzer = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 064c040d2b..703a34c47e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -25,6 +25,48 @@ class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + test("built-in fixed arity expressions") { + val df = ctx.emptyDataFrame + df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") + } + + test("built-in vararg expressions") { + val df = Seq((1, 2)).toDF("a", "b") + df.selectExpr("array(a, b)") + df.selectExpr("struct(a, b)") + } + + test("built-in expressions with multiple constructors") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect() + } + + test("count") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(a)") + } + + test("count distinct") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(distinct a)") + } + + test("error reporting for incorrect number of arguments") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("substr('abcd', 2, 3, 4)") + } + assert(e.getMessage.contains("arguments")) + } + + test("error reporting for undefined functions") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("a_function_that_does_not_exist()") + } + assert(e.getMessage.contains("undefined function")) + } + test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b8f294c262..3b8cafb4a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -39,13 +39,12 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -374,10 +373,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - override def conf: CatalystConf = currentSession().conf - } + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry /* An analyzer that uses the Hive metastore. */ @transient diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 01f47352b2..6e6ac987b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.spark.sql.AnalysisException - import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions @@ -30,8 +27,11 @@ import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ @@ -40,20 +40,18 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ private[hive] abstract class HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $name")) + throw new AnalysisException(s"undefined function $name")) val functionClassName = functionInfo.getFunctionClass.getName -- GitLab