diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d87783efd2d01bc237d09f3ec2757743545c0836..0d8453fb184a30ee6a7cf489f81e2e9fcdeb611d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging { def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler - // TODO: Figure out why flatMap is necessay for pyspark iter.flatMap { row => unpickle.loads(row) match { + // in case of objects are pickled in batch mode case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // Incase the partition doesn't have a collection + // not in batch mode case obj: JMap[String @unchecked, _] => Seq(obj.toMap) } } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823b627c751ce1eb50552c8aea7e58..a6b3277db3266a5213c1992dda10b2f035655865 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -47,12 +47,14 @@ class SQLContext: ... ValueError:... - >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, - ... "boolean" : True}]) + >>> from datetime import datetime + >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, + ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1}, + ... "list": [1, 2, 3]}]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean)) + ... x.boolean, x.time, x.dict["a"], x.list)) >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True) + (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -88,13 +90,13 @@ class SQLContext: >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] + >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}}, + ... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}] True >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}] True """ if (rdd.__class__ is SchemaRDD): @@ -509,8 +511,8 @@ def _test(): {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, - {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, + {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd2750b626bdbb1a769dcbb7fafad3b..c178dad662532d22368a04e1bddf383c07a2a3f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext) case c: java.lang.Long => LongType case c: java.lang.Double => DoubleType case c: java.lang.Boolean => BooleanType + case c: java.math.BigDecimal => DecimalType + case c: java.sql.Timestamp => TimestampType + case c: java.util.Calendar => TimestampType case c: java.util.List[_] => ArrayType(typeFor(c.head)) - case c: java.util.Set[_] => ArrayType(typeFor(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head MapType(typeFor(key), typeFor(value)) @@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext) ArrayType(typeFor(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } - val schema = rdd.first().map { case (fieldName, obj) => + val firstRow = rdd.first() + val schema = firstRow.map { case (fieldName, obj) => AttributeReference(fieldName, typeFor(obj), true)() }.toSeq - val rowRdd = rdd.mapPartitions { iter => + def needTransform(obj: Any): Boolean = obj match { + case c: java.util.List[_] => true + case c: java.util.Map[_, _] => true + case c if c.getClass.isArray => true + case c: java.util.Calendar => true + case c => false + } + + // convert JList, JArray into Seq, convert JMap into Map + // convert Calendar into Timestamp + def transform(obj: Any): Any = obj match { + case c: java.util.List[_] => c.map(transform).toSeq + case c: java.util.Map[_, _] => c.map { + case (key, value) => (key, transform(value)) + }.toMap + case c if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(transform).toSeq + case c: java.util.Calendar => + new java.sql.Timestamp(c.getTime().getTime()) + case c => c + } + + val need = firstRow.exists {case (key, value) => needTransform(value)} + val transformed = if (need) { + rdd.mapPartitions { iter => + iter.map { + m => m.map {case (key, value) => (key, transform(value))} + } + } + } else rdd + + val rowRdd = transformed.mapPartitions { iter => iter.map { map => new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 31d27bb4f0571f1a2c34f07c248fce70d62a99df..019ff9d300a18463e9b506161616542843172b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType} +import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -376,39 +376,27 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + def toJava(obj: Any, dataType: DataType): Any = dataType match { + case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) + case array: ArrayType => obj match { + case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava + case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + case other => other + } + case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + // Pyrolite can handle Timestamp + case other => obj + } def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { val fields = structType.fields.map(field => (field.name, field.dataType)) val map: JMap[String, Any] = new java.util.HashMap row.zip(fields).foreach { - case (obj, (attrName, dataType)) => - dataType match { - case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => - val arrayValues = obj match { - case seq: Seq[Any] => - seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map { - element => rowToMap(element.asInstanceOf[Row], struct) - } - case other => other - } - map.put(attrName, arrayValues) - case array: ArrayType => { - val arrayValues = obj match { - case seq: Seq[Any] => seq.asJava - case other => other - } - map.put(attrName, arrayValues) - } - case other => map.put(attrName, obj) - } + case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType)) } - map }