Skip to content
Snippets Groups Projects
Commit 9812a24a authored by hyukjinkwon's avatar hyukjinkwon Committed by Reynold Xin
Browse files

[SPARK-13503][SQL] Support to specify the (writing) option for compression codec for TEXT

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-13503
This PR makes the TEXT datasource can compress output by option instead of manually setting Hadoop configurations.
For reflecting codec by names, it is similar with https://github.com/apache/spark/pull/10805 and https://github.com/apache/spark/pull/10858.

## How was this patch tested?

This was tested with unittests and with `dev/run_tests` for coding style

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #11384 from HyukjinKwon/SPARK-13503.
parent 26ac6080
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources package org.apache.spark.sql.execution.datasources
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec} import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -44,4 +46,16 @@ private[datasources] object CompressionCodecs { ...@@ -44,4 +46,16 @@ private[datasources] object CompressionCodecs {
s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
} }
} }
/**
* Set compression configurations to Hadoop `Configuration`.
* `codec` should be a full class path
*/
def setCodecConfiguration(conf: Configuration, codec: String): Unit = {
conf.set("mapreduce.output.fileoutputformat.compress", "true")
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
conf.set("mapreduce.map.output.compress", "true")
conf.set("mapreduce.map.output.compress.codec", codec)
}
} }
...@@ -24,7 +24,6 @@ import scala.util.control.NonFatal ...@@ -24,7 +24,6 @@ import scala.util.control.NonFatal
import com.google.common.base.Objects import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.hadoop.mapreduce.RecordWriter import org.apache.hadoop.mapreduce.RecordWriter
...@@ -34,6 +33,7 @@ import org.apache.spark.Logging ...@@ -34,6 +33,7 @@ import org.apache.spark.Logging
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.CompressionCodecs
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -50,16 +50,16 @@ private[sql] class CSVRelation( ...@@ -50,16 +50,16 @@ private[sql] class CSVRelation(
case None => inferSchema(paths) case None => inferSchema(paths)
} }
private val params = new CSVOptions(parameters) private val options = new CSVOptions(parameters)
@transient @transient
private var cachedRDD: Option[RDD[String]] = None private var cachedRDD: Option[RDD[String]] = None
private def readText(location: String): RDD[String] = { private def readText(location: String): RDD[String] = {
if (Charset.forName(params.charset) == Charset.forName("UTF-8")) { if (Charset.forName(options.charset) == Charset.forName("UTF-8")) {
sqlContext.sparkContext.textFile(location) sqlContext.sparkContext.textFile(location)
} else { } else {
val charset = params.charset val charset = options.charset
sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location)
.mapPartitions { _.map { pair => .mapPartitions { _.map { pair =>
new String(pair._2.getBytes, 0, pair._2.getLength, charset) new String(pair._2.getBytes, 0, pair._2.getLength, charset)
...@@ -81,8 +81,8 @@ private[sql] class CSVRelation( ...@@ -81,8 +81,8 @@ private[sql] class CSVRelation(
private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = { private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = {
val rdd = baseRdd(inputPaths) val rdd = baseRdd(inputPaths)
// Make sure firstLine is materialized before sending to executors // Make sure firstLine is materialized before sending to executors
val firstLine = if (params.headerFlag) findFirstLine(rdd) else null val firstLine = if (options.headerFlag) findFirstLine(rdd) else null
CSVRelation.univocityTokenizer(rdd, header, firstLine, params) CSVRelation.univocityTokenizer(rdd, header, firstLine, options)
} }
/** /**
...@@ -96,20 +96,16 @@ private[sql] class CSVRelation( ...@@ -96,20 +96,16 @@ private[sql] class CSVRelation(
val pathsString = inputs.map(_.getPath.toUri.toString) val pathsString = inputs.map(_.getPath.toUri.toString)
val header = schema.fields.map(_.name) val header = schema.fields.map(_.name)
val tokenizedRdd = tokenRdd(header, pathsString) val tokenizedRdd = tokenRdd(header, pathsString)
CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params) CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, options)
} }
override def prepareJobForWrite(job: Job): OutputWriterFactory = { override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val conf = job.getConfiguration val conf = job.getConfiguration
params.compressionCodec.foreach { codec => options.compressionCodec.foreach { codec =>
conf.set("mapreduce.output.fileoutputformat.compress", "true") CompressionCodecs.setCodecConfiguration(conf, codec)
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
conf.set("mapreduce.map.output.compress", "true")
conf.set("mapreduce.map.output.compress.codec", codec)
} }
new CSVOutputWriterFactory(params) new CSVOutputWriterFactory(options)
} }
override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns) override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns)
...@@ -129,17 +125,17 @@ private[sql] class CSVRelation( ...@@ -129,17 +125,17 @@ private[sql] class CSVRelation(
private def inferSchema(paths: Array[String]): StructType = { private def inferSchema(paths: Array[String]): StructType = {
val rdd = baseRdd(paths) val rdd = baseRdd(paths)
val firstLine = findFirstLine(rdd) val firstLine = findFirstLine(rdd)
val firstRow = new LineCsvReader(params).parseLine(firstLine) val firstRow = new LineCsvReader(options).parseLine(firstLine)
val header = if (params.headerFlag) { val header = if (options.headerFlag) {
firstRow firstRow
} else { } else {
firstRow.zipWithIndex.map { case (value, index) => s"C$index" } firstRow.zipWithIndex.map { case (value, index) => s"C$index" }
} }
val parsedRdd = tokenRdd(header, paths) val parsedRdd = tokenRdd(header, paths)
if (params.inferSchemaFlag) { if (options.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, params.nullValue) CSVInferSchema.infer(parsedRdd, header, options.nullValue)
} else { } else {
// By default fields are assumed to be StringType // By default fields are assumed to be StringType
val schemaFields = header.map { fieldName => val schemaFields = header.map { fieldName =>
...@@ -153,8 +149,8 @@ private[sql] class CSVRelation( ...@@ -153,8 +149,8 @@ private[sql] class CSVRelation(
* Returns the first line of the first non-empty file in path * Returns the first line of the first non-empty file in path
*/ */
private def findFirstLine(rdd: RDD[String]): String = { private def findFirstLine(rdd: RDD[String]): String = {
if (params.isCommentSet) { if (options.isCommentSet) {
val comment = params.comment.toString val comment = options.comment.toString
rdd.filter { line => rdd.filter { line =>
line.trim.nonEmpty && !line.startsWith(comment) line.trim.nonEmpty && !line.startsWith(comment)
}.first() }.first()
......
...@@ -165,11 +165,7 @@ private[sql] class JSONRelation( ...@@ -165,11 +165,7 @@ private[sql] class JSONRelation(
override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = {
val conf = job.getConfiguration val conf = job.getConfiguration
options.compressionCodec.foreach { codec => options.compressionCodec.foreach { codec =>
conf.set("mapreduce.output.fileoutputformat.compress", "true") CompressionCodecs.setCodecConfiguration(conf, codec)
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
conf.set("mapreduce.map.output.compress", "true")
conf.set("mapreduce.map.output.compress.codec", codec)
} }
new BucketedOutputWriterFactory { new BucketedOutputWriterFactory {
......
...@@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} ...@@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.execution.datasources.{CompressionCodecs, PartitionSpec}
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.SerializableConfiguration
...@@ -48,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { ...@@ -48,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
partitionColumns: Option[StructType], partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = { parameters: Map[String, String]): HadoopFsRelation = {
dataSchema.foreach(verifySchema) dataSchema.foreach(verifySchema)
new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext) new TextRelation(None, dataSchema, partitionColumns, paths, parameters)(sqlContext)
} }
override def shortName(): String = "text" override def shortName(): String = "text"
...@@ -114,6 +114,15 @@ private[sql] class TextRelation( ...@@ -114,6 +114,15 @@ private[sql] class TextRelation(
/** Write path. */ /** Write path. */
override def prepareJobForWrite(job: Job): OutputWriterFactory = { override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val conf = job.getConfiguration
val compressionCodec = {
val name = parameters.get("compression").orElse(parameters.get("codec"))
name.map(CompressionCodecs.getCodecClassName)
}
compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, path: String,
......
...@@ -57,6 +57,21 @@ class TextSuite extends QueryTest with SharedSQLContext { ...@@ -57,6 +57,21 @@ class TextSuite extends QueryTest with SharedSQLContext {
} }
} }
test("SPARK-13503 Support to specify the option for compression codec for TEXT") {
val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")
val tempFile = Utils.createTempDir()
tempFile.delete()
df.write
.option("compression", "gZiP")
.text(tempFile.getCanonicalPath)
val compressedFiles = tempFile.listFiles()
assert(compressedFiles.exists(_.getName.endsWith(".gz")))
verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))
Utils.deleteRecursively(tempFile)
}
private def testFile: String = { private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment