Skip to content
Snippets Groups Projects
Commit adfd3668 authored by Davies Liu's avatar Davies Liu Committed by Reynold Xin
Browse files

[SPARK-7073] [SQL] [PySpark] Clean up SQL data type hierarchy in Python

Author: Davies Liu <davies@databricks.com>

Closes #6206 from davies/sql_type and squashes the following commits:

33d6860 [Davies Liu] [SPARK-7073] [SQL] [PySpark] Clean up SQL data type hierarchy in Python
parent cc12a86f
No related branches found
No related tags found
No related merge requests found
...@@ -73,56 +73,74 @@ class DataType(object): ...@@ -73,56 +73,74 @@ class DataType(object):
# This singleton pattern does not work with pickle, you will get # This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle # another object after pickle and unpickle
class PrimitiveTypeSingleton(type): class DataTypeSingleton(type):
"""Metaclass for PrimitiveType""" """Metaclass for DataType"""
_instances = {} _instances = {}
def __call__(cls): def __call__(cls):
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
return cls._instances[cls] return cls._instances[cls]
class PrimitiveType(DataType): class NullType(DataType):
"""Spark SQL PrimitiveType""" """Null type.
__metaclass__ = PrimitiveTypeSingleton The data type representing None, used for the types that cannot be inferred.
"""
__metaclass__ = DataTypeSingleton
class NullType(PrimitiveType):
"""Null type.
The data type representing None, used for the types that cannot be inferred. class AtomicType(DataType):
"""An internal type used to represent everything that is not
null, UDTs, arrays, structs, and maps."""
__metaclass__ = DataTypeSingleton
class NumericType(AtomicType):
"""Numeric data types.
""" """
class StringType(PrimitiveType): class IntegralType(NumericType):
"""Integral data types.
"""
class FractionalType(NumericType):
"""Fractional data types.
"""
class StringType(AtomicType):
"""String data type. """String data type.
""" """
class BinaryType(PrimitiveType): class BinaryType(AtomicType):
"""Binary (byte array) data type. """Binary (byte array) data type.
""" """
class BooleanType(PrimitiveType): class BooleanType(AtomicType):
"""Boolean data type. """Boolean data type.
""" """
class DateType(PrimitiveType): class DateType(AtomicType):
"""Date (datetime.date) data type. """Date (datetime.date) data type.
""" """
class TimestampType(PrimitiveType): class TimestampType(AtomicType):
"""Timestamp (datetime.datetime) data type. """Timestamp (datetime.datetime) data type.
""" """
class DecimalType(DataType): class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type. """Decimal (decimal.Decimal) data type.
""" """
...@@ -150,31 +168,31 @@ class DecimalType(DataType): ...@@ -150,31 +168,31 @@ class DecimalType(DataType):
return "DecimalType()" return "DecimalType()"
class DoubleType(PrimitiveType): class DoubleType(FractionalType):
"""Double data type, representing double precision floats. """Double data type, representing double precision floats.
""" """
class FloatType(PrimitiveType): class FloatType(FractionalType):
"""Float data type, representing single precision floats. """Float data type, representing single precision floats.
""" """
class ByteType(PrimitiveType): class ByteType(IntegralType):
"""Byte data type, i.e. a signed integer in a single byte. """Byte data type, i.e. a signed integer in a single byte.
""" """
def simpleString(self): def simpleString(self):
return 'tinyint' return 'tinyint'
class IntegerType(PrimitiveType): class IntegerType(IntegralType):
"""Int data type, i.e. a signed 32-bit integer. """Int data type, i.e. a signed 32-bit integer.
""" """
def simpleString(self): def simpleString(self):
return 'int' return 'int'
class LongType(PrimitiveType): class LongType(IntegralType):
"""Long data type, i.e. a signed 64-bit integer. """Long data type, i.e. a signed 64-bit integer.
If the values are beyond the range of [-9223372036854775808, 9223372036854775807], If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
...@@ -184,7 +202,7 @@ class LongType(PrimitiveType): ...@@ -184,7 +202,7 @@ class LongType(PrimitiveType):
return 'bigint' return 'bigint'
class ShortType(PrimitiveType): class ShortType(IntegralType):
"""Short data type, i.e. a signed 16-bit integer. """Short data type, i.e. a signed 16-bit integer.
""" """
def simpleString(self): def simpleString(self):
...@@ -426,11 +444,9 @@ class UserDefinedType(DataType): ...@@ -426,11 +444,9 @@ class UserDefinedType(DataType):
return type(self) == type(other) return type(self) == type(other)
_all_primitive_types = dict((v.typeName(), v) _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
for v in list(globals().values()) ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
if (type(v) is type or type(v) is PrimitiveTypeSingleton) _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
and v.__base__ == PrimitiveType)
_all_complex_types = dict((v.typeName(), v) _all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType]) for v in [ArrayType, MapType, StructType])
...@@ -444,7 +460,7 @@ def _parse_datatype_json_string(json_string): ...@@ -444,7 +460,7 @@ def _parse_datatype_json_string(json_string):
... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... assert datatype == python_datatype ... assert datatype == python_datatype
>>> for cls in _all_primitive_types.values(): >>> for cls in _all_atomic_types.values():
... check_datatype(cls()) ... check_datatype(cls())
>>> # Simple ArrayType. >>> # Simple ArrayType.
...@@ -494,8 +510,8 @@ _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") ...@@ -494,8 +510,8 @@ _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
def _parse_datatype_json_value(json_value): def _parse_datatype_json_value(json_value):
if not isinstance(json_value, dict): if not isinstance(json_value, dict):
if json_value in _all_primitive_types.keys(): if json_value in _all_atomic_types.keys():
return _all_primitive_types[json_value]() return _all_atomic_types[json_value]()
elif json_value == 'decimal': elif json_value == 'decimal':
return DecimalType() return DecimalType()
elif _FIXED_DECIMAL.match(json_value): elif _FIXED_DECIMAL.match(json_value):
...@@ -1125,7 +1141,7 @@ def _create_cls(dataType): ...@@ -1125,7 +1141,7 @@ def _create_cls(dataType):
return lambda datum: dataType.deserialize(datum) return lambda datum: dataType.deserialize(datum)
elif not isinstance(dataType, StructType): elif not isinstance(dataType, StructType):
# no wrapper for primitive types # no wrapper for atomic types
return lambda x: x return lambda x: x
class Row(tuple): class Row(tuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment