diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java new file mode 100644 index 0000000000000000000000000000000000000000..4eeb7be3f5abb9ee5d023a00b61b47f0c791566b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A Spark SQL UDF that has 0 arguments. + */ +@InterfaceStability.Stable +public interface UDF0<R> extends Serializable { + R call() throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index c4d0adb5236f2f07725ae608b693ae54091fe17a..c66d4057b9135c4a2c059514c36bcd322fdc921d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -122,25 +122,27 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends }""") } - (1 to 22).foreach { i => - val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") - val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") - val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" + (0 to 22).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val version = if (i == 0) "2.3.0" else "1.3.0" + val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** | * Register a user-defined function with ${i} arguments. - | * @since 1.3.0 + | * @since $version | */ - |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { + |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | val func = f$anyCast.call($anyParams) - |def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF(func, returnType, e) - |} else { - | throw new AnalysisException("Invalid number of arguments for function " + name + - | ". Expected: $i; Found: " + e.length) - |} - |functionRegistry.createOrReplaceTempFunction(name, builder) + | def builder(e: Seq[Expression]) = if (e.length == $i) { + | ScalaUDF($funcCall, returnType, e) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $i; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) |}""".stripMargin) } */ @@ -592,6 +594,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } udfInterfaces(0).getActualTypeArguments.length match { + case 1 => register(name, udf.asInstanceOf[UDF0[_]], returnType) case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) @@ -649,6 +652,21 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } } + /** + * Register a user-defined function with 0 arguments. + * @since 2.3.0 + */ + def register(name: String, f: UDF0[_], returnType: DataType): Unit = { + val func = f.asInstanceOf[UDF0[Any]].call() + def builder(e: Seq[Expression]) = if (e.length == 0) { + ScalaUDF(() => func, returnType, e) + } else { + throw new AnalysisException("Invalid number of arguments for function " + name + + ". Expected: 0; Found: " + e.length) + } + functionRegistry.createOrReplaceTempFunction(name, builder) + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 4fb2988f24d2681cfefb31adcddc5644150f1712..5bf18888261866ff6f834dfc8dc142255dfd364f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -113,4 +113,12 @@ public class JavaUDFSuite implements Serializable { spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); List<Row> results = spark.sql("SELECT inc(1, 5)").collectAsList(); } + + @SuppressWarnings("unchecked") + @Test + public void udf6Test() { + spark.udf().register("returnOne", () -> 1, DataTypes.IntegerType); + Row result = spark.sql("SELECT returnOne()").head(); + Assert.assertEquals(1, result.getInt(0)); + } }