Skip to content
Snippets Groups Projects
Commit 96c3500c authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Tathagata Das
Browse files

[SPARK-15935][PYSPARK] Enable test for sql/streaming.py and fix these tests

## What changes were proposed in this pull request?

This PR just enables tests for sql/streaming.py and also fixes the failures.

## How was this patch tested?

Existing unit tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #13655 from zsxwing/python-streaming-test.
parent a87a56f5
No related branches found
No related tags found
No related merge requests found
......@@ -337,6 +337,7 @@ pyspark_sql = Module(
"pyspark.sql.group",
"pyspark.sql.functions",
"pyspark.sql.readwriter",
"pyspark.sql.streaming",
"pyspark.sql.window",
"pyspark.sql.tests",
]
......
......@@ -433,6 +433,8 @@ class SQLContext(object):
def streams(self):
"""Returns a :class:`ContinuousQueryManager` that allows managing all the
:class:`ContinuousQuery` ContinuousQueries active on `this` context.
.. note:: Experimental.
"""
from pyspark.sql.streaming import ContinuousQueryManager
return ContinuousQueryManager(self._ssql_ctx.streams())
......
......@@ -549,6 +549,17 @@ class SparkSession(object):
"""
return DataFrameReader(self._wrapped)
@property
@since(2.0)
def streams(self):
"""Returns a :class:`ContinuousQueryManager` that allows managing all the
:class:`ContinuousQuery` ContinuousQueries active on `this` context.
.. note:: Experimental.
"""
from pyspark.sql.streaming import ContinuousQueryManager
return ContinuousQueryManager(self._jsparkSession.streams())
@since(2.0)
def stop(self):
"""Stop the underlying :class:`SparkContext`.
......
......@@ -15,6 +15,12 @@
# limitations under the License.
#
import sys
if sys.version >= '3':
intlike = int
else:
intlike = (int, long)
from abc import ABCMeta, abstractmethod
from pyspark import since
......@@ -36,10 +42,18 @@ class ContinuousQuery(object):
def __init__(self, jcq):
self._jcq = jcq
@property
@since(2.0)
def id(self):
"""The id of the continuous query. This id is unique across all queries that have been
started in the current process.
"""
return self._jcq.id()
@property
@since(2.0)
def name(self):
"""The name of the continuous query.
"""The name of the continuous query. This name is unique across all active queries.
"""
return self._jcq.name()
......@@ -106,7 +120,7 @@ class ContinuousQueryManager(object):
"""Returns a list of active queries associated with this SQLContext
>>> cq = df.write.format('memory').queryName('this_query').startStream()
>>> cqm = sqlContext.streams
>>> cqm = spark.streams
>>> # get the list of active continuous queries
>>> [q.name for q in cqm.active]
[u'this_query']
......@@ -114,20 +128,26 @@ class ContinuousQueryManager(object):
"""
return [ContinuousQuery(jcq) for jcq in self._jcqm.active()]
@ignore_unicode_prefix
@since(2.0)
def get(self, name):
def get(self, id):
"""Returns an active query from this SQLContext or throws exception if an active query
with this name doesn't exist.
>>> df.write.format('memory').queryName('this_query').startStream()
>>> cq = sqlContext.streams.get('this_query')
>>> cq = df.write.format('memory').queryName('this_query').startStream()
>>> cq.name
u'this_query'
>>> cq = spark.streams.get(cq.id)
>>> cq.isActive
True
>>> cq = sqlContext.streams.get(cq.id)
>>> cq.isActive
True
>>> cq.stop()
"""
if type(name) != str or len(name.strip()) == 0:
raise ValueError("The name for the query must be a non-empty string. Got: %s" % name)
return ContinuousQuery(self._jcqm.get(name))
if not isinstance(id, intlike):
raise ValueError("The id for the query must be an integer. Got: %d" % id)
return ContinuousQuery(self._jcqm.get(id))
@since(2.0)
def awaitAnyTermination(self, timeout=None):
......@@ -162,7 +182,7 @@ class ContinuousQueryManager(object):
"""Forget about past terminated queries so that :func:`awaitAnyTermination()` can be used
again to wait for new terminations.
>>> sqlContext.streams.resetTerminated()
>>> spark.streams.resetTerminated()
"""
self._jcqm.resetTerminated()
......@@ -209,27 +229,28 @@ def _test():
import doctest
import os
import tempfile
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext, HiveContext
import pyspark.sql.readwriter
from pyspark.sql import Row, SparkSession, SQLContext
import pyspark.sql.streaming
os.chdir(os.environ["SPARK_HOME"])
globs = pyspark.sql.readwriter.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs = pyspark.sql.streaming.__dict__.copy()
try:
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
except py4j.protocol.Py4JError:
spark = SparkSession(sc)
globs['tempfile'] = tempfile
globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
globs['hiveContext'] = HiveContext._createForTesting(sc)
globs['spark'] = spark
globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext)
globs['df'] = \
globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
globs['spark'].read.format('text').stream('python/test_support/sql/streaming')
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
pyspark.sql.streaming, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
globs['sc'].stop()
globs['spark'].stop()
if failure_count:
exit(-1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment