From 354d4c24be892271bd9a9eab6ceedfbc5d671c9c Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Sat, 13 Feb 2016 21:06:31 -0800
Subject: [PATCH] [SPARK-13296][SQL] Move UserDefinedFunction into
 sql.expressions.

This pull request has the following changes:

1. Moved UserDefinedFunction into expressions package. This is more consistent with how we structure the packages for window functions and UDAFs.

2. Moved UserDefinedPythonFunction into execution.python package, so we don't have a random private class in the top level sql package.

3. Move everything in execution/python.scala into the newly created execution.python package.

Most of the diffs are just straight copy-paste.

Author: Reynold Xin <rxin@databricks.com>

Closes #11181 from rxin/SPARK-13296.
---
 project/MimaExcludes.scala                    |   8 +-
 python/pyspark/sql/dataframe.py               |   2 +-
 python/pyspark/sql/functions.py               |   6 +-
 .../org/apache/spark/sql/DataFrame.scala      |   3 +-
 .../org/apache/spark/sql/SQLContext.scala     |   4 +-
 .../apache/spark/sql/UDFRegistration.scala    |   3 +-
 .../spark/sql/execution/SparkStrategies.scala |   4 +-
 .../python/BatchPythonEvaluation.scala        | 104 ++++++++++
 .../EvaluatePython.scala}                     | 187 ++----------------
 .../execution/python/ExtractPythonUDFs.scala  |  79 ++++++++
 .../sql/execution/python/PythonUDF.scala      |  44 +++++
 .../python/UserDefinedPythonFunction.scala    |  51 +++++
 .../UserDefinedFunction.scala                 |  39 +---
 .../org/apache/spark/sql/functions.scala      |   1 +
 .../apache/spark/sql/hive/HiveContext.scala   |   2 +-
 15 files changed, 320 insertions(+), 217 deletions(-)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
 rename sql/core/src/main/scala/org/apache/spark/sql/execution/{python.scala => python/EvaluatePython.scala} (56%)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 rename sql/core/src/main/scala/org/apache/spark/sql/{ => expressions}/UserDefinedFunction.scala (57%)

diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8611106db0..6abab7f126 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -235,7 +235,13 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint")
       ) ++ Seq(
         // SPARK-7889
-        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI")
+        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"),
+        // SPARK-13296
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$")
       )
     case v if v.startsWith("1.6") =>
       Seq(
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3104e41407..83b034fe77 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -262,7 +262,7 @@ class DataFrame(object):
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
         with SCCallSiteSync(self._sc) as css:
-            port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe(
+            port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe(
                 self._jdf, num)
         return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
 
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 416d722bba..5fc1cc2cae 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1652,9 +1652,9 @@ class UserDefinedFunction(object):
         jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
         if name is None:
             name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
-        judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
-                                                 sc.pythonExec, sc.pythonVer, broadcast_vars,
-                                                 sc._javaAccumulator, jdt)
+        judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
+            name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer,
+            broadcast_vars, sc._javaAccumulator, jdt)
         return judf
 
     def __del__(self):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index c5b2b7d118..76c09a285d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
 import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
 import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
+import org.apache.spark.sql.execution.python.EvaluatePython
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index d58b99655c..c7d1096a13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -193,7 +193,7 @@ class SQLContext private[sql](
   protected[sql] lazy val analyzer: Analyzer =
     new Analyzer(catalog, functionRegistry, conf) {
       override val extendedResolutionRules =
-        ExtractPythonUDFs ::
+        python.ExtractPythonUDFs ::
         PreInsertCastAndRename ::
         (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)
 
@@ -915,7 +915,7 @@ class SQLContext private[sql](
       rdd: RDD[Array[Any]],
       schema: StructType): DataFrame = {
 
-    val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
+    val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
     DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index f87a88d497..ecfc170bee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
 import org.apache.spark.sql.execution.aggregate.ScalaUDAF
-import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
 import org.apache.spark.sql.types.DataType
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 598ddd7161..73fd22b38e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -369,8 +369,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case logical.RepartitionByExpression(expressions, child, nPartitions) =>
         execution.Exchange(HashPartitioning(
           expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
-      case e @ EvaluatePython(udf, child, _) =>
-        BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
+      case e @ python.EvaluatePython(udf, child, _) =>
+        python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
       case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
       case BroadcastHint(child) => planLater(child) :: Nil
       case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
new file mode 100644
index 0000000000..00df019527
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -0,0 +1,104 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle.{Pickler, Unpickler}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.PythonRunner
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+/**
+ * A physical plan that evalutes a [[PythonUDF]], one partition of tuples at a time.
+ *
+ * Python evaluation works by sending the necessary (projected) input data via a socket to an
+ * external Python process, and combine the result from the Python process with the original row.
+ *
+ * For each row we send to Python, we also put it in a queue. For each output row from Python,
+ * we drain the queue to find the original input row. Note that if the Python process is way too
+ * slow, this could lead to the queue growing unbounded and eventually run out of memory.
+ */
+case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
+  extends SparkPlan {
+
+  def children: Seq[SparkPlan] = child :: Nil
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute().map(_.copy())
+    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+
+    inputRDD.mapPartitions { iter =>
+      EvaluatePython.registerPicklers()  // register pickler for Row
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+
+      val pickle = new Pickler
+      val currentRow = newMutableProjection(udf.children, child.output)()
+      val fields = udf.children.map(_.dataType)
+      val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
+
+      // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+      // For each row, add it to the queue.
+      val inputIterator = iter.grouped(100).map { inputRows =>
+        val toBePickled = inputRows.map { row =>
+          queue.add(row)
+          EvaluatePython.toJava(currentRow(row), schema)
+        }.toArray
+        pickle.dumps(toBePickled)
+      }
+
+      val context = TaskContext.get()
+
+      // Output iterator for results from Python.
+      val outputIterator = new PythonRunner(
+        udf.command,
+        udf.envVars,
+        udf.pythonIncludes,
+        udf.pythonExec,
+        udf.pythonVer,
+        udf.broadcastVars,
+        udf.accumulator,
+        bufferSize,
+        reuseWorker
+      ).compute(inputIterator, context.partitionId(), context)
+
+      val unpickle = new Unpickler
+      val row = new GenericMutableRow(1)
+      val joined = new JoinedRow
+      val resultProj = UnsafeProjection.create(output, output)
+
+      outputIterator.flatMap { pickedResult =>
+        val unpickledBatch = unpickle.loads(pickedResult)
+        unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+      }.map { result =>
+        row(0) = EvaluatePython.fromJava(result, udf.dataType)
+        resultProj(joined(queue.poll(), row))
+      }
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
similarity index 56%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index bf62bb05c3..8c46516594 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -15,106 +15,41 @@
 * limitations under the License.
 */
 
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.execution.python
 
 import java.io.OutputStream
-import java.util.{List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
 
-import net.razorvine.pickle._
+import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}
 
-import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext}
-import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil}
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
- * A serialized version of a Python lambda function.  Suitable for use in a [[PythonRDD]].
+ * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
  */
-private[spark] case class PythonUDF(
-    name: String,
-    command: Array[Byte],
-    envVars: JMap[String, String],
-    pythonIncludes: JList[String],
-    pythonExec: String,
-    pythonVer: String,
-    broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[JList[Array[Byte]]],
-    dataType: DataType,
-    children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging {
-
-  override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
-
-  override def nullable: Boolean = true
-}
+case class EvaluatePython(
+    udf: PythonUDF,
+    child: LogicalPlan,
+    resultAttribute: AttributeReference)
+  extends logical.UnaryNode {
 
-/**
- * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
- * alone in a batch.
- *
- * This has the limitation that the input to the Python UDF is not allowed include attributes from
- * multiple child operators.
- */
-private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-    // Skip EvaluatePython nodes.
-    case plan: EvaluatePython => plan
-
-    case plan: LogicalPlan if plan.resolved =>
-      // Extract any PythonUDFs from the current operator.
-      val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
-      if (udfs.isEmpty) {
-        // If there aren't any, we are done.
-        plan
-      } else {
-        // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
-        // If there is more than one, we will add another evaluation operator in a subsequent pass.
-        udfs.find(_.resolved) match {
-          case Some(udf) =>
-            var evaluation: EvaluatePython = null
-
-            // Rewrite the child that has the input required for the UDF
-            val newChildren = plan.children.map { child =>
-              // Check to make sure that the UDF can be evaluated with only the input of this child.
-              // Other cases are disallowed as they are ambiguous or would require a cartesian
-              // product.
-              if (udf.references.subsetOf(child.outputSet)) {
-                evaluation = EvaluatePython(udf, child)
-                evaluation
-              } else if (udf.references.intersect(child.outputSet).nonEmpty) {
-                sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
-              } else {
-                child
-              }
-            }
-
-            assert(evaluation != null, "Unable to evaluate PythonUDF.  Missing input attributes.")
-
-            // Trim away the new UDF value if it was only used for filtering or something.
-            logical.Project(
-              plan.output,
-              plan.transformExpressions {
-                case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
-              }.withNewChildren(newChildren))
-
-          case None =>
-            // If there is no Python UDF that is resolved, skip this round.
-            plan
-        }
-      }
-  }
+  def output: Seq[Attribute] = child.output :+ resultAttribute
+
+  // References should not include the produced attribute.
+  override def references: AttributeSet = udf.references
 }
 
+
 object EvaluatePython {
   def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
     new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
@@ -221,7 +156,7 @@ object EvaluatePython {
       if (array.length != fields.length) {
         throw new IllegalStateException(
           s"Input row doesn't have expected number of values required by the schema. " +
-          s"${fields.length} fields are required while ${array.length} values are provided."
+            s"${fields.length} fields are required while ${array.length} values are provided."
         )
       }
       new GenericInternalRow(array.zip(fields).map {
@@ -235,7 +170,6 @@ object EvaluatePython {
     case (c, _) => null
   }
 
-
   private val module = "pyspark.sql.types"
 
   /**
@@ -287,7 +221,7 @@ object EvaluatePython {
 
         out.write(Opcodes.MARK)
         var i = 0
-        while (i < row.values.size) {
+        while (i < row.values.length) {
           pickler.save(row.values(i))
           i += 1
         }
@@ -298,6 +232,7 @@ object EvaluatePython {
   }
 
   private[this] var registered = false
+
   /**
    * This should be called before trying to serialize any above classes un cluster mode,
    * this should be put in the closure
@@ -324,91 +259,3 @@ object EvaluatePython {
     }
   }
 }
-
-/**
- * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
- */
-case class EvaluatePython(
-    udf: PythonUDF,
-    child: LogicalPlan,
-    resultAttribute: AttributeReference)
-  extends logical.UnaryNode {
-
-  def output: Seq[Attribute] = child.output :+ resultAttribute
-
-  // References should not include the produced attribute.
-  override def references: AttributeSet = udf.references
-}
-
-/**
- * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time.
- *
- * Python evaluation works by sending the necessary (projected) input data via a socket to an
- * external Python process, and combine the result from the Python process with the original row.
- *
- * For each row we send to Python, we also put it in a queue. For each output row from Python,
- * we drain the queue to find the original input row. Note that if the Python process is way too
- * slow, this could lead to the queue growing unbounded and eventually run out of memory.
- */
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
-  extends SparkPlan {
-
-  def children: Seq[SparkPlan] = child :: Nil
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val inputRDD = child.execute().map(_.copy())
-    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
-    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
-
-    inputRDD.mapPartitions { iter =>
-      EvaluatePython.registerPicklers()  // register pickler for Row
-
-      // The queue used to buffer input rows so we can drain it to
-      // combine input with output from Python.
-      val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
-
-      val pickle = new Pickler
-      val currentRow = newMutableProjection(udf.children, child.output)()
-      val fields = udf.children.map(_.dataType)
-      val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
-
-      // Input iterator to Python: input rows are grouped so we send them in batches to Python.
-      // For each row, add it to the queue.
-      val inputIterator = iter.grouped(100).map { inputRows =>
-        val toBePickled = inputRows.map { row =>
-          queue.add(row)
-          EvaluatePython.toJava(currentRow(row), schema)
-        }.toArray
-        pickle.dumps(toBePickled)
-      }
-
-      val context = TaskContext.get()
-
-      // Output iterator for results from Python.
-      val outputIterator = new PythonRunner(
-        udf.command,
-        udf.envVars,
-        udf.pythonIncludes,
-        udf.pythonExec,
-        udf.pythonVer,
-        udf.broadcastVars,
-        udf.accumulator,
-        bufferSize,
-        reuseWorker
-      ).compute(inputIterator, context.partitionId(), context)
-
-      val unpickle = new Unpickler
-      val row = new GenericMutableRow(1)
-      val joined = new JoinedRow
-      val resultProj = UnsafeProjection.create(output, output)
-
-      outputIterator.flatMap { pickedResult =>
-        val unpickledBatch = unpickle.loads(pickedResult)
-        unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
-      }.map { result =>
-        row(0) = EvaluatePython.fromJava(result, udf.dataType)
-        resultProj(joined(queue.poll(), row))
-      }
-    }
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
new file mode 100644
index 0000000000..6e76e9569f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -0,0 +1,79 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
+ * alone in a batch.
+ *
+ * This has the limitation that the input to the Python UDF is not allowed include attributes from
+ * multiple child operators.
+ */
+private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    // Skip EvaluatePython nodes.
+    case plan: EvaluatePython => plan
+
+    case plan: LogicalPlan if plan.resolved =>
+      // Extract any PythonUDFs from the current operator.
+      val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
+      if (udfs.isEmpty) {
+        // If there aren't any, we are done.
+        plan
+      } else {
+        // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
+        // If there is more than one, we will add another evaluation operator in a subsequent pass.
+        udfs.find(_.resolved) match {
+          case Some(udf) =>
+            var evaluation: EvaluatePython = null
+
+            // Rewrite the child that has the input required for the UDF
+            val newChildren = plan.children.map { child =>
+              // Check to make sure that the UDF can be evaluated with only the input of this child.
+              // Other cases are disallowed as they are ambiguous or would require a cartesian
+              // product.
+              if (udf.references.subsetOf(child.outputSet)) {
+                evaluation = EvaluatePython(udf, child)
+                evaluation
+              } else if (udf.references.intersect(child.outputSet).nonEmpty) {
+                sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+              } else {
+                child
+              }
+            }
+
+            assert(evaluation != null, "Unable to evaluate PythonUDF.  Missing input attributes.")
+
+            // Trim away the new UDF value if it was only used for filtering or something.
+            logical.Project(
+              plan.output,
+              plan.transformExpressions {
+                case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
+              }.withNewChildren(newChildren))
+
+          case None =>
+            // If there is no Python UDF that is resolved, skip this round.
+            plan
+        }
+      }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
new file mode 100644
index 0000000000..0e53a0c473
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.{Accumulator, Logging}
+import org.apache.spark.api.python.PythonBroadcast
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A serialized version of a Python lambda function.
+ */
+case class PythonUDF(
+    name: String,
+    command: Array[Byte],
+    envVars: java.util.Map[String, String],
+    pythonIncludes: java.util.List[String],
+    pythonExec: String,
+    pythonVer: String,
+    broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
+    accumulator: Accumulator[java.util.List[Array[Byte]]],
+    dataType: DataType,
+    children: Seq[Expression]) extends Expression with Unevaluable with Logging {
+
+  override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
+
+  override def nullable: Boolean = true
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
new file mode 100644
index 0000000000..79ac1c85c0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.Accumulator
+import org.apache.spark.api.python.PythonBroadcast
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A user-defined Python function. This is used by the Python API.
+ */
+case class UserDefinedPythonFunction(
+    name: String,
+    command: Array[Byte],
+    envVars: java.util.Map[String, String],
+    pythonIncludes: java.util.List[String],
+    pythonExec: String,
+    pythonVer: String,
+    broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
+    accumulator: Accumulator[java.util.List[Array[Byte]]],
+    dataType: DataType) {
+
+  def builder(e: Seq[Expression]): PythonUDF = {
+    PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
+      accumulator, dataType, e)
+  }
+
+  /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
+  def apply(exprs: Column*): Column = {
+    val udf = builder(exprs.map(_.expr))
+    Column(udf)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
similarity index 57%
rename from sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 2fb3bf07aa..bd35d19aa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -15,16 +15,12 @@
 * limitations under the License.
 */
 
-package org.apache.spark.sql
+package org.apache.spark.sql.expressions
 
-import java.util.{List => JList, Map => JMap}
-
-import org.apache.spark.Accumulator
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import org.apache.spark.sql.execution.PythonUDF
+import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.functions
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -50,30 +46,3 @@ case class UserDefinedFunction protected[sql] (
     Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil)))
   }
 }
-
-/**
- * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]].
- * This is used by Python API.
- */
-private[sql] case class UserDefinedPythonFunction(
-    name: String,
-    command: Array[Byte],
-    envVars: JMap[String, String],
-    pythonIncludes: JList[String],
-    pythonExec: String,
-    pythonVer: String,
-    broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[JList[Array[Byte]]],
-    dataType: DataType) {
-
-  def builder(e: Seq[Expression]): PythonUDF = {
-    PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
-      accumulator, dataType, e)
-  }
-
-  /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
-  def apply(exprs: Column*): Column = {
-    val udf = builder(exprs.map(_.expr))
-    Column(udf)
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d34d377ab6..e4ab6b4f23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
+import org.apache.spark.sql.expressions.UserDefinedFunction
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2433b54ffc..ac174aa6bf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -465,7 +465,7 @@ class HiveContext private[hive](
         catalog.ParquetConversions ::
         catalog.CreateTables ::
         catalog.PreInsertionCasts ::
-        ExtractPythonUDFs ::
+        python.ExtractPythonUDFs ::
         PreInsertCastAndRename ::
         (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)
 
-- 
GitLab