Skip to content
Snippets Groups Projects
Commit b4c99f43 authored by Wenchen Fan's avatar Wenchen Fan Committed by Xiao Li
Browse files

[SPARK-20569][SQL] RuntimeReplaceable functions should not take extra parameters

## What changes were proposed in this pull request?

`RuntimeReplaceable` always has a constructor with the expression to replace with, and this constructor should not be the function builder.

## How was this patch tested?

new regression test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17876 from cloud-fan/minor.
parent 65accb81
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis package org.apache.spark.sql.catalyst.analysis
import java.lang.reflect.Modifier
import scala.language.existentials import scala.language.existentials
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try} import scala.util.{Failure, Success, Try}
...@@ -455,8 +457,17 @@ object FunctionRegistry { ...@@ -455,8 +457,17 @@ object FunctionRegistry {
private def expression[T <: Expression](name: String) private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
// constructor and contains non-parameter `child` and should not be used as function builder.
val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) {
val all = tag.runtimeClass.getConstructors
val maxNumArgs = all.map(_.getParameterCount).max
all.filterNot(_.getParameterCount == maxNumArgs)
} else {
tag.runtimeClass.getConstructors
}
// See if we can find a constructor that accepts Seq[Expression] // See if we can find a constructor that accepts Seq[Expression]
val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]]))
val builder = (expressions: Seq[Expression]) => { val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) { if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one. // If there is an apply method that accepts Seq[Expression], use that one.
...@@ -470,11 +481,8 @@ object FunctionRegistry { ...@@ -470,11 +481,8 @@ object FunctionRegistry {
} else { } else {
// Otherwise, find a constructor method that matches the number of arguments, and use that. // Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression]) val params = Seq.fill(expressions.size)(classOf[Expression])
val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
case Success(e) => throw new AnalysisException(s"Invalid number of arguments for function $name")
e
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
} }
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e case Success(e) => e
......
...@@ -2619,4 +2619,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ...@@ -2619,4 +2619,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
new URL(jarFromInvalidFs) new URL(jarFromInvalidFs)
} }
} }
test("RuntimeReplaceable functions should not take extra parameters") {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment