From f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Thu, 31 Mar 2016 16:40:20 -0700
Subject: [PATCH] [SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs
 within single batch

## What changes were proposed in this pull request?

This PR support multiple Python UDFs within single batch, also improve the performance.

```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$

== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
   +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
      +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
         +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
            +- OneRowRelation$

== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
   +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
      +- OneRowRelation$

== Physical Plan ==
WholeStageCodegen
:  +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
:     +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
   +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
      +- Scan OneRowRelation[]
```

## How was this patch tested?

Added new tests.

Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:

N | Before | After  | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s |  3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X

This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).

Author: Davies Liu <davies@databricks.com>

Closes #12057 from davies/multi_udfs.
---
 .../apache/spark/api/python/PythonRDD.scala   | 64 +++++++++++----
 python/pyspark/sql/functions.py               |  3 +-
 python/pyspark/sql/tests.py                   | 12 ++-
 python/pyspark/worker.py                      | 68 ++++++++++++----
 .../spark/sql/execution/SparkStrategies.scala |  4 +-
 .../python/BatchPythonEvaluation.scala        | 78 ++++++++++++++-----
 .../sql/execution/python/EvaluatePython.scala | 28 +++++--
 .../execution/python/ExtractPythonUDFs.scala  | 77 +++++++++---------
 8 files changed, 233 insertions(+), 101 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0f579b4ef5..6faa03c12b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
   val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 
   override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
-    val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
+    val runner = PythonRunner(func, bufferSize, reuse_worker)
     runner.compute(firstParent.iterator(split, context), split.index, context)
   }
 }
@@ -78,21 +78,41 @@ private[spark] case class PythonFunction(
     accumulator: Accumulator[JList[Array[Byte]]])
 
 /**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for chained Python functions (from bottom to top).
+ * @param funcs
+ */
+private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
+
+private[spark] object PythonRunner {
+  def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
+    new PythonRunner(
+      Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+  }
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list of chained Python
+ * functions (from bottom to top).
  */
 private[spark] class PythonRunner(
-    funcs: Seq[PythonFunction],
+    funcs: Seq[ChainedPythonFunctions],
     bufferSize: Int,
     reuse_worker: Boolean,
-    rowBased: Boolean)
+    isUDF: Boolean,
+    argOffsets: Array[Array[Int]])
   extends Logging {
 
+  require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+
   // All the Python functions should have the same exec, version and envvars.
-  private val envVars = funcs.head.envVars
-  private val pythonExec = funcs.head.pythonExec
-  private val pythonVer = funcs.head.pythonVer
+  private val envVars = funcs.head.funcs.head.envVars
+  private val pythonExec = funcs.head.funcs.head.pythonExec
+  private val pythonVer = funcs.head.funcs.head.pythonVer
 
-  private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
+  // TODO: support accumulator in multiple UDF
+  private val accumulator = funcs.head.funcs.head.accumulator
 
   def compute(
       inputIterator: Iterator[_],
@@ -232,8 +252,8 @@ private[spark] class PythonRunner(
 
     @volatile private var _exception: Exception = null
 
-    private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
-    private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
+    private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+    private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
 
     setDaemon(true)
 
@@ -284,11 +304,25 @@ private[spark] class PythonRunner(
         }
         dataOut.flush()
         // Serialized command:
-        dataOut.writeInt(if (rowBased) 1 else 0)
-        dataOut.writeInt(funcs.length)
-        funcs.foreach { f =>
-          dataOut.writeInt(f.command.length)
-          dataOut.write(f.command)
+        if (isUDF) {
+          dataOut.writeInt(1)
+          dataOut.writeInt(funcs.length)
+          funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+            dataOut.writeInt(offsets.length)
+            offsets.foreach { offset =>
+              dataOut.writeInt(offset)
+            }
+            dataOut.writeInt(chained.funcs.length)
+            chained.funcs.foreach { f =>
+              dataOut.writeInt(f.command.length)
+              dataOut.write(f.command)
+            }
+          }
+        } else {
+          dataOut.writeInt(0)
+          val command = funcs.head.funcs.head.command
+          dataOut.writeInt(command.length)
+          dataOut.write(command)
         }
         // Data values
         PythonRDD.writeIteratorToStream(inputIterator, dataOut)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3211834226..3b20ba5177 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
 # ---------------------------- User Defined Function ----------------------------------
 
 def _wrap_function(sc, func, returnType):
-    ser = AutoBatchedSerializer(PickleSerializer())
-    command = (func, returnType, ser)
+    command = (func, returnType)
     pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
     return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
                                   sc.pythonVer, broadcast_vars, sc._javaAccumulator)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 84947560e7..536ef55251 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -305,7 +305,7 @@ class SQLTests(ReusedPySparkTestCase):
         [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
         self.assertEqual(4, res[0])
 
-    def test_chained_python_udf(self):
+    def test_chained_udf(self):
         self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
         [row] = self.sqlCtx.sql("SELECT double(1)").collect()
         self.assertEqual(row[0], 2)
@@ -314,6 +314,16 @@ class SQLTests(ReusedPySparkTestCase):
         [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
         self.assertEqual(row[0], 6)
 
+    def test_multiple_udfs(self):
+        self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
+        [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+        self.assertEqual(tuple(row), (2, 4))
+        [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+        self.assertEqual(tuple(row), (4, 12))
+        self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
+        [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+        self.assertEqual(tuple(row), (6, 5))
+
     def test_udf_with_array_type(self):
         d = [Row(l=list(range(3)), d={"key": list(range(5))})]
         rdd = self.sc.parallelize(d)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0f05fe31aa..cf47ab8f96 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
-    write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
+    write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -59,7 +59,54 @@ def read_command(serializer, file):
 
 def chain(f, g):
     """chain two function together """
-    return lambda x: g(f(x))
+    return lambda *a: g(f(*a))
+
+
+def wrap_udf(f, return_type):
+    if return_type.needConversion():
+        toInternal = return_type.toInternal
+        return lambda *a: toInternal(f(*a))
+    else:
+        return lambda *a: f(*a)
+
+
+def read_single_udf(pickleSer, infile):
+    num_arg = read_int(infile)
+    arg_offsets = [read_int(infile) for i in range(num_arg)]
+    row_func = None
+    for i in range(read_int(infile)):
+        f, return_type = read_command(pickleSer, infile)
+        if row_func is None:
+            row_func = f
+        else:
+            row_func = chain(row_func, f)
+    # the last returnType will be the return type of UDF
+    return arg_offsets, wrap_udf(row_func, return_type)
+
+
+def read_udfs(pickleSer, infile):
+    num_udfs = read_int(infile)
+    if num_udfs == 1:
+        # fast path for single UDF
+        _, udf = read_single_udf(pickleSer, infile)
+        mapper = lambda a: udf(*a)
+    else:
+        udfs = {}
+        call_udf = []
+        for i in range(num_udfs):
+            arg_offsets, udf = read_single_udf(pickleSer, infile)
+            udfs['f%d' % i] = udf
+            args = ["a[%d]" % o for o in arg_offsets]
+            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+        # Create function like this:
+        #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
+        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+        mapper = eval(mapper_str, udfs)
+
+    func = lambda _, it: map(mapper, it)
+    ser = BatchedSerializer(PickleSerializer(), 100)
+    # profiling is not supported for UDF
+    return func, None, ser, ser
 
 
 def main(infile, outfile):
@@ -107,21 +154,10 @@ def main(infile, outfile):
                 _broadcastRegistry.pop(bid)
 
         _accumulatorRegistry.clear()
-        row_based = read_int(infile)
-        num_commands = read_int(infile)
-        if row_based:
-            profiler = None  # profiling is not supported for UDF
-            row_func = None
-            for i in range(num_commands):
-                f, returnType, deserializer = read_command(pickleSer, infile)
-                if row_func is None:
-                    row_func = f
-                else:
-                    row_func = chain(row_func, f)
-            serializer = deserializer
-            func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
+        is_sql_udf = read_int(infile)
+        if is_sql_udf:
+            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
         else:
-            assert num_commands == 1
             func, profiler, deserializer, serializer = read_command(pickleSer, infile)
 
         init_time = time.time()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7841ff01f9..7a2e2b7382 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case logical.RepartitionByExpression(expressions, child, nPartitions) =>
         exchange.ShuffleExchange(HashPartitioning(
           expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
-      case e @ python.EvaluatePython(udf, child, _) =>
-        python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
+      case e @ python.EvaluatePython(udfs, child, _) =>
+        python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
       case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
       case BroadcastHint(child) => planLater(child) :: Nil
       case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index a76009e7df..c9ab40a0a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -18,16 +18,17 @@
 package org.apache.spark.sql.execution.python
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 
 import net.razorvine.pickle.{Pickler, Unpickler}
 
 import org.apache.spark.TaskContext
-import org.apache.spark.api.python.{PythonFunction, PythonRunner}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
 
 
 /**
@@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType}
  * we drain the queue to find the original input row. Note that if the Python process is way too
  * slow, this could lead to the queue growing unbounded and eventually run out of memory.
  */
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
+case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
   extends SparkPlan {
 
   def children: Seq[SparkPlan] = child :: Nil
 
-  private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
+  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
     udf.children match {
       case Seq(u: PythonUDF) =>
-        val (fs, children) = collectFunctions(u)
-        (fs ++ Seq(udf.func), children)
+        val (chained, children) = collectFunctions(u)
+        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
       case children =>
         // There should not be any other UDFs, or the children can't be evaluated directly.
         assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
-        (Seq(udf.func), udf.children)
+        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
     }
   }
 
@@ -69,19 +70,47 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
       // combine input with output from Python.
       val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
 
-      val (pyFuncs, children) = collectFunctions(udf)
-
-      val pickle = new Pickler
-      val currentRow = newMutableProjection(children, child.output)()
-      val fields = children.map(_.dataType)
-      val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
+      val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+      // flatten all the arguments
+      val allInputs = new ArrayBuffer[Expression]
+      val dataTypes = new ArrayBuffer[DataType]
+      val argOffsets = inputs.map { input =>
+        input.map { e =>
+          if (allInputs.exists(_.semanticEquals(e))) {
+            allInputs.indexWhere(_.semanticEquals(e))
+          } else {
+            allInputs += e
+            dataTypes += e.dataType
+            allInputs.length - 1
+          }
+        }.toArray
+      }.toArray
+      val projection = newMutableProjection(allInputs, child.output)()
+      val schema = StructType(dataTypes.map(dt => StructField("", dt)))
+      val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
 
+      // enable memo iff we serialize the row with schema (schema and class should be memorized)
+      val pickle = new Pickler(needConversion)
       // Input iterator to Python: input rows are grouped so we send them in batches to Python.
       // For each row, add it to the queue.
       val inputIterator = iter.grouped(100).map { inputRows =>
-        val toBePickled = inputRows.map { row =>
-          queue.add(row)
-          EvaluatePython.toJava(currentRow(row), schema)
+        val toBePickled = inputRows.map { inputRow =>
+          queue.add(inputRow)
+          val row = projection(inputRow)
+          if (needConversion) {
+            EvaluatePython.toJava(row, schema)
+          } else {
+            // fast path for these types that does not need conversion in Python
+            val fields = new Array[Any](row.numFields)
+            var i = 0
+            while (i < row.numFields) {
+              val dt = dataTypes(i)
+              fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+              i += 1
+            }
+            fields
+          }
         }.toArray
         pickle.dumps(toBePickled)
       }
@@ -89,19 +118,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
       val context = TaskContext.get()
 
       // Output iterator for results from Python.
-      val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
+      val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
         .compute(inputIterator, context.partitionId(), context)
 
       val unpickle = new Unpickler
-      val row = new GenericMutableRow(1)
+      val mutableRow = new GenericMutableRow(1)
       val joined = new JoinedRow
+      val resultType = if (udfs.length == 1) {
+        udfs.head.dataType
+      } else {
+        StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
+      }
       val resultProj = UnsafeProjection.create(output, output)
 
       outputIterator.flatMap { pickedResult =>
         val unpickledBatch = unpickle.loads(pickedResult)
         unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
       }.map { result =>
-        row(0) = EvaluatePython.fromJava(result, udf.dataType)
+        val row = if (udfs.length == 1) {
+          // fast path for single UDF
+          mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+          mutableRow
+        } else {
+          EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+        }
         resultProj(joined(queue.poll(), row))
       }
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index da28ec4f53..f3d1c44b25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -36,24 +36,28 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
- * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
+ * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
  */
 case class EvaluatePython(
-    udf: PythonUDF,
+    udfs: Seq[PythonUDF],
     child: LogicalPlan,
-    resultAttribute: AttributeReference)
+    resultAttribute: Seq[AttributeReference])
   extends logical.UnaryNode {
 
-  def output: Seq[Attribute] = child.output :+ resultAttribute
+  def output: Seq[Attribute] = child.output ++ resultAttribute
 
   // References should not include the produced attribute.
-  override def references: AttributeSet = udf.references
+  override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
 }
 
 
 object EvaluatePython {
-  def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
-    new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+  def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
+    val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+      AttributeReference(s"pythonUDF$i", u.dataType)()
+    }
+    new EvaluatePython(udfs, child, resultAttrs)
+  }
 
   def takeAndServe(df: DataFrame, n: Int): Int = {
     registerPicklers()
@@ -66,6 +70,16 @@ object EvaluatePython {
     }
   }
 
+  def needConversionInPython(dt: DataType): Boolean = dt match {
+    case DateType | TimestampType => true
+    case _: StructType => true
+    case _: UserDefinedType[_] => true
+    case ArrayType(elementType, _) => needConversionInPython(elementType)
+    case MapType(keyType, valueType, _) =>
+      needConversionInPython(keyType) || needConversionInPython(valueType)
+    case _ => false
+  }
+
   /**
    * Helper for converting from Catalyst type to java type suitable for Pyrolite.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index c486ce18e8..0934cd135d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.sql.catalyst.expressions.Expression
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -47,10 +49,9 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
     }
   }
 
-  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
-    expr.collect {
-      case udf: PythonUDF if canEvaluateInPython(udf) => udf
-    }
+  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
+    case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf)
+    case e => e.children.flatMap(collectEvaluatableUDF)
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
@@ -59,45 +60,43 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
 
     case plan: LogicalPlan if plan.resolved =>
       // Extract any PythonUDFs from the current operator.
-      val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+      val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved)
       if (udfs.isEmpty) {
         // If there aren't any, we are done.
         plan
       } else {
-        // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
-        // If there is more than one, we will add another evaluation operator in a subsequent pass.
-        udfs.find(_.resolved) match {
-          case Some(udf) =>
-            var evaluation: EvaluatePython = null
-
-            // Rewrite the child that has the input required for the UDF
-            val newChildren = plan.children.map { child =>
-              // Check to make sure that the UDF can be evaluated with only the input of this child.
-              // Other cases are disallowed as they are ambiguous or would require a cartesian
-              // product.
-              if (udf.references.subsetOf(child.outputSet)) {
-                evaluation = EvaluatePython(udf, child)
-                evaluation
-              } else if (udf.references.intersect(child.outputSet).nonEmpty) {
-                sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
-              } else {
-                child
-              }
-            }
-
-            assert(evaluation != null, "Unable to evaluate PythonUDF.  Missing input attributes.")
-
-            // Trim away the new UDF value if it was only used for filtering or something.
-            logical.Project(
-              plan.output,
-              plan.transformExpressions {
-                case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
-              }.withNewChildren(newChildren))
-
-          case None =>
-            // If there is no Python UDF that is resolved, skip this round.
-            plan
+        val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+        // Rewrite the child that has the input required for the UDF
+        val newChildren = plan.children.map { child =>
+          // Pick the UDF we are going to evaluate
+          val validUdfs = udfs.filter { case udf =>
+            // Check to make sure that the UDF can be evaluated with only the input of this child.
+            udf.references.subsetOf(child.outputSet)
+          }
+          if (validUdfs.nonEmpty) {
+            val evaluation = EvaluatePython(validUdfs, child)
+            attributeMap ++= validUdfs.zip(evaluation.resultAttribute)
+            evaluation
+          } else {
+            child
+          }
         }
+        // Other cases are disallowed as they are ambiguous or would require a cartesian
+        // product.
+        udfs.filterNot(attributeMap.contains).foreach { udf =>
+          if (udf.references.subsetOf(plan.inputSet)) {
+            sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+          } else {
+            sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
+          }
+        }
+
+        // Trim away the new UDF value if it was only used for filtering or something.
+        logical.Project(
+          plan.output,
+          plan.transformExpressions {
+            case p: PythonUDF if attributeMap.contains(p) => attributeMap(p)
+          }.withNewChildren(newChildren))
       }
   }
 }
-- 
GitLab