diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 111e751588a8bf1dc1b04b12717c439bf7c2727b..ff91e1d74bc2ca61be6de6f8bebb4552b5ad7e58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -605,4 +605,23 @@ object functions { } // scalastyle:on + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUdf", $"value")) + * }}} + * + * @group udf_funcs + */ + def callUdf(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f5df8c6a59f107fe0e16e1335acae05a73e618ab..b26e22f6229fe325e1c4e27f7196766aa870c511 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -440,6 +440,15 @@ class DataFrameSuite extends QueryTest { ) } + test("call udf in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUdf", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUdf("simpleUdf", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer(