Skip to content
Snippets Groups Projects
Commit 72eed2b9 authored by Tathagata Das's avatar Tathagata Das
Browse files

Converted CheckpointState in RDDCheckpointData to use scala Enumeration.

parent 8e74fac2
No related branches found
No related tags found
No related merge requests found
......@@ -5,45 +5,41 @@ import rdd.CoalescedRDD
import scheduler.{ResultTask, ShuffleMapTask}
/**
* This class contains all the information of the regarding RDD checkpointing.
* Enumeration to manage state transitions of an RDD through checkpointing
* [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
*/
private[spark] object CheckpointState extends Enumeration {
type CheckpointState = Value
val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
}
/**
* This class contains all the information of the regarding RDD checkpointing.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
extends Logging with Serializable {
/**
* This class manages the state transition of an RDD through checkpointing
* [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
*/
class CheckpointState extends Serializable {
var state = 0
import CheckpointState._
def mark() { if (state == 0) state = 1 }
def start() { assert(state == 1); state = 2 }
def finish() { assert(state == 2); state = 3 }
def isMarked() = { state == 1 }
def isInProgress = { state == 2 }
def isCheckpointed = { state == 3 }
}
val cpState = new CheckpointState()
var cpState = Initialized
@transient var cpFile: Option[String] = None
@transient var cpRDD: Option[RDD[T]] = None
@transient var cpRDDSplits: Seq[Split] = Nil
// Mark the RDD for checkpointing
def markForCheckpoint() = {
RDDCheckpointData.synchronized { cpState.mark() }
def markForCheckpoint() {
RDDCheckpointData.synchronized {
if (cpState == Initialized) cpState = MarkedForCheckpoint
}
}
// Is the RDD already checkpointed
def isCheckpointed() = {
RDDCheckpointData.synchronized { cpState.isCheckpointed }
def isCheckpointed(): Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
}
// Get the file to which this RDD was checkpointed to as a Option
def getCheckpointFile() = {
// Get the file to which this RDD was checkpointed to as an Option
def getCheckpointFile(): Option[String] = {
RDDCheckpointData.synchronized { cpFile }
}
......@@ -52,8 +48,8 @@ extends Logging with Serializable {
// If it is marked for checkpointing AND checkpointing is not already in progress,
// then set it to be in progress, else return
RDDCheckpointData.synchronized {
if (cpState.isMarked && !cpState.isInProgress) {
cpState.start()
if (cpState == MarkedForCheckpoint) {
cpState = CheckpointingInProgress
} else {
return
}
......@@ -87,7 +83,7 @@ extends Logging with Serializable {
cpRDD = Some(newRDD)
cpRDDSplits = newRDD.splits
rdd.changeDependencies(newRDD)
cpState.finish()
cpState = Checkpointed
RDDCheckpointData.checkpointCompleted()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment