Skip to content
Snippets Groups Projects
Commit 04901dd0 authored by Takeshi Yamamuro's avatar Takeshi Yamamuro Committed by Xiao Li
Browse files

[SPARK-20431][SQL] Specify a schema by using a DDL-formatted string

## What changes were proposed in this pull request?
This pr supported a DDL-formatted string in `DataFrameReader.schema`.
This fix could make users easily define a schema without importing  `o.a.spark.sql.types._`.

## How was this patch tested?
Added tests in `DataFrameReaderWriterSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17719 from maropu/SPARK-20431.
parent 7144b518
No related branches found
No related tags found
No related merge requests found
......@@ -96,14 +96,18 @@ class DataFrameReader(OptionUtils):
By specifying the schema here, the underlying data source can skip the schema
inference step, and thus speed up data loading.
:param schema: a :class:`pyspark.sql.types.StructType` object
:param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string
(For example ``col0 INT, col1 DOUBLE``).
"""
from pyspark.sql import SparkSession
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
spark = SparkSession.builder.getOrCreate()
jschema = spark._jsparkSession.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
if isinstance(schema, StructType):
jschema = spark._jsparkSession.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
elif isinstance(schema, basestring):
self._jreader = self._jreader.schema(schema)
else:
raise TypeError("schema should be StructType or string")
return self
@since(1.5)
......@@ -137,7 +141,8 @@ class DataFrameReader(OptionUtils):
:param path: optional string or a list of string for file-system backed data sources.
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param options: all other string options
>>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
......@@ -181,7 +186,8 @@ class DataFrameReader(OptionUtils):
:param path: string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param primitivesAsString: infers all primitive values as a string type. If None is set,
it uses the default value, ``false``.
:param prefersDecimal: infers all floating-point values as a decimal type. If the values
......@@ -324,7 +330,8 @@ class DataFrameReader(OptionUtils):
``inferSchema`` option or specify the schema explicitly using ``schema``.
:param path: string, or list of strings, for input path(s).
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param sep: sets the single character as a separator for each field and value.
If None is set, it uses the default value, ``,``.
:param encoding: decodes the CSV files by the given encoding type. If None is set,
......
......@@ -67,6 +67,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
this
}
/**
* Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can
* infer the input schema automatically from data. By specifying the schema here, the underlying
* data source can skip the schema inference step, and thus speed up data loading.
*
* @since 2.3.0
*/
def schema(schemaString: String): DataFrameReader = {
this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString))
this
}
/**
* Adds an input option for the underlying data source.
*
......
......@@ -128,6 +128,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
import testImplicits._
private val userSchema = new StructType().add("s", StringType)
private val userSchemaString = "s STRING"
private val textSchema = new StructType().add("value", StringType)
private val data = Seq("1", "2", "3")
private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
......@@ -678,4 +679,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
assert(e.contains("User specified schema not supported with `table`"))
}
}
test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
testRead(spark.read.schema(userSchemaString).text(), Seq.empty, userSchema)
testRead(spark.read.schema(userSchemaString).text(dir), data, userSchema)
testRead(spark.read.schema(userSchemaString).text(dir, dir), data ++ data, userSchema)
testRead(spark.read.schema(userSchemaString).text(Seq(dir, dir): _*), data ++ data, userSchema)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment