From 8c0bfd08fc19fa5de7d77bf8306d19834f907ec0 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Tue, 28 Oct 2014 19:38:16 -0700 Subject: [PATCH] [SPARK-4133] [SQL] [PySpark] type conversionfor python udf Call Python UDF on ArrayType/MapType/PrimitiveType, the returnType can also be ArrayType/MapType/PrimitiveType. For StructType, it will act as tuple (without attributes). If returnType is StructType, it also should be tuple. Author: Davies Liu <davies@databricks.com> Closes #2973 from davies/udf_array and squashes the following commits: 306956e [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array 2c00e43 [Davies Liu] fix merge 11395fa [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array 9df50a2 [Davies Liu] address comments 79afb4e [Davies Liu] type conversionfor python udf --- python/pyspark/tests.py | 16 +++- .../org/apache/spark/sql/SQLContext.scala | 43 +-------- .../org/apache/spark/sql/SchemaRDD.scala | 42 +-------- .../spark/sql/execution/pythonUdfs.scala | 91 +++++++++++++++++-- 4 files changed, 102 insertions(+), 90 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 047d857830..37a128907b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType from pyspark import shuffle _have_scipy = False @@ -690,10 +690,20 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(row[0], 5) def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(u"4", res[0]) + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ca8706ee68..a41a500c9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -438,7 +438,6 @@ class SQLContext(@transient val sparkContext: SparkContext) private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { - import scala.collection.JavaConversions._ def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -452,49 +451,9 @@ class SQLContext(@transient val sparkContext: SparkContext) case other => false } - // Converts value to the type specified by the data type. - // Because Python does not have data types for DateType, TimestampType, FloatType, ShortType, - // and ByteType, we need to explicitly convert values in columns of these data types to the - // desired JVM data types. - def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { - // TODO: We should check nullable - case (null, _) => null - - case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => convert(e, elementType)}: Seq[Any] - - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] - - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { - case (key, value) => (convert(key, keyType), convert(value, valueType)) - }.toMap - - case (c, StructType(fields)) if c.getClass.isArray => - new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { - case (e, f) => convert(e, f.dataType) - }): Row - - case (c: java.util.Calendar, DateType) => - new java.sql.Date(c.getTime().getTime()) - - case (c: java.util.Calendar, TimestampType) => - new java.sql.Timestamp(c.getTime().getTime()) - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - case (c: Long, IntegerType) => c.toInt - case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString - - case (c, _) => c - } - val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { rdd.map(m => m.zip(schema.fields).map { - case (value, field) => convert(value, field.dataType) + case (value, field) => EvaluatePython.fromJava(value, field.dataType) }) } else { rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 948122d42f..8b96df1096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} import org.apache.spark.api.java.JavaRDD /** @@ -377,47 +377,15 @@ class SchemaRDD( */ def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) - /** - * Helper for converting a Row to a simple Array suitable for pyspark serialization. - */ - private def rowToJArray(row: Row, structType: StructType): Array[Any] = { - import scala.collection.Map - - def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (obj: Row, struct: StructType) => rowToJArray(obj, struct) - - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava - case (list: JList[_], array: ArrayType) => - list.map(x => toJava(x, array.elementType)).asJava - case (arr, array: ArrayType) if arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) - - case (obj: Map[_, _], mt: MapType) => obj.map { - case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type - }.asJava - - // Pyrolite can handle Timestamp - case (other, _) => other - } - - val fields = structType.fields.map(field => field.dataType) - row.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - } - /** * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val fieldTypes = schema.fields.map(_.dataType) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToJArray(row, rowSchema) + EvaluatePython.rowToArray(row, fieldTypes) }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } @@ -427,10 +395,10 @@ class SchemaRDD( * format as javaToPython. It is used by pyspark. */ private[sql] def collectToPython: JList[Array[Byte]] = { - val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val fieldTypes = schema.fields.map(_.dataType) val pickle = new Pickler new java.util.ArrayList(collect().map { row => - rowToJArray(row, rowSchema) + EvaluatePython.rowToArray(row, fieldTypes) }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index be729e5d24..a1961bba18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonRDD import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -31,8 +34,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types._ import org.apache.spark.{Accumulator, Logging => SparkLogging} -import scala.collection.JavaConversions._ - /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. */ @@ -108,6 +109,80 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { object EvaluatePython { def apply(udf: PythonUDF, child: LogicalPlan) = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + + /** + * Helper for converting a Scala object to a java suitable for pyspark serialization. + */ + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (row: Row, struct: StructType) => + val fields = struct.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + + case (seq: Seq[Any], array: ArrayType) => + seq.map(x => toJava(x, array.elementType)).asJava + case (list: JList[_], array: ArrayType) => + list.map(x => toJava(x, array.elementType)).asJava + case (arr, array: ArrayType) if arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + + case (obj: Map[_, _], mt: MapType) => obj.map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + + // Pyrolite can handle Timestamp + case (other, _) => other + } + + /** + * Convert Row into Java Array (for pickled into Python) + */ + def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray + } + + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + c.map { e => fromJava(e, elementType)}: Seq[Any] + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { + case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) + }.toMap + + case (c, StructType(fields)) if c.getClass.isArray => + new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + case (e, f) => fromJava(e, f.dataType) + }): Row + + case (c: java.util.Calendar, DateType) => + new java.sql.Date(c.getTime().getTime()) + + case (c: java.util.Calendar, TimestampType) => + new java.sql.Timestamp(c.getTime().getTime()) + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt + case (c: Double, FloatType) => c.toFloat + case (c, StringType) if !c.isInstanceOf[String] => c.toString + + case (c, _) => c + } } /** @@ -141,8 +216,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val parent = childResults.mapPartitions { iter => val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() + val fields = udf.children.map(_.dataType) iter.grouped(1000).map { inputRows => - val toBePickled = inputRows.map(currentRow(_).toArray).toArray + val toBePickled = inputRows.map { row => + EvaluatePython.rowToArray(currentRow(row), fields) + }.toArray pickle.dumps(toBePickled) } } @@ -165,10 +243,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: }.mapPartitions { iter => val row = new GenericMutableRow(1) iter.map { result => - row(0) = udf.dataType match { - case StringType => result.toString - case other => result - } + row(0) = EvaluatePython.fromJava(result, udf.dataType) row: Row } } -- GitLab