Skip to content
Snippets Groups Projects
Commit e1a897b6 authored by Reynold Xin's avatar Reynold Xin Committed by Yin Huai
Browse files

[SPARK-11274] [SQL] Text data source support for Spark SQL.

This adds API for reading and writing text files, similar to SparkContext.textFile and RDD.saveAsTextFile.
```
SQLContext.read.text("/path/to/something.txt")
DataFrame.write.text("/path/to/write.txt")
```

Using the new Dataset API, this also supports
```
val ds: Dataset[String] = SQLContext.read.text("/path/to/something.txt").as[String]
```

Author: Reynold Xin <rxin@databricks.com>

Closes #9240 from rxin/SPARK-11274.
parent 4e38defa
No related branches found
No related tags found
No related merge requests found
org.apache.spark.sql.execution.datasources.jdbc.DefaultSource
org.apache.spark.sql.execution.datasources.json.DefaultSource
org.apache.spark.sql.execution.datasources.parquet.DefaultSource
org.apache.spark.sql.execution.datasources.text.DefaultSource
......@@ -302,6 +302,22 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName)))
}
/**
* Loads a text file and returns a [[DataFrame]] with a single string column named "text".
* Each line in the text file is a new row in the resulting DataFrame. For example:
* {{{
* // Scala:
* sqlContext.read.text("/path/to/spark/README.md")
*
* // Java:
* sqlContext.read().text("/path/to/spark/README.md")
* }}}
*
* @param path input path
* @since 1.6.0
*/
def text(path: String): DataFrame = format("text").load(path)
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
......
......@@ -244,6 +244,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included.
*
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
val props = new Properties()
......@@ -317,6 +319,22 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
def orc(path: String): Unit = format("orc").save(path)
/**
* Saves the content of the [[DataFrame]] in a text file at the specified path.
* The DataFrame must have only one column that is of string type.
* Each row becomes a new line in the output file. For example:
* {{{
* // Scala:
* df.write.text("/path/to/output")
*
* // Java:
* df.write().text("/path/to/output")
* }}}
*
* @since 1.6.0
*/
def text(path: String): Unit = format("text").save(path)
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
......
......@@ -161,11 +161,10 @@ private[json] class JsonOutputWriter(
context: TaskAttemptContext)
extends OutputWriter with SparkHadoopMapRedUtil with Logging {
val writer = new CharArrayWriter()
private[this] val writer = new CharArrayWriter()
// create the Generator without separator inserted between 2 records
val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
val result = new Text()
private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
private[this] val result = new Text()
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.text
import com.google.common.base.Objects
import org.apache.hadoop.fs.{Path, FileStatus}
import org.apache.hadoop.io.{NullWritable, Text, LongWritable}
import org.apache.hadoop.mapred.{TextInputFormat, JobConf}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.execution.datasources.PartitionSpec
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
/**
* A data source for reading text files.
*/
class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
override def createRelation(
sqlContext: SQLContext,
paths: Array[String],
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
dataSchema.foreach(verifySchema)
new TextRelation(None, partitionColumns, paths)(sqlContext)
}
override def shortName(): String = "text"
private def verifySchema(schema: StructType): Unit = {
if (schema.size != 1) {
throw new AnalysisException(
s"Text data source supports only a single column, and you have ${schema.size} columns.")
}
val tpe = schema(0).dataType
if (tpe != StringType) {
throw new AnalysisException(
s"Text data source supports only a string column, but you have ${tpe.simpleString}.")
}
}
}
private[sql] class TextRelation(
val maybePartitionSpec: Option[PartitionSpec],
override val userDefinedPartitionColumns: Option[StructType],
override val paths: Array[String] = Array.empty[String])
(@transient val sqlContext: SQLContext)
extends HadoopFsRelation(maybePartitionSpec) {
/** Data schema is always a single column, named "text". */
override def dataSchema: StructType = new StructType().add("text", StringType)
/** This is an internal data source that outputs internal row format. */
override val needConversion: Boolean = false
/** Read path. */
override def buildScan(inputPaths: Array[FileStatus]): RDD[Row] = {
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job)
val paths = inputPaths.map(_.getPath).sortBy(_.toUri)
if (paths.nonEmpty) {
FileInputFormat.setInputPaths(job, paths: _*)
}
sqlContext.sparkContext.hadoopRDD(
conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
.mapPartitions { iter =>
var buffer = new Array[Byte](1024)
val row = new GenericMutableRow(1)
iter.map { case (_, line) =>
if (line.getLength > buffer.length) {
buffer = new Array[Byte](line.getLength)
}
System.arraycopy(line.getBytes, 0, buffer, 0, line.getLength)
row.update(0, UTF8String.fromBytes(buffer, 0, line.getLength))
row
}
}.asInstanceOf[RDD[Row]]
}
/** Write path. */
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new TextOutputWriter(path, dataSchema, context)
}
}
}
override def equals(other: Any): Boolean = other match {
case that: TextRelation =>
paths.toSet == that.paths.toSet && partitionColumns == that.partitionColumns
case _ => false
}
override def hashCode(): Int = {
Objects.hashCode(paths.toSet, partitionColumns)
}
}
class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext)
extends OutputWriter
with SparkHadoopMapRedUtil {
private[this] val buffer = new Text()
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context)
val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context)
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
}
}.getRecordWriter(context)
}
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
val utf8string = row.getUTF8String(0)
buffer.set(utf8string.getBytes)
recordWriter.write(NullWritable.get(), buffer)
}
override def close(): Unit = {
recordWriter.close(context)
}
}
This is a test file for the text data source
1+1
数据砖头
"doh"
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.text
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.util.Utils
class TextSuite extends QueryTest with SharedSQLContext {
test("reading text file") {
verifyFrame(sqlContext.read.format("text").load(testFile))
}
test("SQLContext.read.text() API") {
verifyFrame(sqlContext.read.text(testFile))
}
test("writing") {
val df = sqlContext.read.text(testFile)
val tempFile = Utils.createTempDir()
tempFile.delete()
df.write.text(tempFile.getCanonicalPath)
verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))
Utils.deleteRecursively(tempFile)
}
test("error handling for invalid schema") {
val tempFile = Utils.createTempDir()
tempFile.delete()
val df = sqlContext.range(2)
intercept[AnalysisException] {
df.write.text(tempFile.getCanonicalPath)
}
intercept[AnalysisException] {
sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
}
}
private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
}
/** Verifies data and schema. */
private def verifyFrame(df: DataFrame): Unit = {
// schema
assert(df.schema == new StructType().add("text", StringType))
// verify content
val data = df.collect()
assert(data(0) == Row("This is a test file for the text data source"))
assert(data(1) == Row("1+1"))
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
// scalastyle:off
assert(data(2) == Row("数据砖头"))
// scalastyle:on
assert(data(3) == Row("\"doh\""))
assert(data.length == 4)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment