From 8c0bfd08fc19fa5de7d77bf8306d19834f907ec0 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Tue, 28 Oct 2014 19:38:16 -0700
Subject: [PATCH] [SPARK-4133] [SQL] [PySpark] type conversionfor python udf

Call Python UDF on ArrayType/MapType/PrimitiveType, the returnType can also be ArrayType/MapType/PrimitiveType.

For StructType, it will act as tuple (without attributes). If returnType is StructType, it also should be tuple.

Author: Davies Liu <davies@databricks.com>

Closes #2973 from davies/udf_array and squashes the following commits:

306956e [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array
2c00e43 [Davies Liu] fix merge
11395fa [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array
9df50a2 [Davies Liu] address comments
79afb4e [Davies Liu] type conversionfor python udf
---
 python/pyspark/tests.py                       | 16 +++-
 .../org/apache/spark/sql/SQLContext.scala     | 43 +--------
 .../org/apache/spark/sql/SchemaRDD.scala      | 42 +--------
 .../spark/sql/execution/pythonUdfs.scala      | 91 +++++++++++++++++--
 4 files changed, 102 insertions(+), 90 deletions(-)

diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 047d857830..37a128907b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,7 @@ from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
     CloudPickleSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
 from pyspark import shuffle
 
 _have_scipy = False
@@ -690,10 +690,20 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(row[0], 5)
 
     def test_udf2(self):
-        self.sqlCtx.registerFunction("strlen", lambda string: len(string))
+        self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
         self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
         [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
-        self.assertEqual(u"4", res[0])
+        self.assertEqual(4, res[0])
+
+    def test_udf_with_array_type(self):
+        d = [Row(l=range(3), d={"key": range(5)})]
+        rdd = self.sc.parallelize(d)
+        srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+        self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+        self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+        [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+        self.assertEqual(range(3), l1)
+        self.assertEqual(1, l2)
 
     def test_broadcast_in_udf(self):
         bar = {"a": "aa", "b": "bb", "c": "abc"}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ca8706ee68..a41a500c9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -438,7 +438,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
   private[sql] def applySchemaToPythonRDD(
       rdd: RDD[Array[Any]],
       schema: StructType): SchemaRDD = {
-    import scala.collection.JavaConversions._
 
     def needsConversion(dataType: DataType): Boolean = dataType match {
       case ByteType => true
@@ -452,49 +451,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
       case other => false
     }
 
-    // Converts value to the type specified by the data type.
-    // Because Python does not have data types for DateType, TimestampType, FloatType, ShortType,
-    // and ByteType, we need to explicitly convert values in columns of these data types to the
-    // desired JVM data types.
-    def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
-      // TODO: We should check nullable
-      case (null, _) => null
-
-      case (c: java.util.List[_], ArrayType(elementType, _)) =>
-        c.map { e => convert(e, elementType)}: Seq[Any]
-
-      case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
-        c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any]
-
-      case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
-          case (key, value) => (convert(key, keyType), convert(value, valueType))
-        }.toMap
-
-      case (c, StructType(fields)) if c.getClass.isArray =>
-        new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map {
-          case (e, f) => convert(e, f.dataType)
-        }): Row
-
-      case (c: java.util.Calendar, DateType) =>
-        new java.sql.Date(c.getTime().getTime())
-
-      case (c: java.util.Calendar, TimestampType) =>
-        new java.sql.Timestamp(c.getTime().getTime())
-
-      case (c: Int, ByteType) => c.toByte
-      case (c: Long, ByteType) => c.toByte
-      case (c: Int, ShortType) => c.toShort
-      case (c: Long, ShortType) => c.toShort
-      case (c: Long, IntegerType) => c.toInt
-      case (c: Double, FloatType) => c.toFloat
-      case (c, StringType) if !c.isInstanceOf[String] => c.toString
-
-      case (c, _) => c
-    }
-
     val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
       rdd.map(m => m.zip(schema.fields).map {
-        case (value, field) => convert(value, field.dataType)
+        case (value, field) => EvaluatePython.fromJava(value, field.dataType)
       })
     } else {
       rdd
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 948122d42f..8b96df1096 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
 import org.apache.spark.api.java.JavaRDD
 
 /**
@@ -377,47 +377,15 @@ class SchemaRDD(
    */
   def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
 
-  /**
-   * Helper for converting a Row to a simple Array suitable for pyspark serialization.
-   */
-  private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
-    import scala.collection.Map
-
-    def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
-      case (null, _) => null
-
-      case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
-
-      case (seq: Seq[Any], array: ArrayType) =>
-        seq.map(x => toJava(x, array.elementType)).asJava
-      case (list: JList[_], array: ArrayType) =>
-        list.map(x => toJava(x, array.elementType)).asJava
-      case (arr, array: ArrayType) if arr.getClass.isArray =>
-        arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
-
-      case (obj: Map[_, _], mt: MapType) => obj.map {
-        case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
-      }.asJava
-
-      // Pyrolite can handle Timestamp
-      case (other, _) => other
-    }
-
-    val fields = structType.fields.map(field => field.dataType)
-    row.zip(fields).map {
-      case (obj, dataType) => toJava(obj, dataType)
-    }.toArray
-  }
-
   /**
    * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
    */
   private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
