diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd852547492aaeb46e0b6c15601c018b357292e3..a2595ff6c22f4250d2fa39c3a296f7e0b9e2061f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -69,7 +69,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), // local function inside a method ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1") + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") ) ++ Seq( // SPARK-8479 Add numNonzeros and numActives to Matrix. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c93a15badae296c7357c256435a4d14698bc764c..abb6522dde7b0fd7a386d012008b9af29e226e70 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,6 +34,7 @@ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler +from pyspark.sql.functions import UserDefinedFunction try: import pandas @@ -191,19 +192,8 @@ class SQLContext(object): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(_c0=4)] """ - func = lambda _, it: map(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_cmd), - env, - includes, - self._sc.pythonExec, - self._sc.pythonVer, - bvars, - self._sc._javaAccumulator, - returnType.json()) + udf = UserDefinedFunction(f, returnType, name) + self._ssql_ctx.udf().registerPython(name, udf._judf) def _inferSchemaFromList(self, data): """ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fd5a3ba8adab38b79a1bb241f72e1b853171987b..031745a1c4d3b1e0ee4a84cc04d82de3a45f7a93 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -801,23 +801,24 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ - def __init__(self, func, returnType): + def __init__(self, func, returnType, name=None): self.func = func self.returnType = returnType self._broadcast = None - self._judf = self._create_judf() + self._judf = self._create_judf(name) - def _create_judf(self): - f = self.func # put it in closure `func` - func = lambda _, it: map(lambda x: f(*x), it) + def _create_judf(self, name): + f, returnType = self.func, self.returnType # put them in closure `func` + func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) sc = SparkContext._active_spark_context pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, + if name is None: + name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator, jdt) return judf diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7a55d801e48e6bf109b63eb7e2e4080710ba751a..ea821f486f13a6cb3b4654fe3e2f60560b052cc8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -417,12 +417,14 @@ class SQLTests(ReusedPySparkTestCase): self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sc.parallelize([row]).toDF() self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index d35d37d017198dd7163b97e2141286b2fe09e6ba..7cd7421a518c9a5294bd6e04dc70de01bc3fd81f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType /** @@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry - protected[sql] def registerPython( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - stringDataType: String): Unit = { + protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" | Registering new PythonUDF: | name: $name - | command: ${command.toSeq} - | envVars: $envVars - | pythonIncludes: $pythonIncludes - | pythonExec: $pythonExec - | dataType: $stringDataType + | command: ${udf.command.toSeq} + | envVars: ${udf.envVars} + | pythonIncludes: ${udf.pythonIncludes} + | pythonExec: ${udf.pythonExec} + | dataType: ${udf.dataType} """.stripMargin) - - val dataType = sqlContext.parseDataType(stringDataType) - - def builder(e: Seq[Expression]): PythonUDF = - PythonUDF( - name, - command, - envVars, - pythonIncludes, - pythonExec, - pythonVer, - broadcastVars, - accumulator, - dataType, - e) - - functionRegistry.registerFunction(name, builder) + functionRegistry.registerFunction(name, udf.builder) } // scalastyle:off diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index b14e00ab9b163022237f354f9f4db50489940ad7..0f8cd280b5acb7f56dcc251649369386dbe01bdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction( accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType) { + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, + accumulator, dataType, e) + } + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, - broadcastVars, accumulator, dataType, exprs.map(_.expr)) + val udf = builder(exprs.map(_.expr)) Column(udf) } }