diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 7daf306f6847908ca7df8f9c29135b01c980b8c9..93fd9d49096b8cea70f6fc21ba4de7ac7102309c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -49,7 +49,7 @@ from pyspark.traceback_utils import SCCallSiteSync __all__ = [ - "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", "SQLContext", "HiveContext", "SchemaRDD", "Row"] @@ -132,6 +132,14 @@ class BooleanType(PrimitiveType): """ +class DateType(PrimitiveType): + + """Spark SQL DateType + + The data type representing datetime.date values. + """ + + class TimestampType(PrimitiveType): """Spark SQL TimestampType @@ -438,7 +446,7 @@ def _parse_datatype_json_value(json_value): return _all_complex_types[json_value["type"]].fromJson(json_value) -# Mapping Python types to Spark SQL DateType +# Mapping Python types to Spark SQL DataType _type_mappings = { bool: BooleanType, int: IntegerType, @@ -448,8 +456,8 @@ _type_mappings = { unicode: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, + datetime.date: DateType, datetime.datetime: TimestampType, - datetime.date: TimestampType, datetime.time: TimestampType, } @@ -656,10 +664,10 @@ def _infer_schema_type(obj, dataType): """ Fill the dataType with types infered from obj - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType... + StructType...IntegerType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) @@ -703,6 +711,7 @@ _acceptable_types = { DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), + DateType: (datetime.date,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -740,7 +749,7 @@ def _verify_type(obj, dataType): # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept abject in type %s" + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): @@ -767,7 +776,7 @@ def _restore_object(dataType, obj): """ Restore object during unpickling. """ # use id(dataType) as key to speed up lookup in dict # Because of batched pickling, dataType will be the - # same object in mose cases. + # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) if cls is None: @@ -782,6 +791,10 @@ def _restore_object(dataType, obj): def _create_object(cls, v): """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() return cls(v) if v is not None else v @@ -795,14 +808,16 @@ def _create_getter(dt, i): return getter -def _has_struct(dt): - """Return whether `dt` is or has StructType in it""" +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): - return _has_struct(dt.elementType) + return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct(dt.valueType) + return _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True return False @@ -815,7 +830,7 @@ def _create_properties(fields): or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," "use position to access it instead" % name) - if _has_struct(f.dataType): + if _has_struct_or_date(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) else: @@ -870,6 +885,9 @@ def _create_cls(dataType): return Dict + elif isinstance(dataType, DateType): + return datetime.date + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1068,8 +1086,9 @@ class SQLContext(object): >>> srdd2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - >>> from datetime import datetime + >>> from datetime import date, datetime >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ @@ -1079,6 +1098,7 @@ class SQLContext(object): ... StructField("short2", ShortType(), False), ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", ... MapType(StringType(), IntegerType(), False), False), @@ -1088,10 +1108,11 @@ class SQLContext(object): ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema) >>> results = srdd.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, - ... x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] - (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) >>> srdd.registerTempTable("table2") >>> sqlCtx.sql( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7d930fccd52d18d3afa5d46455c565c4e9f68a40..d76c743d3f652072182cf80a9e2f8305edb72a5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -112,6 +112,7 @@ object ScalaReflection { case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType case obj: DecimalType.JvmType => DecimalType + case obj: DateType.JvmType => DateType case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 0cf139ebde417c25881730eb032a9ea49573c74d..b9cf37d53ffd29fd5e61f8a281dc653c3c5658e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -91,6 +91,7 @@ object DataType { | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType | "DecimalType" ^^^ DecimalType + | "DateType" ^^^ DateType | "TimestampType" ^^^ TimestampType ) @@ -198,7 +199,8 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all + private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++ + NativeType.all private[sql] val nameToType = all.map(t => t.typeName -> t).toMap } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 488e373854bb3c5d28f22c546859e5728fe3cf8c..430f0664b7d58016746e60a5c5f71ca71f96160b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -43,6 +43,7 @@ case class NullableData( booleanField: java.lang.Boolean, stringField: String, decimalField: BigDecimal, + dateField: Date, timestampField: Timestamp, binaryField: Array[Byte]) @@ -96,6 +97,7 @@ class ScalaReflectionSuite extends FunSuite { StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), StructField("decimalField", DecimalType, nullable = true), + StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), nullable = true)) @@ -199,8 +201,11 @@ class ScalaReflectionSuite extends FunSuite { // DecimalType assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + // DateType + assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) + // TimestampType - assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00"))) + assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) // NullType assert(NullType === typeOfObject(null)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c4f4ef01d78dfb64ab24b8c80b91c29ff5a34718..ca8706ee68697fe49624eb4ddccb7b8dc78c90a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -444,6 +444,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ByteType => true case ShortType => true case FloatType => true + case DateType => true case TimestampType => true case ArrayType(_, _) => true case MapType(_, _, _) => true @@ -452,9 +453,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } // Converts value to the type specified by the data type. - // Because Python does not have data types for TimestampType, FloatType, ShortType, and - // ByteType, we need to explicitly convert values in columns of these data types to the desired - // JVM data types. + // Because Python does not have data types for DateType, TimestampType, FloatType, ShortType, + // and ByteType, we need to explicitly convert values in columns of these data types to the + // desired JVM data types. def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { // TODO: We should check nullable case (null, _) => null @@ -474,6 +475,9 @@ class SQLContext(@transient val sparkContext: SparkContext) case (e, f) => convert(e, f.dataType) }): Row + case (c: java.util.Calendar, DateType) => + new java.sql.Date(c.getTime().getTime()) + case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index bf32da1b7181e14eaa53a76d901e928b58db5171..047dc85df6c1dd84d0328b401076140cfd2f677d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.json import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper @@ -372,13 +372,20 @@ private[sql] object JsonRDD extends Logging { } } + private def toDate(value: Any): Date = { + value match { + // only support string as date + case value: java.lang.String => Date.valueOf(value) + } + } + private def toTimestamp(value: Any): Timestamp = { value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => Timestamp.valueOf(value) - } - } + case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) + case value: java.lang.Long => new Timestamp(value) + case value: java.lang.String => Timestamp.valueOf(value) + } + } private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ if (value == null) { @@ -396,6 +403,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case DateType => toDate(value) case TimestampType => toTimestamp(value) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java index 52d07b5425cc37afa419138e37431f2415e9f8d5..bc5cd66482add6cb6f38950a3834880fc127eea8 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.api.java; import java.math.BigDecimal; +import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; import java.util.HashMap; @@ -39,6 +40,7 @@ public class JavaRowSuite { private boolean booleanValue; private String stringValue; private byte[] binaryValue; + private Date dateValue; private Timestamp timestampValue; @Before @@ -53,6 +55,7 @@ public class JavaRowSuite { booleanValue = true; stringValue = "this is a string"; binaryValue = stringValue.getBytes(); + dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -76,6 +79,7 @@ public class JavaRowSuite { new Boolean(booleanValue), stringValue, // StringType binaryValue, // BinaryType + dateValue, // DateType timestampValue, // TimestampType null // null ); @@ -114,9 +118,10 @@ public class JavaRowSuite { Assert.assertEquals(stringValue, simpleRow.getString(15)); Assert.assertEquals(stringValue, simpleRow.get(15)); Assert.assertEquals(binaryValue, simpleRow.get(16)); - Assert.assertEquals(timestampValue, simpleRow.get(17)); - Assert.assertEquals(true, simpleRow.isNullAt(18)); - Assert.assertEquals(null, simpleRow.get(18)); + Assert.assertEquals(dateValue, simpleRow.get(17)); + Assert.assertEquals(timestampValue, simpleRow.get(18)); + Assert.assertEquals(true, simpleRow.isNullAt(19)); + Assert.assertEquals(null, simpleRow.get(19)); } @Test diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index d099a48a1f4b6c06f5f3b459feae72dd557790ce..d04396a5f8ec20cd9bb9396fb2da96aa7721b763 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -39,6 +39,7 @@ public class JavaSideDataTypeConversionSuite { checkDataType(DataType.StringType); checkDataType(DataType.BinaryType); checkDataType(DataType.BooleanType); + checkDataType(DataType.DateType); checkDataType(DataType.TimestampType); checkDataType(DataType.DecimalType); checkDataType(DataType.DoubleType); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index ff1debff0f8c100e78b498c50e0dc90b7784a859..8415af41be3afd75fb8ef3cdd44f8c5211244171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -38,6 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(org.apache.spark.sql.StringType) checkDataType(org.apache.spark.sql.BinaryType) checkDataType(org.apache.spark.sql.BooleanType) + checkDataType(org.apache.spark.sql.DateType) checkDataType(org.apache.spark.sql.TimestampType) checkDataType(org.apache.spark.sql.DecimalType) checkDataType(org.apache.spark.sql.DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1ae75546aada14b8b11705c0e018bdf30a42b4fd..ce6184f5d8c9db0e5d3a257fdfb01dce6172968d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import java.sql.Timestamp +import java.sql.{Date, Timestamp} class JsonSuite extends QueryTest { import TestJsonData._ @@ -58,8 +58,11 @@ class JsonSuite extends QueryTest { checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) checkTypePromotion(new Timestamp(intNumber.toLong), enforceCorrectType(intNumber.toLong, TimestampType)) - val strDate = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType)) + val strTime = "2014-09-30 12:34:56" + checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) + + val strDate = "2014-10-15" + checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) } test("Get compatible type") {