Skip to content
Snippets Groups Projects
Commit 084dca77 authored by Tathagata Das's avatar Tathagata Das Committed by Shixiong Zhu
Browse files

[SPARK-15981][SQL][STREAMING] Fixed bug and added tests in DataStreamReader Python API

## What changes were proposed in this pull request?

- Fixed bug in Python API of DataStreamReader.  Because a single path was being converted to a array before calling Java DataStreamReader method (which takes a string only), it gave the following error.
```
File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 947, in pyspark.sql.readwriter.DataStreamReader.json
Failed example:
    json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'),                 schema = sdf_schema)
Exception raised:
    Traceback (most recent call last):
      File "/System/Library/Frameworks/Python.framework/Versions/2.6/lib/python2.6/doctest.py", line 1253, in __run
        compileflags, 1) in test.globs
      File "<doctest pyspark.sql.readwriter.DataStreamReader.json[0]>", line 1, in <module>
        json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'),                 schema = sdf_schema)
      File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 963, in json
        return self._df(self._jreader.json(path))
      File "/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__
        answer, self.gateway_client, self.target_id, self.name)
      File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/utils.py", line 63, in deco
        return f(*a, **kw)
      File "/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 316, in get_return_value
        format(target_id, ".", name, value))
    Py4JError: An error occurred while calling o121.json. Trace:
    py4j.Py4JException: Method json([class java.util.ArrayList]) does not exist
    	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
    	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
    	at py4j.Gateway.invoke(Gateway.java:272)
    	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
    	at py4j.commands.CallCommand.execute(CallCommand.java:79)
    	at py4j.GatewayConnection.run(GatewayConnection.java:211)
    	at java.lang.Thread.run(Thread.java:744)
```

- Reduced code duplication between DataStreamReader and DataFrameWriter
- Added missing Python doctests

## How was this patch tested?
New tests

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #13703 from tdas/SPARK-15981.
parent a865f6e0
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,82 @@ def to_str(value): ...@@ -44,7 +44,82 @@ def to_str(value):
return str(value) return str(value)
class DataFrameReader(object): class ReaderUtils(object):
def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
allowComments, allowUnquotedFieldNames, allowSingleQuotes,
allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
mode, columnNameOfCorruptRecord):
"""
Set options based on the Json optional parameters
"""
if schema is not None:
self.schema(schema)
if primitivesAsString is not None:
self.option("primitivesAsString", primitivesAsString)
if prefersDecimal is not None:
self.option("prefersDecimal", prefersDecimal)
if allowComments is not None:
self.option("allowComments", allowComments)
if allowUnquotedFieldNames is not None:
self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
if allowSingleQuotes is not None:
self.option("allowSingleQuotes", allowSingleQuotes)
if allowNumericLeadingZero is not None:
self.option("allowNumericLeadingZero", allowNumericLeadingZero)
if allowBackslashEscapingAnyCharacter is not None:
self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
if mode is not None:
self.option("mode", mode)
if columnNameOfCorruptRecord is not None:
self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
def _set_csv_opts(self, schema, sep, encoding, quote, escape,
comment, header, inferSchema, ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
dateFormat, maxColumns, maxCharsPerColumn, mode):
"""
Set options based on the CSV optional parameters
"""
if schema is not None:
self.schema(schema)
if sep is not None:
self.option("sep", sep)
if encoding is not None:
self.option("encoding", encoding)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if comment is not None:
self.option("comment", comment)
if header is not None:
self.option("header", header)
if inferSchema is not None:
self.option("inferSchema", inferSchema)
if ignoreLeadingWhiteSpace is not None:
self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
if ignoreTrailingWhiteSpace is not None:
self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
if nullValue is not None:
self.option("nullValue", nullValue)
if nanValue is not None:
self.option("nanValue", nanValue)
if positiveInf is not None:
self.option("positiveInf", positiveInf)
if negativeInf is not None:
self.option("negativeInf", negativeInf)
if dateFormat is not None:
self.option("dateFormat", dateFormat)
if maxColumns is not None:
self.option("maxColumns", maxColumns)
if maxCharsPerColumn is not None:
self.option("maxCharsPerColumn", maxCharsPerColumn)
if mode is not None:
self.option("mode", mode)
class DataFrameReader(ReaderUtils):
""" """
Interface used to load a :class:`DataFrame` from external storage systems Interface used to load a :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.read` (e.g. file systems, key-value stores, etc). Use :func:`spark.read`
...@@ -193,26 +268,10 @@ class DataFrameReader(object): ...@@ -193,26 +268,10 @@ class DataFrameReader(object):
[('age', 'bigint'), ('name', 'string')] [('age', 'bigint'), ('name', 'string')]
""" """
if schema is not None: self._set_json_opts(schema, primitivesAsString, prefersDecimal,
self.schema(schema) allowComments, allowUnquotedFieldNames, allowSingleQuotes,
if primitivesAsString is not None: allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
self.option("primitivesAsString", primitivesAsString) mode, columnNameOfCorruptRecord)
if prefersDecimal is not None:
self.option("prefersDecimal", prefersDecimal)
if allowComments is not None:
self.option("allowComments", allowComments)
if allowUnquotedFieldNames is not None:
self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
if allowSingleQuotes is not None:
self.option("allowSingleQuotes", allowSingleQuotes)
if allowNumericLeadingZero is not None:
self.option("allowNumericLeadingZero", allowNumericLeadingZero)
if allowBackslashEscapingAnyCharacter is not None:
self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
if mode is not None:
self.option("mode", mode)
if columnNameOfCorruptRecord is not None:
self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
if isinstance(path, basestring): if isinstance(path, basestring):
path = [path] path = [path]
if type(path) == list: if type(path) == list:
...@@ -345,42 +404,11 @@ class DataFrameReader(object): ...@@ -345,42 +404,11 @@ class DataFrameReader(object):
>>> df.dtypes >>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')] [('_c0', 'string'), ('_c1', 'string')]
""" """
if schema is not None:
self.schema(schema) self._set_csv_opts(schema, sep, encoding, quote, escape,
if sep is not None: comment, header, inferSchema, ignoreLeadingWhiteSpace,
self.option("sep", sep) ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
if encoding is not None: dateFormat, maxColumns, maxCharsPerColumn, mode)
self.option("encoding", encoding)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if comment is not None:
self.option("comment", comment)
if header is not None:
self.option("header", header)
if inferSchema is not None:
self.option("inferSchema", inferSchema)
if ignoreLeadingWhiteSpace is not None:
self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
if ignoreTrailingWhiteSpace is not None:
self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
if nullValue is not None:
self.option("nullValue", nullValue)
if nanValue is not None:
self.option("nanValue", nanValue)
if positiveInf is not None:
self.option("positiveInf", positiveInf)
if negativeInf is not None:
self.option("negativeInf", negativeInf)
if dateFormat is not None:
self.option("dateFormat", dateFormat)
if maxColumns is not None:
self.option("maxColumns", maxColumns)
if maxCharsPerColumn is not None:
self.option("maxCharsPerColumn", maxCharsPerColumn)
if mode is not None:
self.option("mode", mode)
if isinstance(path, basestring): if isinstance(path, basestring):
path = [path] path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
...@@ -764,7 +792,7 @@ class DataFrameWriter(object): ...@@ -764,7 +792,7 @@ class DataFrameWriter(object):
self._jwrite.mode(mode).jdbc(url, table, jprop) self._jwrite.mode(mode).jdbc(url, table, jprop)
class DataStreamReader(object): class DataStreamReader(ReaderUtils):
""" """
Interface used to load a streaming :class:`DataFrame` from external storage systems Interface used to load a streaming :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
...@@ -791,6 +819,7 @@ class DataStreamReader(object): ...@@ -791,6 +819,7 @@ class DataStreamReader(object):
:param source: string, name of the data source, e.g. 'json', 'parquet'. :param source: string, name of the data source, e.g. 'json', 'parquet'.
>>> s = spark.readStream.format("text")
""" """
self._jreader = self._jreader.format(source) self._jreader = self._jreader.format(source)
return self return self
...@@ -806,6 +835,8 @@ class DataStreamReader(object): ...@@ -806,6 +835,8 @@ class DataStreamReader(object):
.. note:: Experimental. .. note:: Experimental.
:param schema: a StructType object :param schema: a StructType object
>>> s = spark.readStream.schema(sdf_schema)
""" """
if not isinstance(schema, StructType): if not isinstance(schema, StructType):
raise TypeError("schema should be StructType") raise TypeError("schema should be StructType")
...@@ -818,6 +849,8 @@ class DataStreamReader(object): ...@@ -818,6 +849,8 @@ class DataStreamReader(object):
"""Adds an input option for the underlying data source. """Adds an input option for the underlying data source.
.. note:: Experimental. .. note:: Experimental.
>>> s = spark.readStream.option("x", 1)
""" """
self._jreader = self._jreader.option(key, to_str(value)) self._jreader = self._jreader.option(key, to_str(value))
return self return self
...@@ -827,6 +860,8 @@ class DataStreamReader(object): ...@@ -827,6 +860,8 @@ class DataStreamReader(object):
"""Adds input options for the underlying data source. """Adds input options for the underlying data source.
.. note:: Experimental. .. note:: Experimental.
>>> s = spark.readStream.options(x="1", y=2)
""" """
for k in options: for k in options:
self._jreader = self._jreader.option(k, to_str(options[k])) self._jreader = self._jreader.option(k, to_str(options[k]))
...@@ -843,6 +878,13 @@ class DataStreamReader(object): ...@@ -843,6 +878,13 @@ class DataStreamReader(object):
:param schema: optional :class:`StructType` for the input schema. :param schema: optional :class:`StructType` for the input schema.
:param options: all other string options :param options: all other string options
>>> json_sdf = spark.readStream.format("json")\
.schema(sdf_schema)\
.load(os.path.join(tempfile.mkdtemp(),'data'))
>>> json_sdf.isStreaming
True
>>> json_sdf.schema == sdf_schema
True
""" """
if format is not None: if format is not None:
self.format(format) self.format(format)
...@@ -905,29 +947,18 @@ class DataStreamReader(object): ...@@ -905,29 +947,18 @@ class DataStreamReader(object):
it uses the value specified in it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``. ``spark.sql.columnNameOfCorruptRecord``.
>>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), \
schema = sdf_schema)
>>> json_sdf.isStreaming
True
>>> json_sdf.schema == sdf_schema
True
""" """
if schema is not None: self._set_json_opts(schema, primitivesAsString, prefersDecimal,
self.schema(schema) allowComments, allowUnquotedFieldNames, allowSingleQuotes,
if primitivesAsString is not None: allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
self.option("primitivesAsString", primitivesAsString) mode, columnNameOfCorruptRecord)
if prefersDecimal is not None:
self.option("prefersDecimal", prefersDecimal)
if allowComments is not None:
self.option("allowComments", allowComments)
if allowUnquotedFieldNames is not None:
self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
if allowSingleQuotes is not None:
self.option("allowSingleQuotes", allowSingleQuotes)
if allowNumericLeadingZero is not None:
self.option("allowNumericLeadingZero", allowNumericLeadingZero)
if allowBackslashEscapingAnyCharacter is not None:
self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
if mode is not None:
self.option("mode", mode)
if columnNameOfCorruptRecord is not None:
self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
if isinstance(path, basestring): if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.json(path)) return self._df(self._jreader.json(path))
else: else:
raise TypeError("path can be only a single string") raise TypeError("path can be only a single string")
...@@ -943,10 +974,15 @@ class DataStreamReader(object): ...@@ -943,10 +974,15 @@ class DataStreamReader(object):
.. note:: Experimental. .. note:: Experimental.
>>> parquet_sdf = spark.readStream.schema(sdf_schema)\
.parquet(os.path.join(tempfile.mkdtemp()))
>>> parquet_sdf.isStreaming
True
>>> parquet_sdf.schema == sdf_schema
True
""" """
if isinstance(path, basestring): if isinstance(path, basestring):
path = [path] return self._df(self._jreader.parquet(path))
return self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path)))
else: else:
raise TypeError("path can be only a single string") raise TypeError("path can be only a single string")
...@@ -964,10 +1000,14 @@ class DataStreamReader(object): ...@@ -964,10 +1000,14 @@ class DataStreamReader(object):
:param paths: string, or list of strings, for input path(s). :param paths: string, or list of strings, for input path(s).
>>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 'data'))
>>> text_sdf.isStreaming
True
>>> "value" in str(text_sdf.schema)
True
""" """
if isinstance(path, basestring): if isinstance(path, basestring):
path = [path] return self._df(self._jreader.text(path))
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
else: else:
raise TypeError("path can be only a single string") raise TypeError("path can be only a single string")
...@@ -1034,46 +1074,20 @@ class DataStreamReader(object): ...@@ -1034,46 +1074,20 @@ class DataStreamReader(object):
* ``DROPMALFORMED`` : ignores the whole corrupted records. * ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records.
>>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 'data'), \
schema = sdf_schema)
>>> csv_sdf.isStreaming
True
>>> csv_sdf.schema == sdf_schema
True
""" """
if schema is not None:
self.schema(schema) self._set_csv_opts(schema, sep, encoding, quote, escape,
if sep is not None: comment, header, inferSchema, ignoreLeadingWhiteSpace,
self.option("sep", sep) ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
if encoding is not None: dateFormat, maxColumns, maxCharsPerColumn, mode)
self.option("encoding", encoding)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if comment is not None:
self.option("comment", comment)
if header is not None:
self.option("header", header)
if inferSchema is not None:
self.option("inferSchema", inferSchema)
if ignoreLeadingWhiteSpace is not None:
self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
if ignoreTrailingWhiteSpace is not None:
self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
if nullValue is not None:
self.option("nullValue", nullValue)
if nanValue is not None:
self.option("nanValue", nanValue)
if positiveInf is not None:
self.option("positiveInf", positiveInf)
if negativeInf is not None:
self.option("negativeInf", negativeInf)
if dateFormat is not None:
self.option("dateFormat", dateFormat)
if maxColumns is not None:
self.option("maxColumns", maxColumns)
if maxCharsPerColumn is not None:
self.option("maxCharsPerColumn", maxCharsPerColumn)
if mode is not None:
self.option("mode", mode)
if isinstance(path, basestring): if isinstance(path, basestring):
path = [path] return self._df(self._jreader.csv(path))
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
else: else:
raise TypeError("path can be only a single string") raise TypeError("path can be only a single string")
...@@ -1286,7 +1300,7 @@ def _test(): ...@@ -1286,7 +1300,7 @@ def _test():
globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned')
globs['sdf'] = \ globs['sdf'] = \
spark.readStream.format('text').load('python/test_support/sql/streaming') spark.readStream.format('text').load('python/test_support/sql/streaming')
globs['sdf_schema'] = StructType([StructField("data", StringType(), False)])
(failure_count, test_count) = doctest.testmod( (failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs, pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment