diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2ae854d04f5641ce58e582aefa02a00df3b308d2..841503b260c35c8742aed6d2896a4035a6149c6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -119,13 +119,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(): DataFrame = { - val dataSource = - DataSource( - sparkSession, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap) - Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())) + load(Seq.empty: _*) // force invocation of `load(...varargs...)` } /** @@ -135,7 +129,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load() + load(Seq(path): _*) // force invocation of `load(...varargs...)` } /** @@ -146,18 +140,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - if (paths.isEmpty) { - sparkSession.emptyDataFrame - } else { - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) - } + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) } + /** * Construct a [[DataFrame]] representing the database table accessible via JDBC URL * url named table and connection properties. @@ -245,6 +236,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.baseRelationToDataFrame(relation) } + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * See the documentation on the overloaded `json()` method with varargs for more details. + * + * @since 1.4.0 + */ + def json(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + json(Seq(path): _*) + } + /** * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. * @@ -252,6 +254,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * schema in advance, use the version that specifies the schema to avoid the extra scan. * * You can set the following JSON-specific options to deal with non-standard JSON files: + * <ul> * <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li> * <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal * type. If the values do not fit in decimal, then it infers them as doubles.</li> @@ -266,17 +269,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing.</li> * <ul> - * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the - * malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * <li> - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When * a schema is set by user, it sets `null` for extra fields.</li> - * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> - * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> + * <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li> + * <li> - `FAILFAST` : throws an exception when it meets corrupted records.</li> * </ul> * <li>`columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li> - * - * @since 1.6.0 + * </ul> + * @since 2.0.0 */ @scala.annotation.varargs def json(paths: String*): DataFrame = format("json").load(paths : _*) @@ -326,6 +329,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions))(sparkSession)) } + /** + * Loads a CSV file and returns the result as a [[DataFrame]]. See the documentation on the + * other overloaded `csv()` method for more details. + * + * @since 2.0.0 + */ + def csv(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + csv(Seq(path): _*) + } + /** * Loads a CSV file and returns the result as a [[DataFrame]]. * @@ -334,6 +348,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * specify the schema explicitly using [[schema]]. * * You can set the following CSV-specific options to deal with CSV files: + * <ul> * <li>`sep` (default `,`): sets the single character as a separator for each * field and value.</li> * <li>`encoding` (default `UTF-8`): decodes the CSV files by the given encoding @@ -370,26 +385,37 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing.</li> * <ul> - * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * <li> - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When * a schema is set by user, it sets `null` for extra fields.</li> - * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li> - * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li> + * <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li> + * <li> - `FAILFAST` : throws an exception when it meets corrupted records.</li> + * </ul> * </ul> - * * @since 2.0.0 */ @scala.annotation.varargs def csv(paths: String*): DataFrame = format("csv").load(paths : _*) /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. + * Loads a Parquet file, returning the result as a [[DataFrame]]. See the documentation + * on the other overloaded `parquet()` method for more details. + * + * @since 2.0.0 + */ + def parquet(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + parquet(Seq(path): _*) + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. * * You can set the following Parquet-specific option(s) for reading Parquet files: + * <ul> * <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets * whether we should merge schemas collected from all Parquet part-files. This will override * `spark.sql.parquet.mergeSchema`.</li> - * + * </ul> * @since 1.4.0 */ @scala.annotation.varargs @@ -404,7 +430,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.5.0 * @note Currently, this method can only be used after enabling Hive support. */ - def orc(path: String): DataFrame = format("orc").load(path) + def orc(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + orc(Seq(path): _*) + } + + /** + * Loads an ORC file and returns the result as a [[DataFrame]]. + * + * @param paths input paths + * @since 2.0.0 + * @note Currently, this method can only be used after enabling Hive support. + */ + @scala.annotation.varargs + def orc(paths: String*): DataFrame = format("orc").load(paths: _*) /** * Returns the specified table as a [[DataFrame]]. @@ -417,6 +456,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName))) } + /** + * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. See the documentation on + * the other overloaded `text()` method for more details. + * + * @since 2.0.0 + */ + def text(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + text(Seq(path): _*) + } + /** * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. @@ -430,12 +481,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * - * @param paths input path + * @param paths input paths * @since 1.6.0 */ @scala.annotation.varargs def text(paths: String*): DataFrame = format("text").load(paths : _*) + /** + * Loads text files and returns a [[Dataset]] of String. See the documentation on the + * other overloaded `textFile()` method for more details. + * @since 2.0.0 + */ + def textFile(path: String): Dataset[String] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + textFile(Seq(path): _*) + } + /** * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset * contains a single string column named "value". @@ -457,6 +518,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def textFile(paths: String*): Dataset[String] = { + if (userSpecifiedSchema.nonEmpty) { + throw new AnalysisException("User specified schema not supported with `textFile`") + } text(paths : _*).select("value").as[String](sparkSession.implicits.newStringEncoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..7babf7573c07579a277b2561fdb240f14f8b014b --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java @@ -0,0 +1,158 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package test.org.apache.spark.sql; + +import java.io.File; +import java.util.HashMap; + +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class JavaDataFrameReaderWriterSuite { + private SparkSession spark = new TestSparkSession(); + private StructType schema = new StructType().add("s", "string"); + private transient String input; + private transient String output; + + @Before + public void setUp() { + input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "input").toString(); + File f = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "output"); + f.delete(); + output = f.toString(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void testFormatAPI() { + spark + .read() + .format("org.apache.spark.sql.test") + .load() + .write() + .format("org.apache.spark.sql.test") + .save(); + } + + @Test + public void testOptionsAPI() { + HashMap<String, String> map = new HashMap<String, String>(); + map.put("e", "1"); + spark + .read() + .option("a", "1") + .option("b", 1) + .option("c", 1.0) + .option("d", true) + .options(map) + .text() + .write() + .option("a", "1") + .option("b", 1) + .option("c", 1.0) + .option("d", true) + .options(map) + .format("org.apache.spark.sql.test") + .save(); + } + + @Test + public void testSaveModeAPI() { + spark + .range(10) + .write() + .format("org.apache.spark.sql.test") + .mode(SaveMode.ErrorIfExists) + .save(); + } + + @Test + public void testLoadAPI() { + spark.read().format("org.apache.spark.sql.test").load(); + spark.read().format("org.apache.spark.sql.test").load(input); + spark.read().format("org.apache.spark.sql.test").load(input, input, input); + spark.read().format("org.apache.spark.sql.test").load(new String[]{input, input}); + } + + @Test + public void testTextAPI() { + spark.read().text(); + spark.read().text(input); + spark.read().text(input, input, input); + spark.read().text(new String[]{input, input}) + .write().text(output); + } + + @Test + public void testTextFileAPI() { + spark.read().textFile(); + spark.read().textFile(input); + spark.read().textFile(input, input, input); + spark.read().textFile(new String[]{input, input}); + } + + @Test + public void testCsvAPI() { + spark.read().schema(schema).csv(); + spark.read().schema(schema).csv(input); + spark.read().schema(schema).csv(input, input, input); + spark.read().schema(schema).csv(new String[]{input, input}) + .write().csv(output); + } + + @Test + public void testJsonAPI() { + spark.read().schema(schema).json(); + spark.read().schema(schema).json(input); + spark.read().schema(schema).json(input, input, input); + spark.read().schema(schema).json(new String[]{input, input}) + .write().json(output); + } + + @Test + public void testParquetAPI() { + spark.read().schema(schema).parquet(); + spark.read().schema(schema).parquet(input); + spark.read().schema(schema).parquet(input, input, input); + spark.read().schema(schema).parquet(new String[] { input, input }) + .write().parquet(output); + } + + /** + * This only tests whether API compiles, but does not run it as orc() + * cannot be run without Hive classes. + */ + public void testOrcAPI() { + spark.read().schema(schema).orc(); + spark.read().schema(schema).orc(input); + spark.read().schema(schema).orc(input, input, input); + spark.read().schema(schema).orc(new String[]{input, input}) + .write().orc(output); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 98e57b38044f25c032fee9e3453936d335873282..3fa3864bc969049000161e8a24786c8c16e6a49c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.test +import java.io.File + +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -79,10 +83,19 @@ class DefaultSource } -class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext { +class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { + - private def newMetadataDir = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + private val userSchema = new StructType().add("s", StringType) + private val textSchema = new StructType().add("value", StringType) + private val data = Seq("1", "2", "3") + private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath + private implicit var enc: Encoder[String] = _ + + before { + enc = spark.implicits.newStringEncoder + Utils.deleteRecursively(new File(dir)) + } test("writeStream cannot be called on non-streaming datasets") { val e = intercept[AnalysisException] { @@ -157,24 +170,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext { assert(LastOptions.saveMode === SaveMode.ErrorIfExists) } - test("paths") { - val df = spark.read - .format("org.apache.spark.sql.test") - .option("checkpointLocation", newMetadataDir) - .load("/test") - - assert(LastOptions.parameters("path") == "/test") - - LastOptions.clear() - - df.write - .format("org.apache.spark.sql.test") - .option("checkpointLocation", newMetadataDir) - .save("/test") - - assert(LastOptions.parameters("path") == "/test") - } - test("test different data types for options") { val df = spark.read .format("org.apache.spark.sql.test") @@ -193,7 +188,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext { .option("intOpt", 56) .option("boolOpt", false) .option("doubleOpt", 6.7) - .option("checkpointLocation", newMetadataDir) .save("/test") assert(LastOptions.parameters("intOpt") == "56") @@ -228,4 +222,152 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext { } } } + + test("load API") { + spark.read.format("org.apache.spark.sql.test").load() + spark.read.format("org.apache.spark.sql.test").load(dir) + spark.read.format("org.apache.spark.sql.test").load(dir, dir, dir) + spark.read.format("org.apache.spark.sql.test").load(Seq(dir, dir): _*) + Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) + } + + test("text - API and behavior regarding schema") { + // Writer + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + testRead(spark.read.text(dir), data, textSchema) + + // Reader, without user specified schema + testRead(spark.read.text(), Seq.empty, textSchema) + testRead(spark.read.text(dir, dir, dir), data ++ data ++ data, textSchema) + testRead(spark.read.text(Seq(dir, dir): _*), data ++ data, textSchema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.text).get, data, textSchema) + + // Reader, with user specified schema, should just apply user schema on the file data + testRead(spark.read.schema(userSchema).text(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).text(dir), data, userSchema) + testRead(spark.read.schema(userSchema).text(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchema).text(Seq(dir, dir): _*), data ++ data, userSchema) + } + + test("textFile - API and behavior regarding schema") { + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + + // Reader, without user specified schema + testRead(spark.read.textFile().toDF(), Seq.empty, textSchema) + testRead(spark.read.textFile(dir).toDF(), data, textSchema) + testRead(spark.read.textFile(dir, dir).toDF(), data ++ data, textSchema) + testRead(spark.read.textFile(Seq(dir, dir): _*).toDF(), data ++ data, textSchema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.text).get, data, textSchema) + + // Reader, with user specified schema, should just apply user schema on the file data + val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() } + assert(e.getMessage.toLowerCase.contains("user specified schema not supported")) + intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) } + intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) } + intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) } + } + + test("csv - API and behavior regarding schema") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).csv(dir) + val df = spark.read.csv(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema + + // Reader, without user specified schema + intercept[IllegalArgumentException] { + testRead(spark.read.csv(), Seq.empty, schema) + } + testRead(spark.read.csv(dir), data, schema) + testRead(spark.read.csv(dir, dir), data ++ data, schema) + testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.csv).get, data, schema) + + // Reader, with user specified schema, should just apply user schema on the file data + testRead(spark.read.schema(userSchema).csv(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).csv(dir), data, userSchema) + testRead(spark.read.schema(userSchema).csv(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchema).csv(Seq(dir, dir): _*), data ++ data, userSchema) + } + + test("json - API and behavior regarding schema") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).json(dir) + val df = spark.read.json(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema + + // Reader, without user specified schema + intercept[AnalysisException] { + testRead(spark.read.json(), Seq.empty, schema) + } + testRead(spark.read.json(dir), data, schema) + testRead(spark.read.json(dir, dir), data ++ data, schema) + testRead(spark.read.json(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.json).get, data, schema) + + // Reader, with user specified schema, data should be nulls as schema in file different + // from user schema + val expData = Seq[String](null, null, null) + testRead(spark.read.schema(userSchema).json(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).json(dir), expData, userSchema) + testRead(spark.read.schema(userSchema).json(dir, dir), expData ++ expData, userSchema) + testRead(spark.read.schema(userSchema).json(Seq(dir, dir): _*), expData ++ expData, userSchema) + } + + test("parquet - API and behavior regarding schema") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).parquet(dir) + val df = spark.read.parquet(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema + + // Reader, without user specified schema + intercept[AnalysisException] { + testRead(spark.read.parquet(), Seq.empty, schema) + } + testRead(spark.read.parquet(dir), data, schema) + testRead(spark.read.parquet(dir, dir), data ++ data, schema) + testRead(spark.read.parquet(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.parquet).get, data, schema) + + // Reader, with user specified schema, data should be nulls as schema in file different + // from user schema + val expData = Seq[String](null, null, null) + testRead(spark.read.schema(userSchema).parquet(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).parquet(dir), expData, userSchema) + testRead(spark.read.schema(userSchema).parquet(dir, dir), expData ++ expData, userSchema) + testRead( + spark.read.schema(userSchema).parquet(Seq(dir, dir): _*), expData ++ expData, userSchema) + } + + /** + * This only tests whether API compiles, but does not run it as orc() + * cannot be run without Hive classes. + */ + ignore("orc - API") { + // Reader, with user specified schema + // Refer to csv-specific test suites for behavior without user specified schema + spark.read.schema(userSchema).orc() + spark.read.schema(userSchema).orc(dir) + spark.read.schema(userSchema).orc(dir, dir, dir) + spark.read.schema(userSchema).orc(Seq(dir, dir): _*) + Option(dir).map(spark.read.schema(userSchema).orc) + + // Writer + spark.range(10).write.orc(dir) + } + + private def testRead( + df: => DataFrame, + expectedResult: Seq[String], + expectedSchema: StructType): Unit = { + checkAnswer(df, spark.createDataset(expectedResult).toDF()) + assert(df.schema === expectedSchema) + } }