diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index bb8f2a3aa5f71c5cbd43e051091eb6db9e3bdff9..46b512f8aea7e8e6aafb9aaab11f306e5fa431ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -114,7 +114,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String val bucketizer: UserDefinedFunction = udf { (feature: Double) => Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - } + }.withName("bucketizer") val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 5fd7123af3a0394f747c5a948a2f71aefdee0d25..1bceac41b9de7caaa87739d560546c73019bd3af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try($inputTypes).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") } @@ -147,7 +147,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -160,7 +160,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -173,7 +173,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -199,7 +199,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -212,7 +212,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -225,7 +225,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -238,7 +238,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -251,7 +251,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -264,7 +264,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -277,7 +277,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -290,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -303,7 +303,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -316,7 +316,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -329,7 +329,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -355,7 +355,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -368,7 +368,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -381,7 +381,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -394,7 +394,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -407,7 +407,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -420,7 +420,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -433,7 +433,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 5a0f488149ea433060c5e39e964b1dd76c4d5995..0c5f1b436591d3c6957cd4f33c5949c66a264e4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -47,6 +47,7 @@ case class UserDefinedFunction protected[sql] ( dataType: DataType, inputTypes: Option[Seq[DataType]]) { + private var _nameOption: Option[String] = None private var _nullable: Boolean = true /** @@ -67,15 +68,27 @@ case class UserDefinedFunction protected[sql] ( dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil), + udfName = _nameOption, nullable = _nullable)) } private def copyAll(): UserDefinedFunction = { val udf = copy() + udf._nameOption = _nameOption udf._nullable = _nullable udf } + /** + * Updates UserDefinedFunction with a given name. + * + * @since 2.3.0 + */ + def withName(name: String): this.type = { + this._nameOption = Option(name) + this + } + /** * Updates UserDefinedFunction with a given nullability. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6f8723af91cea6840a39782de5f114379b6eb430..b4f744b193ada99869d768c4d8ca6ee629c8a5ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -263,10 +263,12 @@ class UDFSuite extends QueryTest with SharedSQLContext { val sparkPlan = spark.sessionState.executePlan(explain).executedPlan sparkPlan.executeCollect().map(_.getString(0).trim).headOption.getOrElse("") } - val udf1 = "myUdf1" - val udf2 = "myUdf2" - spark.udf.register(udf1, (n: Int) => { n + 1 }) - spark.udf.register(udf2, (n: Int) => { n * 1 }) - assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1(UDF:$udf2(1))")) + val udf1Name = "myUdf1" + val udf2Name = "myUdf2" + val udf1 = spark.udf.register(udf1Name, (n: Int) => n + 1) + val udf2 = spark.udf.register(udf2Name, (n: Int) => n * 1) + assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) + assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) + .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } }