Skip to content
Snippets Groups Projects
Commit d3e07165 authored by gatorsmile's avatar gatorsmile
Browse files

[SPARK-19285][SQL] Implement UDF0

### What changes were proposed in this pull request?
This PR is to implement UDF0. `UDF0` is needed when users need to implement a JAVA UDF with no argument.

### How was this patch tested?
Added a test case

Author: gatorsmile <gatorsmile@gmail.com>

Closes #18598 from gatorsmile/udf0.
parent 1cad31f0
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
import java.io.Serializable;
import org.apache.spark.annotation.InterfaceStability;
/**
* A Spark SQL UDF that has 0 arguments.
*/
@InterfaceStability.Stable
public interface UDF0<R> extends Serializable {
R call() throws Exception;
}
...@@ -122,25 +122,27 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ...@@ -122,25 +122,27 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}""") }""")
} }
(1 to 22).foreach { i => (0 to 22).foreach { i =>
val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") val extTypeArgs = (0 to i).map(_ => "_").mkString(", ")
val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
val version = if (i == 0) "2.3.0" else "1.3.0"
val funcCall = if (i == 0) "() => func" else "func"
println(s""" println(s"""
|/** |/**
| * Register a user-defined function with ${i} arguments. | * Register a user-defined function with ${i} arguments.
| * @since 1.3.0 | * @since $version
| */ | */
|def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
| val func = f$anyCast.call($anyParams) | val func = f$anyCast.call($anyParams)
|def builder(e: Seq[Expression]) = if (e.length == $i) { | def builder(e: Seq[Expression]) = if (e.length == $i) {
| ScalaUDF(func, returnType, e) | ScalaUDF($funcCall, returnType, e)
|} else { | } else {
| throw new AnalysisException("Invalid number of arguments for function " + name + | throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $i; Found: " + e.length) | ". Expected: $i; Found: " + e.length)
|} | }
|functionRegistry.createOrReplaceTempFunction(name, builder) | functionRegistry.createOrReplaceTempFunction(name, builder)
|}""".stripMargin) |}""".stripMargin)
} }
*/ */
...@@ -592,6 +594,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ...@@ -592,6 +594,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
} }
udfInterfaces(0).getActualTypeArguments.length match { udfInterfaces(0).getActualTypeArguments.length match {
case 1 => register(name, udf.asInstanceOf[UDF0[_]], returnType)
case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
...@@ -649,6 +652,21 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ...@@ -649,6 +652,21 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
} }
} }
/**
* Register a user-defined function with 0 arguments.
* @since 2.3.0
*/
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF0[Any]].call()
def builder(e: Seq[Expression]) = if (e.length == 0) {
ScalaUDF(() => func, returnType, e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 0; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
}
/** /**
* Register a user-defined function with 1 arguments. * Register a user-defined function with 1 arguments.
* @since 1.3.0 * @since 1.3.0
......
...@@ -113,4 +113,12 @@ public class JavaUDFSuite implements Serializable { ...@@ -113,4 +113,12 @@ public class JavaUDFSuite implements Serializable {
spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
List<Row> results = spark.sql("SELECT inc(1, 5)").collectAsList(); List<Row> results = spark.sql("SELECT inc(1, 5)").collectAsList();
} }
@SuppressWarnings("unchecked")
@Test
public void udf6Test() {
spark.udf().register("returnOne", () -> 1, DataTypes.IntegerType);
Row result = spark.sql("SELECT returnOne()").head();
Assert.assertEquals(1, result.getInt(0));
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment