diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aef71f9ca700104168d2b237e7ec2f682e13b801..7279173df6e4f2990d79ef6b249dc0893f8d66d6 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -98,6 +98,8 @@ class DataFrameReader(OptionUtils): :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + + >>> s = spark.read.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58aa2468e006d0241609d88c42430bbba5d6f104..5bbd70cf0a789d04e9221d7f4b8cb0fabb5048bf 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -319,16 +319,21 @@ class DataStreamReader(OptionUtils): .. note:: Evolving. - :param schema: a :class:`pyspark.sql.types.StructType` object + :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). >>> s = spark.readStream.schema(sdf_schema) + >>> s = spark.readStream.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") spark = SparkSession.builder.getOrCreate() - jschema = spark._jsparkSession.parseDataType(schema.json()) - self._jreader = self._jreader.schema(jschema) + if isinstance(schema, StructType): + jschema = spark._jsparkSession.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + elif isinstance(schema, basestring): + self._jreader = self._jreader.schema(schema) + else: + raise TypeError("schema should be StructType or string") return self @since(2.0) @@ -372,7 +377,8 @@ class DataStreamReader(OptionUtils): :param path: optional string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options >>> json_sdf = spark.readStream.format("json") \\ @@ -415,7 +421,8 @@ class DataStreamReader(OptionUtils): :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -542,7 +549,8 @@ class DataStreamReader(OptionUtils): .. note:: Evolving. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 7e8e6394b48625a2734f9452f3c0168d00e5ddb7..70ddfa8e9b83514af3e4c401c79632e8764a2ec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -59,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataStreamReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index b5f1e28d7396a8c0a0082f3854561bdb0762b804..3de0ae67a3892be06271b341c785124b837e714c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -663,4 +663,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } assert(fs.exists(checkpointDir)) } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .schema("aa INT") + .load() + + assert(LastOptions.schema.isDefined) + assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil)) + + LastOptions.clear() + } }