From 5bf8881b34a18f25acc10aeb28a06af4c44a6ac8 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Tue, 28 Jun 2016 18:33:37 -0700
Subject: [PATCH] [SPARK-16268][PYSPARK] SQLContext should import
 DataStreamReader

## What changes were proposed in this pull request?

Fixed the following error:
```
>>> sqlContext.readStream
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "...", line 442, in readStream
    return DataStreamReader(self._wrapped)
NameError: global name 'DataStreamReader' is not defined
```

## How was this patch tested?

The added test.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #13958 from zsxwing/fix-import.
---
 python/pyspark/sql/context.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8a1a874884..b5dde13ed7 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -26,7 +26,7 @@ from pyspark import since
 from pyspark.rdd import ignore_unicode_prefix
 from pyspark.sql.session import _monkey_patch_RDD, SparkSession
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.readwriter import DataFrameReader, DataStreamReader
 from pyspark.sql.types import Row, StringType
 from pyspark.sql.utils import install_exception_handler
 
@@ -438,8 +438,12 @@ class SQLContext(object):
         .. note:: Experimental.
 
         :return: :class:`DataStreamReader`
+
+        >>> text_sdf = sqlContext.readStream.text(os.path.join(tempfile.mkdtemp(), 'data'))
+        >>> text_sdf.isStreaming
+        True
         """
-        return DataStreamReader(self._wrapped)
+        return DataStreamReader(self)
 
     @property
     @since(2.0)
@@ -515,6 +519,7 @@ class UDFRegistration(object):
 def _test():
     import os
     import doctest
+    import tempfile
     from pyspark.context import SparkContext
     from pyspark.sql import Row, SQLContext
     import pyspark.sql.context
@@ -523,6 +528,8 @@ def _test():
 
     globs = pyspark.sql.context.__dict__.copy()
     sc = SparkContext('local[4]', 'PythonTest')
+    globs['tempfile'] = tempfile
+    globs['os'] = os
     globs['sc'] = sc
     globs['sqlContext'] = SQLContext(sc)
     globs['rdd'] = rdd = sc.parallelize(
-- 
GitLab