From aaf50d05c7616e4f8f16654b642500ae06cdd774 Mon Sep 17 00:00:00 2001 From: Yin Huai <yhuai@databricks.com> Date: Tue, 10 Feb 2015 17:29:52 -0800 Subject: [PATCH] [SPARK-5658][SQL] Finalize DDL and write support APIs https://issues.apache.org/jira/browse/SPARK-5658 Author: Yin Huai <yhuai@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #4446 from yhuai/writeSupportFollowup and squashes the following commits: f3a96f7 [Yin Huai] davies's comments. 225ff71 [Yin Huai] Use Scala TestHiveContext to initialize the Python HiveContext in Python tests. 2306f93 [Yin Huai] Style. 2091fcd [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 537e28f [Yin Huai] Correctly clean up temp data. ae4649e [Yin Huai] Fix Python test. 609129c [Yin Huai] Doc format. 92b6659 [Yin Huai] Python doc and other minor updates. cbc717f [Yin Huai] Rename dataSourceName to source. d1c12d3 [Yin Huai] No need to delete the duplicate rule since it has been removed in master. 22cfa70 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup d91ecb8 [Yin Huai] Fix test. 4c76d78 [Yin Huai] Simplify APIs. 3abc215 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 0832ce4 [Yin Huai] Fix test. 98e7cdb [Yin Huai] Python style. 2bf44ef [Yin Huai] Python APIs. c204967 [Yin Huai] Format a10223d [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 9ff97d8 [Yin Huai] Add SaveMode to saveAsTable. 9b6e570 [Yin Huai] Update doc. c2be775 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 99950a2 [Yin Huai] Use Java enum for SaveMode. 4679665 [Yin Huai] Remove duplicate rule. 77d89dc [Yin Huai] Update doc. e04d908 [Yin Huai] Move import and add (Scala-specific) to scala APIs. cf5703d [Yin Huai] Add checkAnswer to Java tests. 7db95ff [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 6dfd386 [Yin Huai] Add java test. f2f33ef [Yin Huai] Fix test. e702386 [Yin Huai] Apache header. b1e9b1b [Yin Huai] Format. ed4e1b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup af9e9b3 [Yin Huai] DDL and write support API followup. 2a6213a [Yin Huai] Update API names. e6a0b77 [Yin Huai] Update test. 43bae01 [Yin Huai] Remove createTable from HiveContext. 5ffc372 [Yin Huai] Add more load APIs to SQLContext. 5390743 [Yin Huai] Add more save APIs to DataFrame. --- python/pyspark/sql/context.py | 68 ++++++++ python/pyspark/sql/dataframe.py | 72 +++++++- python/pyspark/sql/tests.py | 107 +++++++++++- .../apache/spark/sql/sources/SaveMode.java | 45 +++++ .../org/apache/spark/sql/DataFrame.scala | 160 ++++++++++++++--- .../org/apache/spark/sql/DataFrameImpl.scala | 61 ++----- .../apache/spark/sql/IncomputableColumn.scala | 27 +-- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 164 +++++++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 14 +- .../apache/spark/sql/json/JSONRelation.scala | 30 +++- .../apache/spark/sql/parquet/newParquet.scala | 45 ++++- .../org/apache/spark/sql/sources/ddl.scala | 40 ++++- .../apache/spark/sql/sources/interfaces.scala | 19 ++ .../spark/sql/sources/JavaSaveLoadSuite.java | 97 +++++++++++ .../org/apache/spark/sql/QueryTest.scala | 92 ++++++---- .../sources/CreateTableAsSelectSuite.scala | 29 +++- .../spark/sql/sources/SaveLoadSuite.scala | 59 +++++-- .../apache/spark/sql/hive/HiveContext.scala | 76 -------- .../spark/sql/hive/HiveStrategies.scala | 13 +- .../spark/sql/hive/execution/commands.scala | 105 ++++++++--- .../spark/sql/hive/{ => test}/TestHive.scala | 20 +-- .../hive/JavaMetastoreDataSourcesSuite.java | 147 ++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 64 +++++-- .../sql/hive/InsertIntoHiveTableSuite.scala | 33 ++-- .../sql/hive/MetastoreDataSourcesSuite.scala | 118 +++++++++++-- 26 files changed, 1357 insertions(+), 350 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java create mode 100644 sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{ => test}/TestHive.scala (99%) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 49f016a9cf..882c0f98ea 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -21,6 +21,7 @@ from array import array from itertools import imap from py4j.protocol import Py4JError +from py4j.java_collections import MapConverter from pyspark.rdd import _prepare_for_python_RDD from pyspark.serializers import AutoBatchedSerializer, PickleSerializer @@ -87,6 +88,18 @@ class SQLContext(object): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + def setConf(self, key, value): + """Sets the given Spark SQL configuration property. + """ + self._ssql_ctx.setConf(key, value) + + def getConf(self, key, defaultValue): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set, returns defaultValue. + """ + return self._ssql_ctx.getConf(key, defaultValue) + def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -455,6 +468,61 @@ class SQLContext(object): df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return DataFrame(df, self) + def load(self, path=None, source=None, schema=None, **options): + """Returns the dataset in a data source as a DataFrame. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Optionally, a schema can be provided as the schema of the returned DataFrame. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.load(source, joptions) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.load(source, scala_datatype, joptions) + return DataFrame(df, self) + + def createExternalTable(self, tableName, path=None, source=None, + schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Optionally, a schema can be provided as the schema of the returned DataFrame and + created external table. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + if schema is None: + df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, + joptions) + return DataFrame(df, self) + def sql(self, sqlQuery): """Return a L{DataFrame} representing the result of the given query. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 04be65fe24..3eef0cc376 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -146,9 +146,75 @@ class DataFrame(object): """ self._jdf.insertInto(tableName, overwrite) - def saveAsTable(self, tableName): - """Creates a new table with the contents of this DataFrame.""" - self._jdf.saveAsTable(tableName) + def _java_save_mode(self, mode): + """Returns the Java save mode based on the Python save mode represented by a string. + """ + jSaveMode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode + jmode = jSaveMode.ErrorIfExists + mode = mode.lower() + if mode == "append": + jmode = jSaveMode.Append + elif mode == "overwrite": + jmode = jSaveMode.Overwrite + elif mode == "ignore": + jmode = jSaveMode.Ignore + elif mode == "error": + pass + else: + raise ValueError( + "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") + return jmode + + def saveAsTable(self, tableName, source=None, mode="append", **options): + """Saves the contents of the DataFrame to a data source as a table. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Additionally, mode is used to specify the behavior of the saveAsTable operation when + table already exists in the data source. There are four modes: + + * append: Contents of this DataFrame are expected to be appended to existing table. + * overwrite: Data in the existing table is expected to be overwritten by the contents of \ + this DataFrame. + * error: An exception is expected to be thrown. + * ignore: The save operation is expected to not save the contents of the DataFrame and \ + to not change the existing table. + """ + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + joptions = MapConverter().convert(options, + self.sql_ctx._sc._gateway._gateway_client) + self._jdf.saveAsTable(tableName, source, jmode, joptions) + + def save(self, path=None, source=None, mode="append", **options): + """Saves the contents of the DataFrame to a data source. + + The data source is specified by the `source` and a set of `options`. + If `source` is not specified, the default data source configured by + spark.sql.sources.default will be used. + + Additionally, mode is used to specify the behavior of the save operation when + data already exists in the data source. There are four modes: + + * append: Contents of this DataFrame are expected to be appended to existing data. + * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. + * error: An exception is expected to be thrown. + * ignore: The save operation is expected to not save the contents of the DataFrame and \ + to not change the existing data. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + joptions = MapConverter().convert(options, + self._sc._gateway._gateway_client) + self._jdf.save(source, jmode, joptions) def schema(self): """Returns the schema of this DataFrame (represented by diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d25c6365ed..bc945091f7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -34,10 +34,9 @@ if sys.version_info[:2] <= (2, 6): else: import unittest - -from pyspark.sql import SQLContext, Column +from pyspark.sql import SQLContext, HiveContext, Column from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType, LongType + UserDefinedType, DoubleType, LongType, StringType from pyspark.tests import ReusedPySparkTestCase @@ -286,6 +285,37 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.save(tmpPath, "org.apache.spark.sql.json", "error") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + + df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -296,5 +326,76 @@ class SQLTests(ReusedPySparkTestCase): pydoc.render_doc(df.take(1)) +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + print "type", type(cls.sc) + print "type", type(cls.sc._jsc) + _scala_HiveContext =\ + cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) + cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, + "org.apache.spark.sql.json") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", + source="org.apache.spark.sql.json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.select("value").collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + if __name__ == "__main__": unittest.main() diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java new file mode 100644 index 0000000000..3109f5716d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources; + +/** + * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. + */ +public enum SaveMode { + /** + * Append mode means that when saving a DataFrame to a data source, if data/table already exists, + * contents of the DataFrame are expected to be appended to existing data. + */ + Append, + /** + * Overwrite mode means that when saving a DataFrame to a data source, + * if data/table already exists, existing data is expected to be overwritten by the contents of + * the DataFrame. + */ + Overwrite, + /** + * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, + * an exception is expected to be thrown. + */ + ErrorIfExists, + /** + * Ignore mode means that when saving a DataFrame to a data source, if data already exists, + * the save operation is expected to not save the contents of the DataFrame and to not + * change the existing data. + */ + Ignore +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 04e0d09947..ca8d552c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql +import scala.collection.JavaConversions._ import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.sources.SaveMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -import scala.util.control.NonFatal - - private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { new DataFrameImpl(sqlContext, logicalPlan) @@ -574,8 +574,64 @@ trait DataFrame extends RDDApi[Row] { /** * :: Experimental :: - * Creates a table from the the contents of this DataFrame. This will fail if the table already - * exists. + * Creates a table from the the contents of this DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + * This will fail if the table already exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + def saveAsTable(tableName: String): Unit = { + saveAsTable(tableName, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame, using the default data source + * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + def saveAsTable(tableName: String, mode: SaveMode): Unit = { + if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { + // If table already exists and the save mode is Append, + // we will just call insertInto to append the contents of this DataFrame. + insertInto(tableName, overwrite = false) + } else { + val dataSourceName = sqlContext.conf.defaultDataSourceName + saveAsTable(tableName, dataSourceName, mode) + } + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source and a set of options, + * using [[SaveMode.ErrorIfExists]] as the save mode. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + def saveAsTable( + tableName: String, + source: String): Unit = { + saveAsTable(tableName, source, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * * Note that this currently only works with DataFrames that are created from a HiveContext as * there is no notion of a persisted catalog in a standard SQL context. Instead you can write @@ -583,12 +639,17 @@ trait DataFrame extends RDDApi[Row] { * be the target of an `insertInto`. */ @Experimental - def saveAsTable(tableName: String): Unit + def saveAsTable( + tableName: String, + source: String, + mode: SaveMode): Unit = { + saveAsTable(tableName, source, mode, Map.empty[String, String]) + } /** * :: Experimental :: - * Creates a table from the the contents of this DataFrame based on a given data source and - * a set of options. This will fail if the table already exists. + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * * Note that this currently only works with DataFrames that are created from a HiveContext as * there is no notion of a persisted catalog in a standard SQL context. Instead you can write @@ -598,14 +659,17 @@ trait DataFrame extends RDDApi[Row] { @Experimental def saveAsTable( tableName: String, - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit + source: String, + mode: SaveMode, + options: java.util.Map[String, String]): Unit = { + saveAsTable(tableName, source, mode, options.toMap) + } /** * :: Experimental :: - * Creates a table from the the contents of this DataFrame based on a given data source and - * a set of options. This will fail if the table already exists. + * (Scala-specific) + * Creates a table from the the contents of this DataFrame based on a given data source, + * [[SaveMode]] specified by mode, and a set of options. * * Note that this currently only works with DataFrames that are created from a HiveContext as * there is no notion of a persisted catalog in a standard SQL context. Instead you can write @@ -615,22 +679,76 @@ trait DataFrame extends RDDApi[Row] { @Experimental def saveAsTable( tableName: String, - dataSourceName: String, - options: java.util.Map[String, String]): Unit + source: String, + mode: SaveMode, + options: Map[String, String]): Unit + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path, + * using the default data source configured by spark.sql.sources.default and + * [[SaveMode.ErrorIfExists]] as the save mode. + */ + @Experimental + def save(path: String): Unit = { + save(path, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, + * using the default data source configured by spark.sql.sources.default. + */ + @Experimental + def save(path: String, mode: SaveMode): Unit = { + val dataSourceName = sqlContext.conf.defaultDataSourceName + save(path, dataSourceName, mode) + } + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path based on the given data source, + * using [[SaveMode.ErrorIfExists]] as the save mode. + */ + @Experimental + def save(path: String, source: String): Unit = { + save(source, SaveMode.ErrorIfExists, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path based on the given data source and + * [[SaveMode]] specified by mode. + */ @Experimental - def save(path: String): Unit + def save(path: String, source: String, mode: SaveMode): Unit = { + save(source, mode, Map("path" -> path)) + } + /** + * :: Experimental :: + * Saves the contents of this DataFrame based on the given data source, + * [[SaveMode]] specified by mode, and a set of options. + */ @Experimental def save( - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit + source: String, + mode: SaveMode, + options: java.util.Map[String, String]): Unit = { + save(source, mode, options.toMap) + } + /** + * :: Experimental :: + * (Scala-specific) + * Saves the contents of this DataFrame based on the given data source, + * [[SaveMode]] specified by mode, and a set of options + */ @Experimental def save( - dataSourceName: String, - options: java.util.Map[String, String]): Unit + source: String, + mode: SaveMode, + options: Map[String, String]): Unit /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 1ee16ad516..11f9334556 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -28,13 +28,14 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, ResolvedStar, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{NumericType, StructType} @@ -341,68 +342,34 @@ private[sql] class DataFrameImpl protected[sql]( override def saveAsParquetFile(path: String): Unit = { if (sqlContext.conf.parquetUseDataSourceApi) { - save("org.apache.spark.sql.parquet", "path" -> path) + save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path)) } else { sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd } } - override def saveAsTable(tableName: String): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - val cmd = - CreateTableUsingAsLogicalPlan( - tableName, - dataSourceName, - temporary = false, - Map.empty, - allowExisting = false, - logicalPlan) - - sqlContext.executePlan(cmd).toRdd - } - override def saveAsTable( tableName: String, - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = { + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = { val cmd = CreateTableUsingAsLogicalPlan( tableName, - dataSourceName, + source, temporary = false, - (option +: options).toMap, - allowExisting = false, + mode, + options, logicalPlan) sqlContext.executePlan(cmd).toRdd } - override def saveAsTable( - tableName: String, - dataSourceName: String, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - saveAsTable(tableName, dataSourceName, opts.head, opts.tail:_*) - } - - override def save(path: String): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - save(dataSourceName, "path" -> path) - } - - override def save( - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = { - ResolvedDataSource(sqlContext, dataSourceName, (option +: options).toMap, this) - } - override def save( - dataSourceName: String, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - save(dataSourceName, opts.head, opts.tail:_*) + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = { + ResolvedDataSource(sqlContext, source, mode, options, this) } override def insertInto(tableName: String, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index ce0557b881..494e49c131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedSt import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.sources.SaveMode import org.apache.spark.sql.types.StructType - private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column { def this(name: String) = this(name match { @@ -156,29 +156,16 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def saveAsParquetFile(path: String): Unit = err() - override def saveAsTable(tableName: String): Unit = err() - - override def saveAsTable( - tableName: String, - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = err() - override def saveAsTable( tableName: String, - dataSourceName: String, - options: java.util.Map[String, String]): Unit = err() - - override def save(path: String): Unit = err() - - override def save( - dataSourceName: String, - option: (String, String), - options: (String, String)*): Unit = err() + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = err() override def save( - dataSourceName: String, - options: java.util.Map[String, String]): Unit = err() + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = err() override def insertInto(tableName: String, overwrite: Boolean): Unit = err() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 180f5e765f..39f6c2f4bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -50,7 +50,7 @@ private[spark] object SQLConf { val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource" + val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" // Whether to perform eager analysis on a DataFrame. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 97e3777f93..801505bceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -401,27 +401,173 @@ class SQLContext(@transient val sparkContext: SparkContext) jsonRDD(json.rdd, samplingRatio); } + /** + * :: Experimental :: + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + */ @Experimental def load(path: String): DataFrame = { val dataSourceName = conf.defaultDataSourceName - load(dataSourceName, ("path", path)) + load(path, dataSourceName) } + /** + * :: Experimental :: + * Returns the dataset stored at path as a DataFrame, + * using the given data source. + */ @Experimental - def load( - dataSourceName: String, - option: (String, String), - options: (String, String)*): DataFrame = { - val resolved = ResolvedDataSource(this, None, dataSourceName, (option +: options).toMap) + def load(path: String, source: String): DataFrame = { + load(source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Returns the dataset specified by the given data source and a set of options as a DataFrame. + */ + @Experimental + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + load(source, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Returns the dataset specified by the given data source and a set of options as a DataFrame. + */ + @Experimental + def load(source: String, options: Map[String, String]): DataFrame = { + val resolved = ResolvedDataSource(this, None, source, options) DataFrame(this, LogicalRelation(resolved.relation)) } + /** + * :: Experimental :: + * Returns the dataset specified by the given data source and a set of options as a DataFrame, + * using the given schema as the schema of the DataFrame. + */ @Experimental def load( - dataSourceName: String, + source: String, + schema: StructType, options: java.util.Map[String, String]): DataFrame = { - val opts = options.toSeq - load(dataSourceName, opts.head, opts.tail:_*) + load(source, schema, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Returns the dataset specified by the given data source and a set of options as a DataFrame, + * using the given schema as the schema of the DataFrame. + */ + @Experimental + def load( + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val resolved = ResolvedDataSource(this, Some(schema), source, options) + DataFrame(this, LogicalRelation(resolved.relation)) + } + + /** + * :: Experimental :: + * Creates an external table from the given path and returns the corresponding DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + */ + @Experimental + def createExternalTable(tableName: String, path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + createExternalTable(tableName, path, dataSourceName) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source + * and returns the corresponding DataFrame. + */ + @Experimental + def createExternalTable( + tableName: String, + path: String, + source: String): DataFrame = { + createExternalTable(tableName, source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = None, + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } + + /** + * :: Experimental :: + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, schema, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index edf8a5be64..e915e0e6a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -309,7 +309,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false) => + case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( CreateTempTableUsing( tableName, userSpecifiedSchema, provider, opts)) :: Nil @@ -318,24 +318,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, opts, false, query) => + case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) => val logicalPlan = sqlContext.parseSql(query) val cmd = - CreateTempTableUsingAsSelect(tableName, provider, opts, logicalPlan) + CreateTempTableUsingAsSelect(tableName, provider, mode, opts, logicalPlan) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case c: CreateTableUsingAsSelect if c.temporary && c.allowExisting => - sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsLogicalPlan(tableName, provider, true, opts, false, query) => + case CreateTableUsingAsLogicalPlan(tableName, provider, true, mode, opts, query) => val cmd = - CreateTempTableUsingAsSelect(tableName, provider, opts, query) + CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsLogicalPlan if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting => - sys.error("allowExisting should be set to false when creating a temporary table.") case LogicalDescribeCommand(table, isExtended) => val resultPlan = self.sqlContext.executePlan(table).executedPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index c4e14c6c92..f828bcdd65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -29,6 +28,10 @@ import org.apache.spark.sql.types.StructType private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) + } + /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, @@ -52,15 +55,30 @@ private[sql] class DefaultSource override def createRelation( sqlContext: SQLContext, + mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { - val path = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) { - sys.error(s"path $path already exists.") + val doSave = if (fs.exists(filesystemPath)) { + mode match { + case SaveMode.Append => + sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") + case SaveMode.Overwrite => + fs.delete(filesystemPath, true) + true + case SaveMode.ErrorIfExists => + sys.error(s"path $path already exists.") + case SaveMode.Ignore => false + } + } else { + true + } + if (doSave) { + // Only save data when the save mode is not ignore. + data.toJSON.saveAsTextFile(path) } - data.toJSON.saveAsTextFile(path) createRelation(sqlContext, parameters, data.schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 04804f78f5..aef9c10fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -80,18 +80,45 @@ class DefaultSource override def createRelation( sqlContext: SQLContext, + mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { val path = checkPath(parameters) - ParquetRelation.createEmpty( - path, - data.schema.toAttributes, - false, - sqlContext.sparkContext.hadoopConfiguration, - sqlContext) - - val relation = createRelation(sqlContext, parameters, data.schema) - relation.asInstanceOf[ParquetRelation2].insert(data, true) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val doSave = if (fs.exists(filesystemPath)) { + mode match { + case SaveMode.Append => + sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") + case SaveMode.Overwrite => + fs.delete(filesystemPath, true) + true + case SaveMode.ErrorIfExists => + sys.error(s"path $path already exists.") + case SaveMode.Ignore => false + } + } else { + true + } + + val relation = if (doSave) { + // Only save data when the save mode is not ignore. + ParquetRelation.createEmpty( + path, + data.schema.toAttributes, + false, + sqlContext.sparkContext.hadoopConfiguration, + sqlContext) + + val createdRelation = createRelation(sqlContext, parameters, data.schema) + createdRelation.asInstanceOf[ParquetRelation2].insert(data, true) + + createdRelation + } else { + // If the save mode is Ignore, we will just create the relation based on existing data. + createRelation(sqlContext, parameters) + } + relation } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 9f64f76100..6487c14b1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -119,11 +119,20 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { throw new DDLException( "a CREATE TABLE AS SELECT statement does not allow column definitions.") } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(tableName, provider, temp.isDefined, + mode, options, - allowExisting.isDefined, query.get) } else { val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) @@ -133,7 +142,8 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { provider, temp.isDefined, options, - allowExisting.isDefined) + allowExisting.isDefined, + managedIfNoPath = false) } } ) @@ -264,6 +274,7 @@ object ResolvedDataSource { def apply( sqlContext: SQLContext, provider: String, + mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { val loader = Utils.getContextOrSparkClassLoader @@ -277,7 +288,7 @@ object ResolvedDataSource { val relation = clazz.newInstance match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, options, data) + dataSource.createRelation(sqlContext, mode, options, data) case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") } @@ -307,28 +318,40 @@ private[sql] case class DescribeCommand( new MetadataBuilder().putString("comment", "comment of the column").build())()) } +/** + * Used to represent the operation of create table using a data source. + * @param tableName + * @param userSpecifiedSchema + * @param provider + * @param temporary + * @param options + * @param allowExisting If it is true, we will do nothing when the table already exists. + * If it is false, an exception will be thrown + * @param managedIfNoPath + */ private[sql] case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, options: Map[String, String], - allowExisting: Boolean) extends Command + allowExisting: Boolean, + managedIfNoPath: Boolean) extends Command private[sql] case class CreateTableUsingAsSelect( tableName: String, provider: String, temporary: Boolean, + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, query: String) extends Command private[sql] case class CreateTableUsingAsLogicalPlan( tableName: String, provider: String, temporary: Boolean, + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, query: LogicalPlan) extends Command private [sql] case class CreateTempTableUsing( @@ -348,12 +371,13 @@ private [sql] case class CreateTempTableUsing( private [sql] case class CreateTempTableUsingAsSelect( tableName: String, provider: String, + mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { def run(sqlContext: SQLContext) = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, options, df) + val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df) sqlContext.registerRDDAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) @@ -364,7 +388,7 @@ private [sql] case class CreateTempTableUsingAsSelect( /** * Builds a map in which keys are case insensitive */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 5eecc303ef..37fda7ba6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -79,8 +79,27 @@ trait SchemaRelationProvider { @DeveloperApi trait CreatableRelationProvider { + /** + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + * + * @param sqlContext + * @param mode + * @param parameters + * @param data + * @return + */ def createRelation( sqlContext: SQLContext, + mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation } diff --git a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java new file mode 100644 index 0000000000..852baf0e09 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaSaveLoadSuite { + + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + String originalDefaultSource; + File path; + DataFrame df; + + private void checkAnswer(DataFrame actual, List<Row> expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + List<String> jsonObjects = new ArrayList<String>(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD<String> rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @Test + public void saveAndLoad() { + Map<String, String> options = new HashMap<String, String>(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + } + + @Test + public void saveAndLoadWithSchema() { + Map<String, String> options = new HashMap<String, String>(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + List<StructField> fields = new ArrayList<>(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + + checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9ddd2ca5c..dfb6858957 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} +import scala.collection.JavaConversions._ + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation @@ -52,9 +54,51 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(rdd, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** + * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + */ + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param rdd the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -70,18 +114,20 @@ class QueryTest extends PlanTest { } val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: |${rdd.logicalPlan} |== Analyzed Plan == @@ -90,37 +136,21 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } + return None } - /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached + def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(rdd, expectedAnswer.toSeq) match { + case Some(errorMessage) => errorMessage + case None => null } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index b02389978b..29caed9337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -77,12 +77,10 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") - - val message = intercept[RuntimeException]{ + val message = intercept[DDLException]{ sql( s""" - |CREATE TEMPORARY TABLE jsonTable + |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( | path '${path.toString}' @@ -91,10 +89,25 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) }.getMessage assert( - message.contains(s"path ${path.toString} already exists."), + message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") - // Explicitly delete it. + // Overwrite the temporary table. + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT a * 4 FROM jt").collect()) + + dropTempTable("jsonTable") + // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) sql( @@ -104,12 +117,12 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { |OPTIONS ( | path '${path.toString}' |) AS - |SELECT a * 4 FROM jt + |SELECT b FROM jt """.stripMargin) checkAnswer( sql("SELECT * FROM jsonTable"), - sql("SELECT a * 4 FROM jt").collect()) + sql("SELECT b FROM jt").collect()) dropTempTable("jsonTable") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index fe2f76cc39..a510045671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.DataFrame -import org.apache.spark.util.Utils - import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.{SQLConf, DataFrame} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { @@ -38,42 +38,60 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll(): Unit = { originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") path = util.getTempFilePath("datasource").getCanonicalFile val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) df = jsonRDD(rdd) + df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } after { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) if (path.exists()) Utils.deleteRecursively(path) } def checkLoad(): Unit = { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") checkAnswer(load(path.toString), df.collect()) - checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect()) + + // Test if we can pick up the data source name passed in load. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect()) + checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect()) + val schema = StructType(StructField("b", StringType, true) :: Nil) + checkAnswer( + load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)), + sql("SELECT b FROM jsonTable").collect()) } - test("save with overwrite and load") { + test("save with path and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") df.save(path.toString) - checkLoad + checkLoad() + } + + test("save with path and datasource, and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.save(path.toString, "org.apache.spark.sql.json") + checkLoad() } test("save with data source and options, and load") { - df.save("org.apache.spark.sql.json", ("path", path.toString)) - checkLoad + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString)) + checkLoad() } test("save and save again") { - df.save(path.toString) + df.save(path.toString, "org.apache.spark.sql.json") - val message = intercept[RuntimeException] { - df.save(path.toString) + var message = intercept[RuntimeException] { + df.save(path.toString, "org.apache.spark.sql.json") }.getMessage assert( @@ -82,7 +100,18 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { if (path.exists()) Utils.deleteRecursively(path) - df.save(path.toString) - checkLoad + df.save(path.toString, "org.apache.spark.sql.json") + checkLoad() + + df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString)) + checkLoad() + + message = intercept[RuntimeException] { + df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString)) + }.getMessage + + assert( + message.contains("Append mode is not supported"), + "We should complain that 'Append mode is not supported' for JSON source.") } } \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2c00659496..7ae6ed6f84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -79,18 +79,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } - /** - * Creates a table using the schema of the given class. - * - * @param tableName The name of the table to create. - * @param allowExisting When false, an exception will be thrown if the table already exists. - * @tparam A A case class that is used to describe the schema of the table to be created. - */ - @Deprecated - def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) - } - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a @@ -107,70 +95,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.invalidateTable("default", tableName) } - @Experimental - def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = { - val dataSourceName = conf.defaultDataSourceName - createTable(tableName, dataSourceName, allowExisting, ("path", path)) - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - allowExisting: Boolean, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = None, - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting) - executePlan(cmd).toRdd - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - schema: StructType, - allowExisting: Boolean, - option: (String, String), - options: (String, String)*): Unit = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = Some(schema), - dataSourceName, - temporary = false, - (option +: options).toMap, - allowExisting) - executePlan(cmd).toRdd - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - allowExisting: Boolean, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*) - } - - @Experimental - def createTable( - tableName: String, - dataSourceName: String, - schema: StructType, - allowExisting: Boolean, - options: java.util.Map[String, String]): Unit = { - val opts = options.toSeq - createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*) - } - /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 95abc363ae..cb138be90e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -216,20 +216,21 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) => + case CreateTableUsing( + tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => ExecutedCommand( CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil + tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) => + case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) => val logicalPlan = hiveContext.parseSql(query) val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan) + CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, logicalPlan) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) => + case CreateTableUsingAsLogicalPlan(tableName, provider, false, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query) + CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 95dcaccefd..f6bea1c6a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.sources.ResolvedDataSource +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.sources._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -105,7 +107,8 @@ case class CreateMetastoreDataSource( userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], - allowExisting: Boolean) extends RunnableCommand { + allowExisting: Boolean, + managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] @@ -120,7 +123,7 @@ case class CreateMetastoreDataSource( var isExternal = true val optionsWithPath = - if (!options.contains("path")) { + if (!options.contains("path") && managedIfNoPath) { isExternal = false options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) } else { @@ -141,22 +144,13 @@ case class CreateMetastoreDataSource( case class CreateMetastoreDataSourceAsSelect( tableName: String, provider: String, + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - - if (hiveContext.catalog.tableExists(tableName :: Nil)) { - if (allowExisting) { - return Seq.empty[Row] - } else { - sys.error(s"Table $tableName already exists.") - } - } - - val df = DataFrame(hiveContext, query) + var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { @@ -166,15 +160,82 @@ case class CreateMetastoreDataSourceAsSelect( options } - // Create the relation based on the data of df. - ResolvedDataSource(sqlContext, provider, optionsWithPath, df) + if (sqlContext.catalog.tableExists(Seq(tableName))) { + // Check if we need to throw an exception or just return. + mode match { + case SaveMode.ErrorIfExists => + sys.error(s"Table $tableName already exists. " + + s"If you want to append into it, please set mode to SaveMode.Append. " + + s"Or, if you want to overwrite it, please set mode to SaveMode.Overwrite.") + case SaveMode.Ignore => + // Since the table already exists and the save mode is Ignore, we will just return. + return Seq.empty[Row] + case SaveMode.Append => + // Check if the specified data source match the data source of the existing table. + val resolved = + ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath) + val createdRelation = LogicalRelation(resolved.relation) + EliminateAnalysisOperators(sqlContext.table(tableName).logicalPlan) match { + case l @ LogicalRelation(i: InsertableRelation) => + if (l.schema != createdRelation.schema) { + val errorDescription = + s"Cannot append to table $tableName because the schema of this " + + s"DataFrame does not match the schema of table $tableName." + val errorMessage = + s""" + |$errorDescription + |== Schemas == + |${sideBySide( + s"== Expected Schema ==" +: + l.schema.treeString.split("\\\n"), + s"== Actual Schema ==" +: + createdRelation.schema.treeString.split("\\\n")).mkString("\n")} + """.stripMargin + sys.error(errorMessage) + } else if (i != createdRelation.relation) { + val errorDescription = + s"Cannot append to table $tableName because the resolved relation does not " + + s"match the existing relation of $tableName. " + + s"You can use insertInto($tableName, false) to append this DataFrame to the " + + s"table $tableName and using its data source and options." + val errorMessage = + s""" + |$errorDescription + |== Relations == + |${sideBySide( + s"== Expected Relation ==" :: + l.toString :: Nil, + s"== Actual Relation ==" :: + createdRelation.toString :: Nil).mkString("\n")} + """.stripMargin + sys.error(errorMessage) + } + case o => + sys.error(s"Saving data in ${o.toString} is not supported.") + } + case SaveMode.Overwrite => + hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + // Need to create the table again. + createMetastoreTable = true + } + } else { + // The table does not exist. We need to create it in metastore. + createMetastoreTable = true + } - hiveContext.catalog.createDataSourceTable( - tableName, - None, - provider, - optionsWithPath, - isExternal) + val df = DataFrame(hiveContext, query) + + // Create the relation based on the data of df. + ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df) + + if (createMetastoreTable) { + hiveContext.catalog.createDataSourceTable( + tableName, + Some(df.schema), + provider, + optionsWithPath, + isExternal) + } Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala similarity index 99% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7c1d1133c3..840fbc1972 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import scala.collection.mutable -import scala.language.implicitConversions - import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table @@ -30,16 +27,18 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} + +import scala.collection.mutable +import scala.language.implicitConversions /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -224,11 +223,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } }), TestTable("src_thrift", () => { - import org.apache.thrift.protocol.TBinaryProtocol - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.mapred.SequenceFileInputFormat - import org.apache.hadoop.mapred.SequenceFileOutputFormat + import org.apache.hadoop.hive.serde2.thrift.test.Complex + import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} + import org.apache.thrift.protocol.TBinaryProtocol val srcThrift = new Table("default", "src_thrift") srcThrift.setFields(Nil) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java new file mode 100644 index 0000000000..9744a2aa3f --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.sources.SaveMode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.QueryTest$; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaMetastoreDataSourcesSuite { + private transient JavaSparkContext sc; + private transient HiveContext sqlContext; + + String originalDefaultSource; + File path; + Path hiveManagedPath; + FileSystem fs; + DataFrame df; + + private void checkAnswer(DataFrame actual, List<Row> expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestHive$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); + if (fs.exists(hiveManagedPath)){ + fs.delete(hiveManagedPath, true); + } + + List<String> jsonObjects = new ArrayList<String>(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD<String> rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } + + @Test + public void saveExternalTableAndQueryIt() { + Map<String, String> options = new HashMap<String, String>(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + df.collectAsList()); + } + + @Test + public void saveExternalTableWithSchemaAndQueryIt() { + Map<String, String> options = new HashMap<String, String>(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + List<StructField> fields = new ArrayList<>(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); + + checkAnswer( + loadedDF, + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + } + + @Test + public void saveTableAndQueryIt() { + Map<String, String> options = new HashMap<String, String>(); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index ba39129388..0270e63557 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql -import org.scalatest.FunSuite +import scala.collection.JavaConversions._ -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -55,9 +53,36 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(rdd, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param rdd the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -73,18 +98,20 @@ class QueryTest extends PlanTest { } val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: |${rdd.logicalPlan} |== Analyzed Plan == @@ -93,22 +120,21 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) + return None } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(rdd, expectedAnswer.toSeq) match { + case Some(errorMessage) => errorMessage + case None => null } } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 869d01eb39..43da7519ac 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,7 +19,11 @@ package org.apache.spark.sql.hive import java.io.File +import org.scalatest.BeforeAndAfter + import com.google.common.io.Files + +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ @@ -29,15 +33,22 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) -class InsertIntoHiveTableSuite extends QueryTest { +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerTempTable("testData") + + before { + // Since every we are doing tests for DDL statements, + // it is better to reset before every test. + TestHive.reset() + // Register the testData, which will be used in every test. + testData.registerTempTable("testData") + } test("insertInto() HiveTable") { - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. testData.insertInto("createAndInsertTest") @@ -68,16 +79,18 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("Double create fails when allowExisting = false") { - createTable[TestData]("doubleCreateAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[org.apache.hadoop.hive.ql.metadata.HiveException] { - createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false) - } + val message = intercept[QueryExecutionException] { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + }.getMessage + + println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { - createTable[TestData]("createAndInsertTest") - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { @@ -98,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("SPARK-4203:random partition directory order") { - createTable[TestData]("tmp_table") + sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Files.createTempDir() sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 9ce058909f..f94aabd29a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.spark.sql.sources.SaveMode import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql.catalyst.util import org.apache.spark.sql._ @@ -41,11 +43,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { override def afterEach(): Unit = { reset() - if (ctasPath.exists()) Utils.deleteRecursively(ctasPath) + if (tempPath.exists()) Utils.deleteRecursively(tempPath) } val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile + var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile test ("persistent JSON table") { sql( @@ -270,7 +272,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -297,7 +299,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -309,7 +311,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -325,7 +327,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE IF NOT EXISTS ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT a FROM jsonTable """.stripMargin) @@ -400,38 +402,122 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("DROP TABLE jsonTable").collect().foreach(println) } - test("save and load table") { + test("save table") { val originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) val df = jsonRDD(rdd) + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + // Save the df as a managed table (by not specifiying the path). df.saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable"), df.collect()) - createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false) + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.saveAsTable("savedJsonTable", SaveMode.Append) + } + + // We can overwrite it. + df.saveAsTable("savedJsonTable", SaveMode.Overwrite) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) + assert(df.schema === table("savedJsonTable").schema) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + + // Create an external table by specifying the path. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer( + jsonFile(tempPath.toString), + df.collect()) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + test("create external table") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + createExternalTable("createdJsonTable", tempPath.toString) assert(table("createdJsonTable").schema === df.schema) checkAnswer( sql("SELECT * FROM createdJsonTable"), df.collect()) - val message = intercept[RuntimeException] { - createTable("createdJsonTable", filePath.toString, false) + var message = intercept[RuntimeException] { + createExternalTable("createdJsonTable", filePath.toString) }.getMessage assert(message.contains("Table createdJsonTable already exists."), "We should complain that ctasJsonTable already exists") - createTable("createdJsonTable", filePath.toString, true) - // createdJsonTable should be not changed. - assert(table("createdJsonTable").schema === df.schema) + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") checkAnswer( - sql("SELECT * FROM createdJsonTable"), + jsonFile(tempPath.toString), df.collect()) - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + // Try to specify the schema. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable").collect()) + + sql("DROP TABLE createdJsonTable") + + message = intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage + assert( + message.contains("Option 'path' not specified"), + "We should complain that path is not specified.") + + sql("DROP TABLE savedJsonTable") + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } } -- GitLab