diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590e6d5a1f93f07b2392598e4adf44647..1f3325ad09ef1a98cdbb94fb3ec67c34d7c57069 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -115,8 +115,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + val transformUDF = udf(this.createTransformFunc, outputDataType) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c5aed6da9c442a1d8ab35d66f2c1d9e9814166..3572f3c3a1f2c1eb539afcce5a91866826a9fbaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2843,6 +2843,20 @@ object functions extends LegacyFunctions { // scalastyle:on parameter.number // scalastyle:on line.size.limit + /** + * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must + * specifcy the output data type, and there is no automatic input type coercion. + * + * @param f A closure in Scala + * @param dataType The output data type of the UDF + * + * @group udf_funcs + * @since 2.0.0 + */ + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + UserDefinedFunction(f, dataType, None) + } + /** * Call an user-defined function. * Example: