Skip to content
Snippets Groups Projects
Commit 5ab9fcfb authored by Yin Huai's avatar Yin Huai
Browse files

[SPARK-8532] [SQL] In Python's DataFrameWriter,...

[SPARK-8532] [SQL] In Python's DataFrameWriter, save/saveAsTable/json/parquet/jdbc always override mode

https://issues.apache.org/jira/browse/SPARK-8532

This PR has two changes. First, it fixes the bug that save actions (i.e. `save/saveAsTable/json/parquet/jdbc`) always override mode. Second, it adds input argument `partitionBy` to `save/saveAsTable/parquet`.

Author: Yin Huai <yhuai@databricks.com>

Closes #6937 from yhuai/SPARK-8532 and squashes the following commits:

f972d5d [Yin Huai] davies's comment.
d37abd2 [Yin Huai] style.
d21290a [Yin Huai] Python doc.
889eb25 [Yin Huai] Minor refactoring and add partitionBy to save, saveAsTable, and parquet.
7fbc24b [Yin Huai] Use None instead of "error" as the default value of mode since JVM-side already uses "error" as the default value.
d696dff [Yin Huai] Python style.
88eb6c4 [Yin Huai] If mode is "error", do not call mode method.
c40c461 [Yin Huai] Regression test.
parent da7bbb94
No related branches found
No related tags found
No related merge requests found
...@@ -218,7 +218,10 @@ class DataFrameWriter(object): ...@@ -218,7 +218,10 @@ class DataFrameWriter(object):
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
""" """
self._jwrite = self._jwrite.mode(saveMode) # At the JVM side, the default value of mode is already set to "error".
# So, if the given saveMode is None, we will not call JVM-side's mode method.
if saveMode is not None:
self._jwrite = self._jwrite.mode(saveMode)
return self return self
@since(1.4) @since(1.4)
...@@ -253,11 +256,12 @@ class DataFrameWriter(object): ...@@ -253,11 +256,12 @@ class DataFrameWriter(object):
""" """
if len(cols) == 1 and isinstance(cols[0], (list, tuple)): if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0] cols = cols[0]
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) if len(cols) > 0:
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
return self return self
@since(1.4) @since(1.4)
def save(self, path=None, format=None, mode="error", **options): def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
"""Saves the contents of the :class:`DataFrame` to a data source. """Saves the contents of the :class:`DataFrame` to a data source.
The data source is specified by the ``format`` and a set of ``options``. The data source is specified by the ``format`` and a set of ``options``.
...@@ -272,11 +276,12 @@ class DataFrameWriter(object): ...@@ -272,11 +276,12 @@ class DataFrameWriter(object):
* ``overwrite``: Overwrite existing data. * ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists. * ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists. * ``error`` (default case): Throw an exception if data already exists.
:param partitionBy: names of partitioning columns
:param options: all other string options :param options: all other string options
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
""" """
self.mode(mode).options(**options) self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None: if format is not None:
self.format(format) self.format(format)
if path is None: if path is None:
...@@ -296,7 +301,7 @@ class DataFrameWriter(object): ...@@ -296,7 +301,7 @@ class DataFrameWriter(object):
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
@since(1.4) @since(1.4)
def saveAsTable(self, name, format=None, mode="error", **options): def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
"""Saves the content of the :class:`DataFrame` as the specified table. """Saves the content of the :class:`DataFrame` as the specified table.
In the case the table already exists, behavior of this function depends on the In the case the table already exists, behavior of this function depends on the
...@@ -312,15 +317,16 @@ class DataFrameWriter(object): ...@@ -312,15 +317,16 @@ class DataFrameWriter(object):
:param name: the table name :param name: the table name
:param format: the format used to save :param format: the format used to save
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param partitionBy: names of partitioning columns
:param options: all other string options :param options: all other string options
""" """
self.mode(mode).options(**options) self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None: if format is not None:
self.format(format) self.format(format)
self._jwrite.saveAsTable(name) self._jwrite.saveAsTable(name)
@since(1.4) @since(1.4)
def json(self, path, mode="error"): def json(self, path, mode=None):
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path. """Saves the content of the :class:`DataFrame` in JSON format at the specified path.
:param path: the path in any Hadoop supported file system :param path: the path in any Hadoop supported file system
...@@ -333,10 +339,10 @@ class DataFrameWriter(object): ...@@ -333,10 +339,10 @@ class DataFrameWriter(object):
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
""" """
self._jwrite.mode(mode).json(path) self.mode(mode)._jwrite.json(path)
@since(1.4) @since(1.4)
def parquet(self, path, mode="error"): def parquet(self, path, mode=None, partitionBy=()):
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path. """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
:param path: the path in any Hadoop supported file system :param path: the path in any Hadoop supported file system
...@@ -346,13 +352,15 @@ class DataFrameWriter(object): ...@@ -346,13 +352,15 @@ class DataFrameWriter(object):
* ``overwrite``: Overwrite existing data. * ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists. * ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists. * ``error`` (default case): Throw an exception if data already exists.
:param partitionBy: names of partitioning columns
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
""" """
self._jwrite.mode(mode).parquet(path) self.partitionBy(partitionBy).mode(mode)
self._jwrite.parquet(path)
@since(1.4) @since(1.4)
def jdbc(self, url, table, mode="error", properties={}): def jdbc(self, url, table, mode=None, properties={}):
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC. """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
.. note:: Don't create too many partitions in parallel on a large cluster;\ .. note:: Don't create too many partitions in parallel on a large cluster;\
......
...@@ -539,6 +539,38 @@ class SQLTests(ReusedPySparkTestCase): ...@@ -539,6 +539,38 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath) shutil.rmtree(tmpPath)
def test_save_and_load_builder(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.write.json(tmpPath)
actual = self.sqlCtx.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.read.json(tmpPath, schema)
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
df.write.mode("overwrite").json(tmpPath)
actual = self.sqlCtx.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
.format("json").save(path=tmpPath)
actual =\
self.sqlCtx.read.format("json")\
.load(path=tmpPath, noUse="this options will not be used in load.")
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.sqlCtx.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
shutil.rmtree(tmpPath)
def test_help_command(self): def test_help_command(self):
# Regression test for SPARK-5464 # Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment