Skip to content
Snippets Groups Projects
Commit 9e68f486 authored by Stephen Haberman's avatar Stephen Haberman
Browse files

More quickly call close in HadoopRDD.

This also refactors out the common "gotNext" iterator pattern into
a shared utility class.
parent cbf8f0d4
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
import spark.util.NextIterator
/**
......@@ -62,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
var reader: RecordReader[K, V] = null
......@@ -75,34 +76,18 @@ class HadoopRDD[K, V](
val key: K = reader.createKey()
val value: V = reader.createValue()
var gotNext = false
var finished = false
override def hasNext: Boolean = {
if (!gotNext) {
try {
finished = !reader.next(key, value)
} catch {
case eof: EOFException =>
finished = true
}
gotNext = true
}
!finished
}
override def next: (K, V) = {
if (!gotNext) {
override def getNext() = {
try {
finished = !reader.next(key, value)
} catch {
case eof: EOFException =>
finished = true
}
if (finished) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
(key, value)
}
private def close() {
override def close() {
try {
reader.close()
} catch {
......
......@@ -72,40 +72,14 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
*/
def asIterator: Iterator[Any] = new Iterator[Any] {
var gotNext = false
var finished = false
var nextValue: Any = null
private def getNext() {
def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] {
override protected def getNext() = {
try {
nextValue = readObject[Any]()
readObject[Any]()
} catch {
case eof: EOFException =>
finished = true
}
gotNext = true
}
override def hasNext: Boolean = {
if (!gotNext) {
getNext()
}
if (finished) {
close()
}
!finished
}
override def next(): Any = {
if (!gotNext) {
getNext()
}
if (finished) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
}
}
}
package spark.util
/** Provides a basic/boilerplate Iterator implementation. */
private[spark] abstract class NextIterator[U] extends Iterator[U] {
private var gotNext = false
private var nextValue: U = _
protected var finished = false
/**
* Method for subclasses to implement to provide the next element.
*
* If no next element is available, the subclass should set `finished`
* to `true` and may return any value (it will be ignored).
*
* This convention is required because `null` may be a valid value,
* and using `Option` seems like it might create unnecessary Some/None
* instances, given some iterators might be called in a tight loop.
*
* @return U, or set 'finished' when done
*/
protected def getNext(): U
/**
* Method for subclasses to optionally implement when all elements
* have been successfully iterated, and the iteration is done.
*
* <b>Note:</b> `NextIterator` cannot guarantee that `close` will be
* called because it has no control over what happens when an exception
* happens in the user code that is calling hasNext/next.
*
* Ideally you should have another try/catch, as in HadoopRDD, that
* ensures any resources are closed should iteration fail.
*/
protected def close() {
}
override def hasNext: Boolean = {
if (!finished) {
if (!gotNext) {
nextValue = getNext()
if (finished) {
close()
}
gotNext = true
}
}
!finished
}
override def next(): U = {
if (!hasNext) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
}
}
\ No newline at end of file
......@@ -2,6 +2,7 @@ package spark.streaming.dstream
import spark.streaming.StreamingContext
import spark.storage.StorageLevel
import spark.util.NextIterator
import java.io._
import java.net.Socket
......@@ -59,45 +60,18 @@ object SocketReceiver {
*/
def bytesToLines(inputStream: InputStream): Iterator[String] = {
val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))
val iterator = new Iterator[String] {
var gotNext = false
var finished = false
var nextValue: String = null
private def getNext() {
try {
nextValue = dataInputStream.readLine()
if (nextValue == null) {
finished = true
}
}
gotNext = true
}
override def hasNext: Boolean = {
if (!finished) {
if (!gotNext) {
getNext()
if (finished) {
dataInputStream.close()
}
}
new NextIterator[String] {
protected override def getNext() {
val nextValue = dataInputStream.readLine()
if (nextValue == null) {
finished = true
}
!finished
nextValue
}
override def next(): String = {
if (finished) {
throw new NoSuchElementException("End of stream")
}
if (!gotNext) {
getNext()
}
gotNext = false
nextValue
protected override def close() {
dataInputStream.close()
}
}
iterator
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment