Skip to content
Snippets Groups Projects
Commit 66f05f38 authored by Charles Reiss's avatar Charles Reiss
Browse files

Add new Hadoop API reading support.

parent 02d43e69
No related branches found
No related tags found
No related merge requests found
package spark
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.InputFormat
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.JobID
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.TaskAttemptID
import java.util.Date
import java.text.SimpleDateFormat
class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
override def hashCode(): Int = (41 * (41 + rddId) + index).toInt
}
class NewHadoopRDD[K, V](
sc: SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K], valueClass: Class[V],
@transient conf: Configuration)
extends RDD[(K, V)](sc) {
private val serializableConf = new SerializableWritable(conf)
private val jobtrackerId: String = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
formatter.format(new Date())
}
@transient private val jobId = new JobID(jobtrackerId, id)
@transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = new JobContext(serializableConf.value, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Split](rawSplits.size)
for (i <- 0 until rawSplits.size)
result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
result
}
override def splits = splits_
override def compute(theSplit: Split) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = serializableConf.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val context = new TaskAttemptContext(serializableConf.value, attemptId)
val format = inputFormatClass.newInstance
val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
reader.initialize(split.serializableHadoopSplit.value, context)
var havePair = false
var finished = false
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
if (finished) {
reader.close
}
}
!finished
}
override def next: (K, V) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
return (reader.getCurrentKey, reader.getCurrentValue)
}
}
override def preferredLocations(split: Split) = {
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
override val dependencies: List[Dependency[_]] = Nil
}
......@@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.io.Writable
......@@ -22,6 +24,10 @@ import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
import spark.broadcast._
class SparkContext(
......@@ -123,6 +129,29 @@ extends Logging {
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
val job = new NewHadoopJob
NewFileInputFormat.addInputPath(job, new Path(path))
val conf = job.getConfiguration
newAPIHadoopFile(path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
conf)
}
/** Get an RDD for a given Hadoop file with an arbitrary new API InputFormat and extra
* configuration options to pass to the input format.
*/
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
conf: Configuration): RDD[(K, V)] =
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
/** Get an RDD for a Hadoop SequenceFile with given key and value types */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
......
......@@ -128,4 +128,17 @@ class FileSuite extends FunSuite {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
sc.stop()
}
test("read SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat
val sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
nums.saveAsSequenceFile(outputDir)
val output =
sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
sc.stop()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment