diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 60f62b219b217c40cd3287047143cc24972267ef..a271afe4cf9baf16b8adf18285e922fb03da0a0e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -428,6 +428,19 @@ class SQLContext(object): """ return DataFrameReader(self) + @property + @since(2.0) + def readStream(self): + """ + Returns a :class:`DataStreamReader` that can be used to read data streams + as a streaming :class:`DataFrame`. + + .. note:: Experimental. + + :return: :class:`DataStreamReader` + """ + return DataStreamReader(self._wrapped) + @property @since(2.0) def streams(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4fa799ac55bdf57245b8eaf26682925804a9f594..0126faf574829254bacc7d3c5e0048f564b2838e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -33,7 +33,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column -from pyspark.sql.readwriter import DataFrameWriter +from pyspark.sql.readwriter import DataFrameWriter, DataStreamWriter from pyspark.sql.types import * __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -172,12 +172,26 @@ class DataFrame(object): @since(1.4) def write(self): """ - Interface for saving the content of the :class:`DataFrame` out into external storage. + Interface for saving the content of the non-streaming :class:`DataFrame` out into external + storage. :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) + @property + @since(2.0) + def writeStream(self): + """ + Interface for saving the content of the streaming :class:`DataFrame` out into external + storage. + + .. note:: Experimental. + + :return: :class:`DataStreamWriter` + """ + return DataStreamWriter(self) + @property @since(1.3) def schema(self): diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 0f50f672a22d1492362acdfd4baad55845d9223b..ad954d0ad8217872d2c7168dbf72cd1a8c1ad916 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -137,34 +137,6 @@ class DataFrameReader(object): else: return self._df(self._jreader.load()) - @since(2.0) - def stream(self, path=None, format=None, schema=None, **options): - """Loads a data stream from a data source and returns it as a :class`DataFrame`. - - .. note:: Experimental. - - :param path: optional string for file-system backed data sources. - :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`StructType` for the input schema. - :param options: all other string options - - >>> df = spark.read.format('text').stream('python/test_support/sql/streaming') - >>> df.isStreaming - True - """ - if format is not None: - self.format(format) - if schema is not None: - self.schema(schema) - self.options(**options) - if path is not None: - if type(path) != str or len(path.strip()) == 0: - raise ValueError("If the path is provided for stream, it needs to be a " + - "non-empty string. List of paths are not supported.") - return self._df(self._jreader.stream(path)) - else: - return self._df(self._jreader.stream()) - @since(1.4) def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, @@ -509,26 +481,6 @@ class DataFrameWriter(object): self._jwrite = self._jwrite.mode(saveMode) return self - @since(2.0) - def outputMode(self, outputMode): - """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - - Options include: - - * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to - the sink - * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink - every time these is some updates - - .. note:: Experimental. - - >>> writer = sdf.write.outputMode('append') - """ - if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: - raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode) - self._jwrite = self._jwrite.outputMode(outputMode) - return self - @since(1.4) def format(self, source): """Specifies the underlying output data source. @@ -571,48 +523,6 @@ class DataFrameWriter(object): self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self - @since(2.0) - def queryName(self, queryName): - """Specifies the name of the :class:`ContinuousQuery` that can be started with - :func:`startStream`. This name must be unique among all the currently active queries - in the associated SparkSession. - - .. note:: Experimental. - - :param queryName: unique name for the query - - >>> writer = sdf.write.queryName('streaming_query') - """ - if not queryName or type(queryName) != str or len(queryName.strip()) == 0: - raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) - self._jwrite = self._jwrite.queryName(queryName) - return self - - @keyword_only - @since(2.0) - def trigger(self, processingTime=None): - """Set the trigger for the stream query. If this is not set it will run the query as fast - as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. - - .. note:: Experimental. - - :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. - - >>> # trigger the query for execution every 5 seconds - >>> writer = sdf.write.trigger(processingTime='5 seconds') - """ - from pyspark.sql.streaming import ProcessingTime - trigger = None - if processingTime is not None: - if type(processingTime) != str or len(processingTime.strip()) == 0: - raise ValueError('The processing time must be a non empty string. Got: %s' % - processingTime) - trigger = ProcessingTime(processingTime) - if trigger is None: - raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) - return self - @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -644,57 +554,6 @@ class DataFrameWriter(object): else: self._jwrite.save(path) - @ignore_unicode_prefix - @since(2.0) - def startStream(self, path=None, format=None, partitionBy=None, queryName=None, **options): - """Streams the contents of the :class:`DataFrame` to a data source. - - The data source is specified by the ``format`` and a set of ``options``. - If ``format`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - .. note:: Experimental. - - :param path: the path in a Hadoop supported file system - :param format: the format used to save - - * ``append``: Append contents of this :class:`DataFrame` to existing data. - * ``overwrite``: Overwrite existing data. - * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. - :param partitionBy: names of partitioning columns - :param queryName: unique name for the query - :param options: All other string options. You may want to provide a `checkpointLocation` - for most streams, however it is not required for a `memory` stream. - - >>> cq = sdf.write.format('memory').queryName('this_query').startStream() - >>> cq.isActive - True - >>> cq.name - u'this_query' - >>> cq.stop() - >>> cq.isActive - False - >>> cq = sdf.write.trigger(processingTime='5 seconds').startStream( - ... queryName='that_query', format='memory') - >>> cq.name - u'that_query' - >>> cq.isActive - True - >>> cq.stop() - """ - self.options(**options) - if partitionBy is not None: - self.partitionBy(partitionBy) - if format is not None: - self.format(format) - if queryName is not None: - self.queryName(queryName) - if path is None: - return self._cq(self._jwrite.startStream()) - else: - return self._cq(self._jwrite.startStream(path)) - @since(1.4) def insertInto(self, tableName, overwrite=False): """Inserts the content of the :class:`DataFrame` to the specified table. @@ -905,6 +764,503 @@ class DataFrameWriter(object): self._jwrite.mode(mode).jdbc(url, table, jprop) +class DataStreamReader(object): + """ + Interface used to load a streaming :class:`DataFrame` from external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` + to access this. + + .. note:: Experimental. + + .. versionadded:: 2.0 + """ + + def __init__(self, spark): + self._jreader = spark._ssql_ctx.readStream() + self._spark = spark + + def _df(self, jdf): + from pyspark.sql.dataframe import DataFrame + return DataFrame(jdf, self._spark) + + @since(2.0) + def format(self, source): + """Specifies the input data source format. + + .. note:: Experimental. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + """ + self._jreader = self._jreader.format(source) + return self + + @since(2.0) + def schema(self, schema): + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + .. note:: Experimental. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._spark._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(2.0) + def option(self, key, value): + """Adds an input option for the underlying data source. + + .. note:: Experimental. + """ + self._jreader = self._jreader.option(key, to_str(value)) + return self + + @since(2.0) + def options(self, **options): + """Adds input options for the underlying data source. + + .. note:: Experimental. + """ + for k in options: + self._jreader = self._jreader.option(k, to_str(options[k])) + return self + + @since(2.0) + def load(self, path=None, format=None, schema=None, **options): + """Loads a data stream from a data source and returns it as a :class`DataFrame`. + + .. note:: Experimental. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`StructType` for the input schema. + :param options: all other string options + + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + if type(path) != str or len(path.strip()) == 0: + raise ValueError("If the path is provided for stream, it needs to be a " + + "non-empty string. List of paths are not supported.") + return self._df(self._jreader.load(path)) + else: + return self._df(self._jreader.load()) + + @since(2.0) + def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, + allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, + allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, + mode=None, columnNameOfCorruptRecord=None): + """ + Loads a JSON file stream (one object per line) and returns a :class`DataFrame`. + + If the ``schema`` parameter is not specified, this function goes + through the input once to determine the input schema. + + .. note:: Experimental. + + :param path: string represents path to the JSON dataset, + or RDD of Strings storing JSON objects. + :param schema: an optional :class:`StructType` for the input schema. + :param primitivesAsString: infers all primitive values as a string type. If None is set, + it uses the default value, ``false``. + :param prefersDecimal: infers all floating-point values as a decimal type. If the values + do not fit in decimal, then it infers them as doubles. If None is + set, it uses the default value, ``false``. + :param allowComments: ignores Java/C++ style comment in JSON records. If None is set, + it uses the default value, ``false``. + :param allowUnquotedFieldNames: allows unquoted JSON field names. If None is set, + it uses the default value, ``false``. + :param allowSingleQuotes: allows single quotes in addition to double quotes. If None is + set, it uses the default value, ``true``. + :param allowNumericLeadingZero: allows leading zeros in numbers (e.g. 00012). If None is + set, it uses the default value, ``false``. + :param allowBackslashEscapingAnyCharacter: allows accepting quoting of all character + using backslash quoting mechanism. If None is + set, it uses the default value, ``false``. + :param mode: allows a mode for dealing with corrupt records during parsing. If None is + set, it uses the default value, ``PERMISSIVE``. + + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record and puts the malformed string into a new field configured by \ + ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ + ``null`` for extra fields. + * ``DROPMALFORMED`` : ignores the whole corrupted records. + * ``FAILFAST`` : throws an exception when it meets corrupted records. + + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + + """ + if schema is not None: + self.schema(schema) + if primitivesAsString is not None: + self.option("primitivesAsString", primitivesAsString) + if prefersDecimal is not None: + self.option("prefersDecimal", prefersDecimal) + if allowComments is not None: + self.option("allowComments", allowComments) + if allowUnquotedFieldNames is not None: + self.option("allowUnquotedFieldNames", allowUnquotedFieldNames) + if allowSingleQuotes is not None: + self.option("allowSingleQuotes", allowSingleQuotes) + if allowNumericLeadingZero is not None: + self.option("allowNumericLeadingZero", allowNumericLeadingZero) + if allowBackslashEscapingAnyCharacter is not None: + self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter) + if mode is not None: + self.option("mode", mode) + if columnNameOfCorruptRecord is not None: + self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.json(path)) + else: + raise TypeError("path can be only a single string") + + @since(2.0) + def parquet(self, path): + """Loads a Parquet file stream, returning the result as a :class:`DataFrame`. + + You can set the following Parquet-specific option(s) for reading Parquet files: + * ``mergeSchema``: sets whether we should merge schemas collected from all \ + Parquet part-files. This will override ``spark.sql.parquet.mergeSchema``. \ + The default value is specified in ``spark.sql.parquet.mergeSchema``. + + .. note:: Experimental. + + """ + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path))) + else: + raise TypeError("path can be only a single string") + + @ignore_unicode_prefix + @since(2.0) + def text(self, path): + """ + Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a + string column named "value", and followed by partitioned columns if there + are any. + + Each line in the text file is a new row in the resulting DataFrame. + + .. note:: Experimental. + + :param paths: string, or list of strings, for input path(s). + + """ + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) + else: + raise TypeError("path can be only a single string") + + @since(2.0) + def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, + comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, + ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, + negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, mode=None): + """Loads a CSV file stream and returns the result as a :class:`DataFrame`. + + This function will go through the input once to determine the input schema if + ``inferSchema`` is enabled. To avoid going through the entire data once, disable + ``inferSchema`` option or specify the schema explicitly using ``schema``. + + .. note:: Experimental. + + :param path: string, or list of strings, for input path(s). + :param schema: an optional :class:`StructType` for the input schema. + :param sep: sets the single character as a separator for each field and value. + If None is set, it uses the default value, ``,``. + :param encoding: decodes the CSV files by the given encoding type. If None is set, + it uses the default value, ``UTF-8``. + :param quote: sets the single character used for escaping quoted values where the + separator can be part of the value. If None is set, it uses the default + value, ``"``. If you would like to turn off quotations, you need to set an + empty string. + :param escape: sets the single character used for escaping quotes inside an already + quoted value. If None is set, it uses the default value, ``\``. + :param comment: sets the single character used for skipping lines beginning with this + character. By default (None), it is disabled. + :param header: uses the first line as names of columns. If None is set, it uses the + default value, ``false``. + :param inferSchema: infers the input schema automatically from data. It requires one extra + pass over the data. If None is set, it uses the default value, ``false``. + :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values + being read should be skipped. If None is set, it uses + the default value, ``false``. + :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values + being read should be skipped. If None is set, it uses + the default value, ``false``. + :param nullValue: sets the string representation of a null value. If None is set, it uses + the default value, empty string. + :param nanValue: sets the string representation of a non-number value. If None is set, it + uses the default value, ``NaN``. + :param positiveInf: sets the string representation of a positive infinity value. If None + is set, it uses the default value, ``Inf``. + :param negativeInf: sets the string representation of a negative infinity value. If None + is set, it uses the default value, ``Inf``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to both date type and timestamp type. By default, it is None + which means trying to parse times and date by + ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``. + :param maxColumns: defines a hard limit of how many columns a record can have. If None is + set, it uses the default value, ``20480``. + :param maxCharsPerColumn: defines the maximum number of characters allowed for any given + value being read. If None is set, it uses the default value, + ``1000000``. + :param mode: allows a mode for dealing with corrupt records during parsing. If None is + set, it uses the default value, ``PERMISSIVE``. + + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. + When a schema is set by user, it sets ``null`` for extra fields. + * ``DROPMALFORMED`` : ignores the whole corrupted records. + * ``FAILFAST`` : throws an exception when it meets corrupted records. + + """ + if schema is not None: + self.schema(schema) + if sep is not None: + self.option("sep", sep) + if encoding is not None: + self.option("encoding", encoding) + if quote is not None: + self.option("quote", quote) + if escape is not None: + self.option("escape", escape) + if comment is not None: + self.option("comment", comment) + if header is not None: + self.option("header", header) + if inferSchema is not None: + self.option("inferSchema", inferSchema) + if ignoreLeadingWhiteSpace is not None: + self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + if ignoreTrailingWhiteSpace is not None: + self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + if nullValue is not None: + self.option("nullValue", nullValue) + if nanValue is not None: + self.option("nanValue", nanValue) + if positiveInf is not None: + self.option("positiveInf", positiveInf) + if negativeInf is not None: + self.option("negativeInf", negativeInf) + if dateFormat is not None: + self.option("dateFormat", dateFormat) + if maxColumns is not None: + self.option("maxColumns", maxColumns) + if maxCharsPerColumn is not None: + self.option("maxCharsPerColumn", maxCharsPerColumn) + if mode is not None: + self.option("mode", mode) + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + else: + raise TypeError("path can be only a single string") + + +class DataStreamWriter(object): + """ + Interface used to write a streaming :class:`DataFrame` to external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream` + to access this. + + .. note:: Experimental. + + .. versionadded:: 2.0 + """ + + def __init__(self, df): + self._df = df + self._spark = df.sql_ctx + self._jwrite = df._jdf.writeStream() + + def _cq(self, jcq): + from pyspark.sql.streaming import ContinuousQuery + return ContinuousQuery(jcq) + + @since(2.0) + def outputMode(self, outputMode): + """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + + Options include: + + * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to + the sink + * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink + every time these is some updates + + .. note:: Experimental. + + >>> writer = sdf.writeStream.outputMode('append') + """ + if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: + raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode) + self._jwrite = self._jwrite.outputMode(outputMode) + return self + + @since(2.0) + def format(self, source): + """Specifies the underlying output data source. + + .. note:: Experimental. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> writer = sdf.writeStream.format('json') + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(2.0) + def option(self, key, value): + """Adds an output option for the underlying data source. + + .. note:: Experimental. + """ + self._jwrite = self._jwrite.option(key, to_str(value)) + return self + + @since(2.0) + def options(self, **options): + """Adds output options for the underlying data source. + + .. note:: Experimental. + """ + for k in options: + self._jwrite = self._jwrite.option(k, to_str(options[k])) + return self + + @since(2.0) + def partitionBy(self, *cols): + """Partitions the output by the given columns on the file system. + + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + .. note:: Experimental. + + :param cols: name of columns + + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) + return self + + @since(2.0) + def queryName(self, queryName): + """Specifies the name of the :class:`ContinuousQuery` that can be started with + :func:`startStream`. This name must be unique among all the currently active queries + in the associated SparkSession. + + .. note:: Experimental. + + :param queryName: unique name for the query + + >>> writer = sdf.writeStream.queryName('streaming_query') + """ + if not queryName or type(queryName) != str or len(queryName.strip()) == 0: + raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) + self._jwrite = self._jwrite.queryName(queryName) + return self + + @keyword_only + @since(2.0) + def trigger(self, processingTime=None): + """Set the trigger for the stream query. If this is not set it will run the query as fast + as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. + + .. note:: Experimental. + + :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') + """ + from pyspark.sql.streaming import ProcessingTime + trigger = None + if processingTime is not None: + if type(processingTime) != str or len(processingTime.strip()) == 0: + raise ValueError('The processing time must be a non empty string. Got: %s' % + processingTime) + trigger = ProcessingTime(processingTime) + if trigger is None: + raise ValueError('A trigger was not provided. Supported triggers: processingTime.') + self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) + return self + + @ignore_unicode_prefix + @since(2.0) + def start(self, path=None, format=None, partitionBy=None, queryName=None, **options): + """Streams the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + .. note:: Experimental. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + :param queryName: unique name for the query + :param options: All other string options. You may want to provide a `checkpointLocation` + for most streams, however it is not required for a `memory` stream. + + >>> cq = sdf.writeStream.format('memory').queryName('this_query').start() + >>> cq.isActive + True + >>> cq.name + u'this_query' + >>> cq.stop() + >>> cq.isActive + False + >>> cq = sdf.writeStream.trigger(processingTime='5 seconds').start( + ... queryName='that_query', format='memory') + >>> cq.name + u'that_query' + >>> cq.isActive + True + >>> cq.stop() + """ + self.options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if queryName is not None: + self.queryName(queryName) + if path is None: + return self._cq(self._jwrite.start()) + else: + return self._cq(self._jwrite.start(path)) + + def _test(): import doctest import os @@ -929,7 +1285,7 @@ def _test(): globs['spark'] = spark globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') globs['sdf'] = \ - spark.read.format('text').stream('python/test_support/sql/streaming') + spark.readStream.format('text').load('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f0bf0923b8c759d6170f77104799a2b97995e00b..11c815dd9450b96471273a4e24a295ed19e39e9e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -31,7 +31,7 @@ from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.sql.catalog import Catalog from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame -from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.readwriter import DataFrameReader, DataStreamReader from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -549,6 +549,19 @@ class SparkSession(object): """ return DataFrameReader(self._wrapped) + @property + @since(2.0) + def readStream(self): + """ + Returns a :class:`DataStreamReader` that can be used to read data streams + as a streaming :class:`DataFrame`. + + .. note:: Experimental. + + :return: :class:`DataStreamReader` + """ + return DataStreamReader(self._wrapped) + @property @since(2.0) def streams(self): @@ -556,6 +569,8 @@ class SparkSession(object): :class:`ContinuousQuery` ContinuousQueries active on `this` context. .. note:: Experimental. + + :return: :class:`ContinuousQueryManager` """ from pyspark.sql.streaming import ContinuousQueryManager return ContinuousQueryManager(self._jsparkSession.streams()) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bb4e62cdd6a5682e99c692d7d36040e7b723ceee..0edaa515493955e6bec1b47f19e5c8cfb64b62c2 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -119,7 +119,7 @@ class ContinuousQueryManager(object): def active(self): """Returns a list of active queries associated with this SQLContext - >>> cq = df.write.format('memory').queryName('this_query').startStream() + >>> cq = df.writeStream.format('memory').queryName('this_query').start() >>> cqm = spark.streams >>> # get the list of active continuous queries >>> [q.name for q in cqm.active] @@ -134,7 +134,7 @@ class ContinuousQueryManager(object): """Returns an active query from this SQLContext or throws exception if an active query with this name doesn't exist. - >>> cq = df.write.format('memory').queryName('this_query').startStream() + >>> cq = df.writeStream.format('memory').queryName('this_query').start() >>> cq.name u'this_query' >>> cq = spark.streams.get(cq.id) @@ -236,7 +236,7 @@ def _test(): globs = pyspark.sql.streaming.__dict__.copy() try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) @@ -245,7 +245,7 @@ def _test(): globs['spark'] = spark globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext) globs['df'] = \ - globs['spark'].read.format('text').stream('python/test_support/sql/streaming') + globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.streaming, globs=globs, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e0acde678317d8affe25c6da6908fa2998706be4..fee960a1a7bb43793d74eea1935d6194e8bba020 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -892,9 +892,9 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_trigger_takes_keyword_args(self): - df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') try: - df.write.trigger('5 seconds') + df.writeStream.trigger('5 seconds') self.fail("Should have thrown an exception") except TypeError: # should throw error @@ -902,22 +902,25 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_read_options(self): schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.read.format('text').option('path', 'python/test_support/sql/streaming')\ - .schema(schema).stream() + df = self.spark.readStream\ + .format('text')\ + .option('path', 'python/test_support/sql/streaming')\ + .schema(schema)\ + .load() self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct<data:string>") def test_stream_read_options_overwrite(self): bad_schema = StructType([StructField("test", IntegerType(), False)]) schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.read.format('csv').option('path', 'python/test_support/sql/fake') \ - .schema(bad_schema).stream(path='python/test_support/sql/streaming', - schema=schema, format='text') + df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema)\ + .load(path='python/test_support/sql/streaming', schema=schema, format='text') self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct<data:string>") def test_stream_save_options(self): - df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() @@ -925,8 +928,8 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').outputMode('append').option('path', out).startStream() + cq = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ + .format('parquet').outputMode('append').option('path', out).start() try: self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) @@ -941,7 +944,7 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_save_options_overwrite(self): - df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() @@ -951,9 +954,10 @@ class SQLTests(ReusedPySparkTestCase): chk = os.path.join(tmpPath, 'chk') fake1 = os.path.join(tmpPath, 'fake1') fake2 = os.path.join(tmpPath, 'fake2') - cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ + cq = df.writeStream.option('checkpointLocation', fake1)\ + .format('memory').option('path', fake2) \ .queryName('fake_query').outputMode('append') \ - .startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) try: self.assertEqual(cq.name, 'this_query') @@ -971,7 +975,7 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_await_termination(self): - df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() @@ -979,8 +983,8 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + cq = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) try: self.assertTrue(cq.isActive) try: @@ -999,7 +1003,7 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_query_manager_await_termination(self): - df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() @@ -1007,8 +1011,8 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + cq = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) try: self.assertTrue(cq.isActive) try: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dfe31da3f310713f55f2700533f23e2897faad72..2ae854d04f5641ce58e582aefa02a00df3b308d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,7 +22,6 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.Partition -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -30,12 +29,11 @@ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType /** * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems, - * key-value stores, etc) or data streams. Use [[SparkSession.read]] to access this. + * key-value stores, etc). Use [[SparkSession.read]] to access this. * * @since 1.4.0 */ @@ -160,36 +158,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { options = extraOptions.toMap).resolveRelation()) } } - - /** - * :: Experimental :: - * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path - * (e.g. external key-value stores). - * - * @since 2.0.0 - */ - @Experimental - def stream(): DataFrame = { - val dataSource = - DataSource( - sparkSession, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap) - Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) - } - - /** - * :: Experimental :: - * Loads input in as a [[DataFrame]], for data streams that read from some path. - * - * @since 2.0.0 - */ - @Experimental - def stream(path: String): DataFrame = { - option("path", path).stream() - } - /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 171b1378e5f941e73767ae45e0907e77a030f4c0..60a9d1f020b42796d9625fce71b6eab03ecaa657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,20 +23,15 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.Utils /** * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, - * key-value stores, etc) or data streams. Use [[Dataset.write]] to access this. + * key-value stores, etc). Use [[Dataset.write]] to access this. * * @since 1.4.0 */ @@ -54,9 +49,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: SaveMode): DataFrameWriter[T] = { - // mode() is used for non-continuous queries - // outputMode() is used for continuous queries - assertNotStreaming("mode() can only be called on non-continuous queries") this.mode = saveMode this } @@ -71,9 +63,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - // mode() is used for non-continuous queries - // outputMode() is used for continuous queries - assertNotStreaming("mode() can only be called on non-continuous queries") this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append @@ -85,76 +74,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { this } - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be - * written to the sink - * - `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time these is some updates - * - * @since 2.0.0 - */ - @Experimental - def outputMode(outputMode: OutputMode): DataFrameWriter[T] = { - assertStreaming("outputMode() can only be called on continuous queries") - this.outputMode = outputMode - this - } - - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to - * the sink - * - `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink - * every time these is some updates - * - * @since 2.0.0 - */ - @Experimental - def outputMode(outputMode: String): DataFrameWriter[T] = { - assertStreaming("outputMode() can only be called on continuous queries") - this.outputMode = outputMode.toLowerCase match { - case "append" => - OutputMode.Append - case "complete" => - OutputMode.Complete - case _ => - throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append' and 'complete'") - } - this - } - - /** - * :: Experimental :: - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run - * the query as fast as possible. - * - * Scala Example: - * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 2.0.0 - */ - @Experimental - def trigger(trigger: Trigger): DataFrameWriter[T] = { - assertStreaming("trigger() can only be called on continuous queries") - this.trigger = trigger - this - } - /** * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. * @@ -284,7 +203,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { */ def save(): Unit = { assertNotBucketed("save") - assertNotStreaming("save() can only be called on non-continuous queries") val dataSource = DataSource( df.sparkSession, className = source, @@ -294,148 +212,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { dataSource.write(mode, df) } - - /** - * :: Experimental :: - * Specifies the name of the [[ContinuousQuery]] that can be started with `startStream()`. - * This name must be unique among all the currently active queries in the associated SQLContext. - * - * @since 2.0.0 - */ - @Experimental - def queryName(queryName: String): DataFrameWriter[T] = { - assertStreaming("queryName() can only be called on continuous queries") - this.extraOptions += ("queryName" -> queryName) - this - } - - /** - * :: Experimental :: - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - @Experimental - def startStream(path: String): ContinuousQuery = { - option("path", path).startStream() - } - - /** - * :: Experimental :: - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - @Experimental - def startStream(): ContinuousQuery = { - assertNotBucketed("startStream") - assertStreaming("startStream() can only be called on continuous queries") - - if (source == "memory") { - if (extraOptions.get("queryName").isEmpty) { - throw new AnalysisException("queryName must be specified for memory sink") - } - - val sink = new MemorySink(df.schema, outputMode) - val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) - val query = df.sparkSession.sessionState.continuousQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), - df, - sink, - outputMode, - useTempCheckpointLocation = true, - recoverFromCheckpointLocation = false, - trigger = trigger) - resultDf.createOrReplaceTempView(query.name) - query - } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - df.sparkSession.sessionState.continuousQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), - df, - dataSource.createSink(outputMode), - outputMode, - trigger = trigger) - } - } - - /** - * :: Experimental :: - * Starts the execution of the streaming query, which will continually send results to the given - * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data - * generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned - * [[ContinuousQuery]] object can be used to interact with the stream. - * - * Scala example: - * {{{ - * datasetOfString.write.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }) - * }}} - * - * Java example: - * {{{ - * datasetOfString.write().foreach(new ForeachWriter<String>() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }); - * }}} - * - * @since 2.0.0 - */ - @Experimental - def foreach(writer: ForeachWriter[T]): ContinuousQuery = { - assertNotPartitioned("foreach") - assertNotBucketed("foreach") - assertStreaming( - "foreach() can only be called on streaming Datasets/DataFrames.") - - val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc) - df.sparkSession.sessionState.continuousQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), - df, - sink, - outputMode, - useTempCheckpointLocation = true, - trigger = trigger) - } - /** * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. @@ -467,7 +243,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(tableIdent: TableIdentifier): Unit = { assertNotBucketed("insertInto") - assertNotStreaming("insertInto() can only be called on non-continuous queries") val partitions = normalizedParCols.map(_.map(col => col -> (Option.empty[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -586,7 +361,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { - assertNotStreaming("saveAsTable() can only be called on non-continuous queries") val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) @@ -629,7 +403,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - assertNotStreaming("jdbc() can only be called on non-continuous queries") val props = new Properties() extraOptions.foreach { case (key, value) => @@ -688,7 +461,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def json(path: String): Unit = { - assertNotStreaming("json() can only be called on non-continuous queries") format("json").save(path) } @@ -708,7 +480,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def parquet(path: String): Unit = { - assertNotStreaming("parquet() can only be called on non-continuous queries") format("parquet").save(path) } @@ -728,7 +499,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @note Currently, this method can only be used after enabling Hive support */ def orc(path: String): Unit = { - assertNotStreaming("orc() can only be called on non-continuous queries") format("orc").save(path) } @@ -752,7 +522,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.6.0 */ def text(path: String): Unit = { - assertNotStreaming("text() can only be called on non-continuous queries") format("text").save(path) } @@ -782,7 +551,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def csv(path: String): Unit = { - assertNotStreaming("csv() can only be called on non-continuous queries") format("csv").save(path) } @@ -794,10 +562,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private var mode: SaveMode = SaveMode.ErrorIfExists - private var outputMode: OutputMode = OutputMode.Append - - private var trigger: Trigger = ProcessingTime(0L) - private var extraOptions = new scala.collection.mutable.HashMap[String, String] private var partitioningColumns: Option[Seq[String]] = None @@ -807,21 +571,4 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private var numBuckets: Option[Int] = None private var sortColumnNames: Option[Seq[String]] = None - - /////////////////////////////////////////////////////////////////////////////////////// - // Helper functions - /////////////////////////////////////////////////////////////////////////////////////// - - private def assertNotStreaming(errMsg: String): Unit = { - if (df.isStreaming) { - throw new AnalysisException(errMsg) - } - } - - private def assertStreaming(errMsg: String): Unit = { - if (!df.isStreaming) { - throw new AnalysisException(errMsg) - } - } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 53779df3d9c0030f98221ea7a377c79d8b9f7d09..f9db325ea241fad2290360f538312bd0a23af0e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.streaming.ContinuousQuery +import org.apache.spark.sql.streaming.{ContinuousQuery, DataStreamWriter} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -2407,13 +2407,36 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Interface for saving the content of the Dataset out into external storage or streams. + * Interface for saving the content of the non-streaming Dataset out into external storage. * * @group basic * @since 1.6.0 */ @Experimental - def write: DataFrameWriter[T] = new DataFrameWriter[T](this) + def write: DataFrameWriter[T] = { + if (isStreaming) { + logicalPlan.failAnalysis( + "'write' can not be called on streaming Dataset/DataFrame") + } + new DataFrameWriter[T](this) + } + + /** + * :: Experimental :: + * Interface for saving the content of the streaming Dataset out into external storage. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def writeStream: DataStreamWriter[T] = { + if (!isStreaming) { + logicalPlan.failAnalysis( + "'writeStream' can be called only on streaming Dataset/DataFrame") + } + new DataStreamWriter[T](this) + } + /** * Returns the content of the Dataset as a Dataset of JSON strings. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 58b4e6c5f604e2246da4d194fd5414fa5c6f16a9..33f62915df694f66b573ad5369638aa7af7ea482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.streaming.ContinuousQueryManager +import org.apache.spark.sql.streaming.{ContinuousQueryManager, DataStreamReader} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager @@ -491,7 +491,8 @@ class SQLContext private[sql](val sparkSession: SparkSession) } /** - * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a + * [[DataFrame]]. * {{{ * sqlContext.read.parquet("/path/to/file.parquet") * sqlContext.read.schema(schema).json("/path/to/file.json") @@ -502,6 +503,21 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ def read: DataFrameReader = sparkSession.read + + /** + * :: Experimental :: + * Returns a [[DataStreamReader]] that can be used to read streaming data in as a [[DataFrame]]. + * {{{ + * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") + * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") + * }}} + * + * @since 2.0.0 + */ + @Experimental + def readStream: DataStreamReader = sparkSession.readStream + + /** * Creates an external table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 01c2e3ad29e48cccb7f77ae2c589c4d391f44278..9137a735dd4da386d67fb5ce7d3cc572e17eae3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -574,7 +574,8 @@ class SparkSession private( } /** - * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a + * [[DataFrame]]. * {{{ * sparkSession.read.parquet("/path/to/file.parquet") * sparkSession.read.schema(schema).json("/path/to/file.json") @@ -584,6 +585,19 @@ class SparkSession private( */ def read: DataFrameReader = new DataFrameReader(self) + /** + * :: Experimental :: + * Returns a [[DataStreamReader]] that can be used to read streaming data in as a [[DataFrame]]. + * {{{ + * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") + * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") + * }}} + * + * @since 2.0.0 + */ + @Experimental + def readStream: DataStreamReader = new DataStreamReader(self) + // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala new file mode 100644 index 0000000000000000000000000000000000000000..248247a257d94115c829e387af7e26a753dedbab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.types.StructType + +/** + * Interface used to load a streaming [[Dataset]] from external storage systems (e.g. file systems, + * key-value stores, etc). Use [[SparkSession.readStream]] to access this. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { + /** + * :: Experimental :: + * Specifies the input data source format. + * + * @since 2.0.0 + */ + @Experimental + def format(source: String): DataStreamReader = { + this.source = source + this + } + + /** + * :: Experimental :: + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can + * skip the schema inference step, and thus speed up data loading. + * + * @since 2.0.0 + */ + @Experimental + def schema(schema: StructType): DataStreamReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * :: Experimental :: + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: String): DataStreamReader = { + this.extraOptions += (key -> value) + this + } + + /** + * :: Experimental :: + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) + + /** + * :: Experimental :: + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: Long): DataStreamReader = option(key, value.toString) + + /** + * :: Experimental :: + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: Double): DataStreamReader = option(key, value.toString) + + /** + * :: Experimental :: + * (Scala-specific) Adds input options for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def options(options: scala.collection.Map[String, String]): DataStreamReader = { + this.extraOptions ++= options + this + } + + /** + * :: Experimental :: + * Adds input options for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def options(options: java.util.Map[String, String]): DataStreamReader = { + this.options(options.asScala) + this + } + + + /** + * :: Experimental :: + * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path + * (e.g. external key-value stores). + * + * @since 2.0.0 + */ + @Experimental + def load(): DataFrame = { + val dataSource = + DataSource( + sparkSession, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap) + Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) + } + + /** + * :: Experimental :: + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + @Experimental + def load(path: String): DataFrame = { + option("path", path).load() + } + + /** + * :: Experimental :: + * Loads a JSON file stream (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + * <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li> + * <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal + * type. If the values do not fit in decimal, then it infers them as doubles.</li> + * <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li> + * <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li> + * <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + * </li> + * <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)</li> + * <li>`allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all + * character using backslash quoting mechanism</li> + * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing.</li> + * <ul> + * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the + * malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.</li> + * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> + * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> + * </ul> + * <li>`columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li> + * + * @since 2.0.0 + */ + @Experimental + def json(path: String): DataFrame = format("json").load(path) + + /** + * :: Experimental :: + * Loads a CSV file stream and returns the result as a [[DataFrame]]. + * + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using [[schema]]. + * + * You can set the following CSV-specific options to deal with CSV files: + * <li>`sep` (default `,`): sets the single character as a separator for each + * field and value.</li> + * <li>`encoding` (default `UTF-8`): decodes the CSV files by the given encoding + * type.</li> + * <li>`quote` (default `"`): sets the single character used for escaping quoted values where + * the separator can be part of the value. If you would like to turn off quotations, you need to + * set not `null` but an empty string. This behaviour is different form + * `com.databricks.spark.csv`.</li> + * <li>`escape` (default `\`): sets the single character used for escaping quotes inside + * an already quoted value.</li> + * <li>`comment` (default empty string): sets the single character used for skipping lines + * beginning with this character. By default, it is disabled.</li> + * <li>`header` (default `false`): uses the first line as names of columns.</li> + * <li>`inferSchema` (default `false`): infers the input schema automatically from data. It + * requires one extra pass over the data.</li> + * <li>`ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces + * from values being read should be skipped.</li> + * <li>`ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + * whitespaces from values being read should be skipped.</li> + * <li>`nullValue` (default empty string): sets the string representation of a null value.</li> + * <li>`nanValue` (default `NaN`): sets the string representation of a non-number" value.</li> + * <li>`positiveInf` (default `Inf`): sets the string representation of a positive infinity + * value.</li> + * <li>`negativeInf` (default `-Inf`): sets the string representation of a negative infinity + * value.</li> + * <li>`dateFormat` (default `null`): sets the string that indicates a date format. Custom date + * formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type + * and timestamp type. By default, it is `null` which means trying to parse times and date by + * `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.</li> + * <li>`maxColumns` (default `20480`): defines a hard limit of how many columns + * a record can have.</li> + * <li>`maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed + * for any given value being read.</li> + * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing.</li> + * <ul> + * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.</li> + * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> + * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> + * </ul> + * + * @since 2.0.0 + */ + @Experimental + def csv(path: String): DataFrame = format("csv").load(path) + + /** + * :: Experimental :: + * Loads a Parquet file stream, returning the result as a [[DataFrame]]. + * + * You can set the following Parquet-specific option(s) for reading Parquet files: + * <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets + * whether we should merge schemas collected from all Parquet part-files. This will override + * `spark.sql.parquet.mergeSchema`.</li> + * + * @since 2.0.0 + */ + @Experimental + def parquet(path: String): DataFrame = { + format("parquet").load(path) + } + + /** + * :: Experimental :: + * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. + * + * Each line in the text files is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * spark.readStream.text("/path/to/directory/") + * + * // Java: + * spark.readStream().text("/path/to/directory/") + * }}} + * + * @since 2.0.0 + */ + @Experimental + def text(path: String): DataFrame = format("text").load(path) + + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sparkSession.sessionState.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..b035ff7938bae273eaf523a91dc0268981f18baa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} + +/** + * :: Experimental :: + * Interface used to write a streaming [[Dataset]] to external storage systems (e.g. file systems, + * key-value stores, etc). Use [[Dataset.writeStream]] to access this. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { + + private val df = ds.toDF() + + /** + * :: Experimental :: + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be + * written to the sink + * - `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates + * + * @since 2.0.0 + */ + @Experimental + def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + this.outputMode = outputMode + this + } + + + /** + * :: Experimental :: + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to + * the sink + * - `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink + * every time these is some updates + * + * @since 2.0.0 + */ + @Experimental + def outputMode(outputMode: String): DataStreamWriter[T] = { + this.outputMode = outputMode.toLowerCase match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append' and 'complete'") + } + this + } + + /** + * :: Experimental :: + * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run + * the query as fast as possible. + * + * Scala Example: + * {{{ + * df.writeStream.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * df.writeStream.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * df.writeStream().trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.0.0 + */ + @Experimental + def trigger(trigger: Trigger): DataStreamWriter[T] = { + this.trigger = trigger + this + } + + + /** + * :: Experimental :: + * Specifies the name of the [[ContinuousQuery]] that can be started with `startStream()`. + * This name must be unique among all the currently active queries in the associated SQLContext. + * + * @since 2.0.0 + */ + @Experimental + def queryName(queryName: String): DataStreamWriter[T] = { + this.extraOptions += ("queryName" -> queryName) + this + } + + /** + * :: Experimental :: + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 2.0.0 + */ + @Experimental + def format(source: String): DataStreamWriter[T] = { + this.source = source + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. As an example, when we + * partition a dataset by year and then month, the directory layout would look like: + * + * - year=2016/month=01/ + * - year=2016/month=02/ + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. + * + * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataStreamWriter[T] = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * :: Experimental :: + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: String): DataStreamWriter[T] = { + this.extraOptions += (key -> value) + this + } + + /** + * :: Experimental :: + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) + + /** + * :: Experimental :: + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) + + /** + * :: Experimental :: + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) + + /** + * :: Experimental :: + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + this.extraOptions ++= options + this + } + + /** + * :: Experimental :: + * Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + @Experimental + def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + this.options(options.asScala) + this + } + + /** + * :: Experimental :: + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + @Experimental + def start(path: String): ContinuousQuery = { + option("path", path).start() + } + + /** + * :: Experimental :: + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + @Experimental + def start(): ContinuousQuery = { + if (source == "memory") { + assertNotPartitioned("memory") + if (extraOptions.get("queryName").isEmpty) { + throw new AnalysisException("queryName must be specified for memory sink") + } + + val sink = new MemorySink(df.schema, outputMode) + val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) + val query = df.sparkSession.sessionState.continuousQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + sink, + outputMode, + useTempCheckpointLocation = true, + recoverFromCheckpointLocation = false, + trigger = trigger) + resultDf.createOrReplaceTempView(query.name) + query + } else if (source == "foreach") { + assertNotPartitioned("foreach") + val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + df.sparkSession.sessionState.continuousQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) + } else { + val dataSource = + DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + df.sparkSession.sessionState.continuousQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + dataSource.createSink(outputMode), + outputMode, + trigger = trigger) + } + } + + /** + * :: Experimental :: + * Starts the execution of the streaming query, which will continually send results to the given + * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data + * generated by the [[DataFrame]]/[[Dataset]] to an external system. + * + * Scala example: + * {{{ + * datasetOfString.writeStream.foreach(new ForeachWriter[String] { + * + * def open(partitionId: Long, version: Long): Boolean = { + * // open connection + * } + * + * def process(record: String) = { + * // write string to connection + * } + * + * def close(errorOrNull: Throwable): Unit = { + * // close the connection + * } + * }).start() + * }}} + * + * Java example: + * {{{ + * datasetOfString.writeStream().foreach(new ForeachWriter<String>() { + * + * @Override + * public boolean open(long partitionId, long version) { + * // open connection + * } + * + * @Override + * public void process(String value) { + * // write string to connection + * } + * + * @Override + * public void close(Throwable errorOrNull) { + * // close the connection + * } + * }).start(); + * }}} + * + * @since 2.0.0 + */ + @Experimental + def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + this.source = "foreach" + this.foreachWriter = if (writer != null) { + ds.sparkSession.sparkContext.clean(writer) + } else { + throw new IllegalArgumentException("foreach writer cannot be null") + } + this + } + + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNotPartitioned(operation: String): Unit = { + if (partitioningColumns.isDefined) { + throw new AnalysisException(s"'$operation' does not support partitioning") + } + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName + + private var outputMode: OutputMode = OutputMode.Append + + private var trigger: Trigger = ProcessingTime(0L) + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var foreachWriter: ForeachWriter[T] = null + + private var partitioningColumns: Option[Seq[String]] = None +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index e1fb3b947837bc28bcac5f9b163cf29ed22b3768..6ff597c16bb28d3c787a6ac600d8e7a918b151c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -38,9 +38,10 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf test("foreach") { withTempDir { checkpointDir => val input = MemoryStream[Int] - val query = input.toDS().repartition(2).write + val query = input.toDS().repartition(2).writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .foreach(new TestForeachWriter()) + .start() input.addData(1, 2, 3, 4) query.processAllAvailable() @@ -70,14 +71,14 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf test("foreach with error") { withTempDir { checkpointDir => val input = MemoryStream[Int] - val query = input.toDS().repartition(1).write + val query = input.toDS().repartition(1).writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) throw new RuntimeException("error") } - }) + }).start() input.addData(1, 2, 3, 4) query.processAllAvailable() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index f81608bdb85e3a02465f76102524e208bf86dde2..ef2fcbf73e360b1e9a3a535488d11d590752012f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -225,12 +225,12 @@ class ContinuousQueryManagerSuite extends StreamTest with BeforeAndAfter { val metadataRoot = Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath query = - df.write + df.writeStream .format("memory") .queryName(s"query$i") .option("checkpointLocation", metadataRoot) .outputMode("append") - .startStream() + .start() .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala index 43a88576cf9f44c2b924c0c64bc3e7c8ebf3a67c..ad6bc277295973093d21828b89b778706450115f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -39,12 +39,12 @@ class ContinuousQuerySuite extends StreamTest with BeforeAndAfter { def startQuery(queryName: String): ContinuousQuery = { val metadataRoot = Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath - val writer = mapped.write + val writer = mapped.writeStream writer .queryName(queryName) .format("memory") .option("checkpointLocation", metadataRoot) - .startStream() + .start() } val q1 = startQuery("q1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index bb3063dc34ae38816e9c6e82ddfdca1db2c53bb8..a5acc970e3a78320713a49f98f7c8c23f8605933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -128,10 +128,10 @@ class FileStreamSinkSuite extends StreamTest { try { query = - df.write - .format("parquet") + df.writeStream .option("checkpointLocation", checkpointDir) - .startStream(outputDir) + .format("parquet") + .start(outputDir) inputData.addData(1, 2, 3) @@ -162,11 +162,11 @@ class FileStreamSinkSuite extends StreamTest { query = ds.map(i => (i, i * 1000)) .toDF("id", "value") - .write - .format("parquet") + .writeStream .partitionBy("id") .option("checkpointLocation", checkpointDir) - .startStream(outputDir) + .format("parquet") + .start(outputDir) inputData.addData(1, 2, 3) failAfter(streamingTimeout) { @@ -246,13 +246,13 @@ class FileStreamSinkSuite extends StreamTest { val writer = ds.map(i => (i, i * 1000)) .toDF("id", "value") - .write + .writeStream if (format.nonEmpty) { writer.format(format.get) } query = writer .option("checkpointLocation", checkpointDir) - .startStream(outputDir) + .start(outputDir) } finally { if (query != null) { query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index f681b8878d9edab66c7f1cbb88cd676dde4577ae..6971f93b230f147d195cc3d29a218c7f068b2eac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -107,11 +107,11 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { schema: Option[StructType] = None): DataFrame = { val reader = if (schema.isDefined) { - spark.read.format(format).schema(schema.get) + spark.readStream.format(format).schema(schema.get) } else { - spark.read.format(format) + spark.readStream.format(format) } - reader.stream(path) + reader.load(path) } protected def getSourceFromFileStream(df: DataFrame): FileStreamSource = { @@ -153,14 +153,14 @@ class FileStreamSourceSuite extends FileStreamSourceTest { format: Option[String], path: Option[String], schema: Option[StructType] = None): StructType = { - val reader = spark.read + val reader = spark.readStream format.foreach(reader.format) schema.foreach(reader.schema) val df = if (path.isDefined) { - reader.stream(path.get) + reader.load(path.get) } else { - reader.stream() + reader.load() } df.queryExecution.analyzed .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 1c0fb34dd0191a143b0a7aa8560fdfa69bde4c53..0e157cf7267dccb771999e4583e3aabb0ffb68bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -98,7 +98,7 @@ class FileStressSuite extends StreamTest { } writer.start() - val input = spark.read.format("text").stream(inputDir) + val input = spark.readStream.format("text").load(inputDir) def startStream(): ContinuousQuery = { val output = input @@ -116,17 +116,17 @@ class FileStressSuite extends StreamTest { if (partitionWrites) { output - .write + .writeStream .partitionBy("id") .format("parquet") .option("checkpointLocation", checkpoint) - .startStream(outputDir) + .start(outputDir) } else { output - .write + .writeStream .format("parquet") .option("checkpointLocation", checkpoint) - .startStream(outputDir) + .start(outputDir) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 9aada0b18dd8d6a30e60f2304b942a16ccc4531b..310d75630272bce8f663e53ae0a5082208964ace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -140,11 +140,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("registering as a table in Append output mode") { val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") .outputMode("append") .queryName("memStream") - .startStream() + .start() input.addData(1, 2, 3) query.processAllAvailable() @@ -166,11 +166,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { val query = input.toDF() .groupBy("value") .count() - .write + .writeStream .format("memory") .outputMode("complete") .queryName("memStream") - .startStream() + .start() input.addData(1, 2, 3) query.processAllAvailable() @@ -191,10 +191,10 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Ignore the stress test as it takes several minutes to run (0 until 1000).foreach { _ => val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") .queryName("memStream") - .startStream() + .start() input.addData(1, 2, 3) query.processAllAvailable() @@ -215,9 +215,9 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("error when no name is specified") { val error = intercept[AnalysisException] { val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") - .startStream() + .start() } assert(error.message contains "queryName must be specified") @@ -227,21 +227,21 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { val location = Utils.createTempDir(namePrefix = "steaming.checkpoint").getCanonicalPath val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") .queryName("memStream") .option("checkpointLocation", location) - .startStream() + .start() input.addData(1, 2, 3) query.processAllAvailable() query.stop() intercept[AnalysisException] { - input.toDF().write + input.toDF().writeStream .format("memory") .queryName("memStream") .option("checkpointLocation", location) - .startStream() + .start() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9414b1ce4019b60d09d6711ac2ebfceda35fb286..786404a5895816b13742c6f2eafb63fd5ec8d9f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -89,9 +89,9 @@ class StreamSuite extends StreamTest { def assertDF(df: DataFrame) { withTempDir { outputDir => withTempDir { checkpointDir => - val query = df.write.format("parquet") + val query = df.writeStream.format("parquet") .option("checkpointLocation", checkpointDir.getAbsolutePath) - .startStream(outputDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) try { query.processAllAvailable() val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] @@ -103,7 +103,7 @@ class StreamSuite extends StreamTest { } } - val df = spark.read.format(classOf[FakeDefaultSource].getName).stream() + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() assertDF(df) assertDF(df) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 8681199817fe680d7d0e7f1e733281a3a4f6c607..7f44227ec46fee2491c18f127c21e50e2b3c52c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -40,6 +40,8 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { import testImplicits._ + + test("simple count, update mode") { val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala similarity index 55% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 6e0d66ae7f19a42fcdd2d28020db31851e590779..c6d374f75467a6e64509da1f302b252198ec7e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -101,7 +101,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -110,25 +110,38 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { spark.streams.active.foreach(_.stop()) } + test("write cannot be called on streaming datasets") { + val e = intercept[AnalysisException] { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + .write + .save() + } + Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + test("resolve default source") { - spark.read + spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() - .write + .load() + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() } test("resolve full class") { - spark.read + spark.readStream .format("org.apache.spark.sql.streaming.test.DefaultSource") - .stream() - .write + .load() + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() } @@ -136,12 +149,12 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .stream() + .load() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -149,13 +162,13 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { LastOptions.clear() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() assert(LastOptions.parameters("opt1") == "1") @@ -164,84 +177,84 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { } test("partitioning") { - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() assert(LastOptions.partitionColumns == Nil) - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .partitionBy("a") - .startStream() + .start() .stop() assert(LastOptions.partitionColumns == Seq("a")) withSQLConf("spark.sql.caseSensitive" -> "false") { - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .partitionBy("A") - .startStream() + .start() .stop() assert(LastOptions.partitionColumns == Seq("a")) } intercept[AnalysisException] { - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .partitionBy("b") - .startStream() + .start() .stop() } } test("stream paths") { - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .stream("/test") + .load("/test") assert(LastOptions.parameters("path") == "/test") LastOptions.clear() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream("/test") + .start("/test") .stop() assert(LastOptions.parameters("path") == "/test") } test("test different data types for options") { - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) .option("doubleOpt", 6.7) - .stream("/test") + .load("/test") assert(LastOptions.parameters("intOpt") == "56") assert(LastOptions.parameters("boolOpt") == "false") assert(LastOptions.parameters("doubleOpt") == "6.7") LastOptions.clear() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) .option("doubleOpt", 6.7) .option("checkpointLocation", newMetadataDir) - .startStream("/test") + .start("/test") .stop() assert(LastOptions.parameters("intOpt") == "56") @@ -253,25 +266,25 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { /** Start a query with a specific name */ def startQueryWithName(name: String = ""): ContinuousQuery = { - spark.read + spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream("/test") - .write + .load("/test") + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .queryName(name) - .startStream() + .start() } /** Start a query without specifying a name */ def startQueryWithoutName(): ContinuousQuery = { - spark.read + spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream("/test") - .write + .load("/test") + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() } /** Get the names of active streams */ @@ -311,24 +324,24 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { } test("trigger") { - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream("/test") + .load("/test") - var q = df.write + var q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .trigger(ProcessingTime(10.seconds)) - .startStream() + .start() q.stop() assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) - q = df.write + q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) - .startStream() + .start() q.stop() assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) @@ -339,19 +352,19 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { val checkpointLocation = newMetadataDir - val df1 = spark.read + val df1 = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - val df2 = spark.read + val df2 = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - val q = df1.union(df2).write + val q = df1.union(df2).writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", checkpointLocation) .trigger(ProcessingTime(10.seconds)) - .startStream() + .start() q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( @@ -371,76 +384,12 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("check trigger() can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) - assert(e.getMessage == "trigger() can only be called on continuous queries;") - } - - test("check queryName() can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.queryName("queryName")) - assert(e.getMessage == "queryName() can only be called on continuous queries;") - } - - test("check startStream() can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.startStream()) - assert(e.getMessage == "startStream() can only be called on continuous queries;") - } - - test("check startStream(path) can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.startStream("non_exist_path")) - assert(e.getMessage == "startStream() can only be called on continuous queries;") - } - - test("check mode(SaveMode) can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.mode(SaveMode.Append)) - assert(e.getMessage == "mode() can only be called on non-continuous queries;") - } - - test("check mode(string) can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.mode("append")) - assert(e.getMessage == "mode() can only be called on non-continuous queries;") - } - - test("check outputMode(OutputMode) can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.outputMode(OutputMode.Append)) - Seq("outputmode", "continuous queries").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - - test("check outputMode(string) can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.outputMode("append")) - Seq("outputmode", "continuous queries").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - test("check outputMode(string) throws exception on unsupported modes") { def testError(outputMode: String): Unit = { - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write + .load() + val w = df.writeStream val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) Seq("output mode", "unknown", outputMode).foreach { s => assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) @@ -450,159 +399,46 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { testError("Xyz") } - test("check bucketBy() can only be called on non-continuous queries") { - val df = spark.read + test("check foreach() catches null writers") { + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.bucketBy(1, "text").startStream()) - assert(e.getMessage == "'startStream' does not support bucketing right now;") - } - - test("check sortBy() can only be called on non-continuous queries;") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.sortBy("text").startStream()) - assert(e.getMessage == "'startStream' does not support bucketing right now;") - } + .load() - test("check save(path) can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.save("non_exist_path")) - assert(e.getMessage == "save() can only be called on non-continuous queries;") - } - - test("check save() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.save()) - assert(e.getMessage == "save() can only be called on non-continuous queries;") - } - - test("check insertInto() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.insertInto("non_exsit_table")) - assert(e.getMessage == "insertInto() can only be called on non-continuous queries;") - } - - test("check saveAsTable() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table")) - assert(e.getMessage == "saveAsTable() can only be called on non-continuous queries;") - } - - test("check jdbc() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.jdbc(null, null, null)) - assert(e.getMessage == "jdbc() can only be called on non-continuous queries;") - } - - test("check json() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.json("non_exist_path")) - assert(e.getMessage == "json() can only be called on non-continuous queries;") - } - - test("check parquet() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.parquet("non_exist_path")) - assert(e.getMessage == "parquet() can only be called on non-continuous queries;") - } - - test("check orc() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.orc("non_exist_path")) - assert(e.getMessage == "orc() can only be called on non-continuous queries;") - } - - test("check text() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.text("non_exist_path")) - assert(e.getMessage == "text() can only be called on non-continuous queries;") - } - - test("check csv() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.csv("non_exist_path")) - assert(e.getMessage == "csv() can only be called on non-continuous queries;") - } - - test("check foreach() does not support partitioning or bucketing") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - - var w = df.write.partitionBy("value") - var e = intercept[AnalysisException](w.foreach(null)) - Seq("foreach", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - - w = df.write.bucketBy(2, "value") - e = intercept[AnalysisException](w.foreach(null)) - Seq("foreach", "bucketing").foreach { s => + var w = df.writeStream + var e = intercept[IllegalArgumentException](w.foreach(null)) + Seq("foreach", "null").foreach { s => assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) } } - test("check jdbc() does not support partitioning or bucketing") { - val df = spark.read.text(newTextInput) - var w = df.write.partitionBy("value") - var e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + test("check foreach() does not support partitioning") { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + val foreachWriter = new ForeachWriter[Row] { + override def open(partitionId: Long, version: Long): Boolean = false + override def process(value: Row): Unit = {} + override def close(errorOrNull: Throwable): Unit = {} } - - w = df.write.bucketBy(2, "value") - e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "bucketing").foreach { s => + var w = df.writeStream.partitionBy("value") + var e = intercept[AnalysisException](w.foreach(foreachWriter).start()) + Seq("foreach", "partitioning").foreach { s => assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) } } test("ConsoleSink can be correctly loaded") { LastOptions.clear() - val df = spark.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - val cq = df.write + val cq = df.writeStream .format("console") .option("checkpointLocation", newMetadataDir) .trigger(ProcessingTime(2.seconds)) - .startStream() + .start() cq.awaitTermination(2000L) } @@ -611,10 +447,11 @@ class DataFrameReaderWriterSuite extends StreamTest with BeforeAndAfter { withTempDir { dir => val path = dir.getCanonicalPath intercept[AnalysisException] { - spark.range(10).write.format("parquet").mode("overwrite").partitionBy("id").save(path) - } - intercept[AnalysisException] { - spark.range(10).write.format("orc").mode("overwrite").partitionBy("id").save(path) + spark.range(10).writeStream + .outputMode("append") + .partitionBy("id") + .format("parquet") + .start(path) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..98e57b38044f25c032fee9e3453936d335873282 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.util.Utils + + +object LastOptions { + + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var saveMode: SaveMode = null + + def clear(): Unit = { + parameters = null + schema = null + saveMode = null + } +} + + +/** Dummy provider. */ +class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + + case class FakeRelation(sqlContext: SQLContext) extends BaseRelation { + override def schema: StructType = StructType(Seq(StructField("a", StringType))) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType + ): BaseRelation = { + LastOptions.parameters = parameters + LastOptions.schema = Some(schema) + FakeRelation(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String] + ): BaseRelation = { + LastOptions.parameters = parameters + LastOptions.schema = None + FakeRelation(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + LastOptions.parameters = parameters + LastOptions.schema = None + LastOptions.saveMode = mode + FakeRelation(sqlContext) + } +} + + +class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext { + + private def newMetadataDir = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + + test("writeStream cannot be called on non-streaming datasets") { + val e = intercept[AnalysisException] { + spark.read + .format("org.apache.spark.sql.test") + .load() + .writeStream + .start() + } + Seq("'writeStream'", "only", "streaming Dataset/DataFrame").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + + test("resolve default source") { + spark.read + .format("org.apache.spark.sql.test") + .load() + .write + .format("org.apache.spark.sql.test") + .save() + } + + test("resolve full class") { + spark.read + .format("org.apache.spark.sql.test.DefaultSource") + .load() + .write + .format("org.apache.spark.sql.test") + .save() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = spark.read + .format("org.apache.spark.sql.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .load() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .save() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("save mode") { + val df = spark.read + .format("org.apache.spark.sql.test") + .load() + + df.write + .format("org.apache.spark.sql.test") + .mode(SaveMode.ErrorIfExists) + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + } + + test("paths") { + val df = spark.read + .format("org.apache.spark.sql.test") + .option("checkpointLocation", newMetadataDir) + .load("/test") + + assert(LastOptions.parameters("path") == "/test") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.test") + .option("checkpointLocation", newMetadataDir) + .save("/test") + + assert(LastOptions.parameters("path") == "/test") + } + + test("test different data types for options") { + val df = spark.read + .format("org.apache.spark.sql.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .load("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + + LastOptions.clear() + df.write + .format("org.apache.spark.sql.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .option("checkpointLocation", newMetadataDir) + .save("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + } + + test("check jdbc() does not support partitioning or bucketing") { + val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) + + var w = df.write.partitionBy("value") + var e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "partitioning").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + + w = df.write.bucketBy(2, "value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "bucketing").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("prevent all column partitioning") { + withTempDir { dir => + val path = dir.getCanonicalPath + intercept[AnalysisException] { + spark.range(10).write.format("parquet").mode("overwrite").partitionBy("id").save(path) + } + intercept[AnalysisException] { + spark.range(10).write.format("orc").mode("overwrite").partitionBy("id").save(path) + } + } + } +}