diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 39c3a4996c327421f1088b28e8c06858c2950556..d29a1a9881cd4b23e63be178504abc482a8e2d96 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -29,7 +29,7 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.Path import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec @@ -618,6 +618,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance + format match { + case c: Configurable => c.setConf(wrappedConf.value) + case _ => () + } val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index fa5c9b10fe059272f3fe7ac4d6762751c13f25e5..e3e23775f011d9ae43dc4f69eb64a58cc0feffd7 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -23,6 +23,8 @@ import scala.util.Random import org.scalatest.FunSuite import com.google.common.io.Files +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.conf.{Configuration, Configurable} import org.apache.spark.SparkContext._ import org.apache.spark.{Partitioner, SharedSparkContext} @@ -330,4 +332,77 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { (1, ArrayBuffer(1)), (2, ArrayBuffer(1)))) } + + test("saveNewAPIHadoopFile should call setConf if format is configurable") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + + // No error, non-configurable formats still work + pairs.saveAsNewAPIHadoopFile[FakeFormat]("ignored") + + /* + Check that configurable formats get configured: + ConfigTestFormat throws an exception if we try to write + to it when setConf hasn't been called first. + Assertion is in ConfigTestFormat.getRecordWriter. + */ + pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") + } } + +/* + These classes are fakes for testing + "saveNewAPIHadoopFile should call setConf if format is configurable". + Unfortunately, they have to be top level classes, and not defined in + the test method, because otherwise Scala won't generate no-args constructors + and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile + tries to instantiate them with Class.newInstance. + */ +class FakeWriter extends RecordWriter[Integer, Integer] { + + def close(p1: TaskAttemptContext) = () + + def write(p1: Integer, p2: Integer) = () + +} + +class FakeCommitter extends OutputCommitter { + def setupJob(p1: JobContext) = () + + def needsTaskCommit(p1: TaskAttemptContext): Boolean = false + + def setupTask(p1: TaskAttemptContext) = () + + def commitTask(p1: TaskAttemptContext) = () + + def abortTask(p1: TaskAttemptContext) = () +} + +class FakeFormat() extends OutputFormat[Integer, Integer]() { + + def checkOutputSpecs(p1: JobContext) = () + + def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = { + new FakeWriter() + } + + def getOutputCommitter(p1: TaskAttemptContext): OutputCommitter = { + new FakeCommitter() + } +} + +class ConfigTestFormat() extends FakeFormat() with Configurable { + + var setConfCalled = false + def setConf(p1: Configuration) = { + setConfCalled = true + () + } + + def getConf: Configuration = null + + override def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = { + assert(setConfCalled, "setConf was never called") + super.getRecordWriter(p1) + } +} +