-    val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+    val fieldTypes = schema.fields.map(_.dataType)
     this.mapPartitions { iter =>
       val pickle = new Pickler
       iter.map { row =>
-        rowToJArray(row, rowSchema)
+        EvaluatePython.rowToArray(row, fieldTypes)
       }.grouped(100).map(batched => pickle.dumps(batched.toArray))
     }
   }
@@ -427,10 +395,10 @@ class SchemaRDD(
    * format as javaToPython. It is used by pyspark.
    */
   private[sql] def collectToPython: JList[Array[Byte]] = {
-    val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+    val fieldTypes = schema.fields.map(_.dataType)
     val pickle = new Pickler
     new java.util.ArrayList(collect().map { row =>
-      rowToJArray(row, rowSchema)
+      EvaluatePython.rowToArray(row, fieldTypes)
     }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index be729e5d24..a1961bba18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.execution
 
 import java.util.{List => JList, Map => JMap}
 
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+
 import net.razorvine.pickle.{Pickler, Unpickler}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.python.PythonRDD
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Row
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -31,8 +34,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.types._
 import org.apache.spark.{Accumulator, Logging => SparkLogging}
 
-import scala.collection.JavaConversions._
-
 /**
  * A serialized version of a Python lambda function.  Suitable for use in a [[PythonRDD]].
  */
@@ -108,6 +109,80 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
 object EvaluatePython {
   def apply(udf: PythonUDF, child: LogicalPlan) =
     new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+
+  /**
+   * Helper for converting a Scala object to a java suitable for pyspark serialization.
+   */
+  def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+    case (null, _) => null
+
+    case (row: Row, struct: StructType) =>
+      val fields = struct.fields.map(field => field.dataType)
+      row.zip(fields).map {
+        case (obj, dataType) => toJava(obj, dataType)
+      }.toArray
+
+    case (seq: Seq[Any], array: ArrayType) =>
+      seq.map(x => toJava(x, array.elementType)).asJava
+    case (list: JList[_], array: ArrayType) =>
+      list.map(x => toJava(x, array.elementType)).asJava
+    case (arr, array: ArrayType) if arr.getClass.isArray =>
+      arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+
+    case (obj: Map[_, _], mt: MapType) => obj.map {
+      case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+    }.asJava
+
+    // Pyrolite can handle Timestamp
+    case (other, _) => other
+  }
+
+  /**
+   * Convert Row into Java Array (for pickled into Python)
+   */
+  def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
+    row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
+  }
+
+  // Converts value to the type specified by the data type.
+  // Because Python does not have data types for TimestampType, FloatType, ShortType, and
+  // ByteType, we need to explicitly convert values in columns of these data types to the desired
+  // JVM data types.
+  def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+    // TODO: We should check nullable
+    case (null, _) => null
+
+    case (c: java.util.List[_], ArrayType(elementType, _)) =>
+      c.map { e => fromJava(e, elementType)}: Seq[Any]
+
+    case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
+      c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any]
+
+    case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
+      case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
+    }.toMap
+
+    case (c, StructType(fields)) if c.getClass.isArray =>
+      new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map {
+        case (e, f) => fromJava(e, f.dataType)
+      }): Row
+
+    case (c: java.util.Calendar, DateType) =>
+      new java.sql.Date(c.getTime().getTime())
+
+    case (c: java.util.Calendar, TimestampType) =>
+      new java.sql.Timestamp(c.getTime().getTime())
+
+    case (c: Int, ByteType) => c.toByte
+    case (c: Long, ByteType) => c.toByte
+    case (c: Int, ShortType) => c.toShort
+    case (c: Long, ShortType) => c.toShort
+    case (c: Long, IntegerType) => c.toInt
+    case (c: Double, FloatType) => c.toFloat
+    case (c, StringType) if !c.isInstanceOf[String] => c.toString
+
+    case (c, _) => c
+  }
 }
 
 /**
@@ -141,8 +216,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
     val parent = childResults.mapPartitions { iter =>
       val pickle = new Pickler
       val currentRow = newMutableProjection(udf.children, child.output)()
+      val fields = udf.children.map(_.dataType)
       iter.grouped(1000).map { inputRows =>
-        val toBePickled = inputRows.map(currentRow(_).toArray).toArray
+        val toBePickled = inputRows.map { row =>
+          EvaluatePython.rowToArray(currentRow(row), fields)
+        }.toArray
         pickle.dumps(toBePickled)
       }
     }
@@ -165,10 +243,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
     }.mapPartitions { iter =>
       val row = new GenericMutableRow(1)
       iter.map { result =>
-        row(0) = udf.dataType match {
-          case StringType => result.toString
-          case other => result
-        }
+        row(0) = EvaluatePython.fromJava(result, udf.dataType)
         row: Row
       }
     }
-- 
GitLab