diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 60141792d499b4c859734952f67cd367afad009a..7dfa17f68a943749d7a6098ca6a4b70fc91d0b91 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -627,7 +627,6 @@ class RDD(object): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e3bf0f35ea15ea016d99a3c4dc0a6c4dc0fd7387..2cc0e2d1d7b8dd6abaa0b6505d2a4b892e7826cc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -33,7 +33,7 @@ from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ +from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -514,17 +514,21 @@ class SparkSession(object): schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] - verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): + verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, schema) + verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) + verify_func = _make_type_verifier( + dataType, name="field value") if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, dataType) + verify_func(obj) return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c105969b26b9779785afc011fd51c2145527c3ba..16ba8bd73f40098c219503b3c0854566d4a7b701 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -57,7 +57,7 @@ except: from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -852,7 +852,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier from pyspark.sql.tests import ExamplePointUDT, ExamplePoint def check_datatype(datatype): @@ -868,8 +868,8 @@ class SQLTests(ReusedPySparkTestCase): check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) - _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), @@ -877,8 +877,10 @@ class SQLTests(ReusedPySparkTestCase): check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) @@ -2636,6 +2638,195 @@ class HiveContextSQLTests(ReusedPySparkTestCase): importlib.reload(window) + +class DataTypeVerificationTests(unittest.TestCase): + + def test_verify_type_exception_msg(self): + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) + + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) + + def test_verify_type_ok_nullable(self): + obj = None + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: + try: + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + schema = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), + + # UDT + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Byte + (-(2**7) - 1, ByteType(), ValueError), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Short + (-(2**15) - 1, ShortType(), ValueError), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (2**31, IntegerType(), ValueError), + + # Float & Double + (1, FloatType(), TypeError), + (1, DoubleType(), TypeError), + + # Decimal + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (1, BinaryType(), TypeError), + + # Date/Timestamp + ("2000-01-02", DateType(), TypeError), + (946811040, TimestampType(), TypeError), + + # Array + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, "2"], ArrayType(IntegerType()), TypeError), + + # Map + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), + ] + + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb370965312c66e4a477319f798364a0e..f5505ed4722ad64fa87e8df7d06385b900f5021e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,121 +1249,196 @@ _acceptable_types = { } -def _verify_type(obj, dataType, nullable=True): +def _make_type_verifier(dataType, nullable=True, name=None): """ - Verify the type of obj against dataType, raise a TypeError if they do not match. - - Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed - range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it - will become infinity when cast to Java float if it overflows. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, LongType()) - >>> _verify_type(list(range(3)), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Make a verifier that checks the type of obj against dataType and raises a TypeError if they do + not match. + + This verifier also checks the value of obj against datatype and raises a ValueError if it's not + within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is + not checked, so it will become infinity when cast to Java float if it overflows. + + >>> _make_type_verifier(StructType([]))(None) + >>> _make_type_verifier(StringType())("") + >>> _make_type_verifier(LongType())(0) + >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) + >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) + >>> _make_type_verifier(StructType([]))(()) + >>> _make_type_verifier(StructType([]))([]) + >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... >>> # Check if numeric values are within the allowed range. - >>> _verify_type(12, ByteType()) - >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType())(12) + >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier( + ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1}) Traceback (most recent call last): ... ValueError:... >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) - >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... """ - if obj is None: - if nullable: - return - else: - raise ValueError("This field is not nullable, but got None") - # StringType can work with any types - if isinstance(dataType, StringType): - return + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) - return + def verify_nullability(obj): + if obj is None: + if nullable: + return True + else: + raise ValueError(new_msg("This field is not nullable, but got None")) + else: + return False _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) - if _type is StructType: - # check the type and fields later - pass - else: + def assert_acceptable_types(obj): + assert _type in _acceptable_types, \ + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) + + def verify_acceptable_types(obj): # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError(new_msg("%s can not accept object %r in type %s" + % (dataType, obj, type(obj)))) + + if isinstance(dataType, StringType): + # StringType can work with any types + verify_value = lambda _: _ + + elif isinstance(dataType, UserDefinedType): + verifier = _make_type_verifier(dataType.sqlType(), name=name) - if isinstance(dataType, ByteType): - if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + def verify_udf(obj): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) + verifier(dataType.toInternal(obj)) + + verify_value = verify_udf + + elif isinstance(dataType, ByteType): + def verify_byte(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -128 or obj > 127: + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) + + verify_value = verify_byte elif isinstance(dataType, ShortType): - if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + def verify_short(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -32768 or obj > 32767: + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) + + verify_value = verify_short elif isinstance(dataType, IntegerType): - if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + def verify_integer(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -2147483648 or obj > 2147483647: + raise ValueError( + new_msg("object of IntegerType out of range, got: %s" % obj)) + + verify_value = verify_integer elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + element_verifier = _make_type_verifier( + dataType.elementType, dataType.containsNull, name="element in array %s" % name) + + def verify_array(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for i in obj: + element_verifier(i) + + verify_value = verify_array elif isinstance(dataType, MapType): - for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) + value_verifier = _make_type_verifier( + dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) + + def verify_map(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for k, v in obj.items(): + key_verifier(k) + value_verifier(v) + + verify_value = verify_map elif isinstance(dataType, StructType): - if isinstance(obj, dict): - for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - # the order in obj could be different than dataType.fields - for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) - elif isinstance(obj, (tuple, list)): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) - elif hasattr(obj, "__dict__"): - d = obj.__dict__ - for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) - else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + verifiers = [] + for f in dataType.fields: + verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) + verifiers.append((f.name, verifier)) + + def verify_struct(obj): + assert_acceptable_types(obj) + + if isinstance(obj, dict): + for f, verifier in verifiers: + verifier(obj.get(f)) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f, verifier in verifiers: + verifier(obj[f]) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(verifiers): + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(verifiers)))) + for v, (_, verifier) in zip(obj, verifiers): + verifier(v) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f, verifier in verifiers: + verifier(d.get(f)) + else: + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) + verify_value = verify_struct + + else: + def verify_default(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + + verify_value = verify_default + + def verify(obj): + if not verify_nullability(obj): + verify_value(obj) + + return verify # This is used to unpickle a Row from JVM