Skip to content
Snippets Groups Projects
Commit 6f9e598c authored by Aaditya Ramesh's avatar Aaditya Ramesh Committed by Shixiong Zhu
Browse files

[SPARK-13027][STREAMING] Added batch time as a parameter to updateStateByKey

Added RDD batch time as an input parameter to the update function in updateStateByKey.

Author: Aaditya Ramesh <aramesh@conviva.com>

Closes #11122 from aramesh117/SPARK-13027.
parent 745ab8bc
No related branches found
No related tags found
No related merge requests found
......@@ -453,9 +453,12 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean
): DStream[(K, S)] = ssc.withScope {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
rememberPartitioner: Boolean): DStream[(K, S)] = ssc.withScope {
val cleanedFunc = ssc.sc.clean(updateFunc)
val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
cleanedFunc(it)
}
new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, None)
}
/**
......@@ -499,10 +502,33 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = ssc.withScope {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
rememberPartitioner, Some(initialRDD))
initialRDD: RDD[(K, S)]): DStream[(K, S)] = ssc.withScope {
val cleanedFunc = ssc.sc.clean(updateFunc)
val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
cleanedFunc(it)
}
new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, Some(initialRDD))
}
/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](updateFunc: (Time, K, Seq[V], Option[S]) => Option[S],
partitioner: Partitioner,
rememberPartitioner: Boolean,
initialRDD: Option[RDD[(K, S)]] = None): DStream[(K, S)] = ssc.withScope {
val cleanedFunc = ssc.sc.clean(updateFunc)
val newUpdateFunc = (time: Time, iterator: Iterator[(K, Seq[V], Option[S])]) => {
iterator.flatMap(t => cleanedFunc(time, t._1, t._2, t._3).map(s => (t._1, s)))
}
new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, initialRDD)
}
/**
......
......@@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Duration, Time}
private[streaming]
class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
updateFunc: (Time, Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
preservePartitioning: Boolean,
initialRDD: Option[RDD[(K, S)]]
......@@ -41,8 +41,10 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
override val mustCheckpoint = true
private [this] def computeUsingPreviousRDD (
parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = {
private [this] def computeUsingPreviousRDD(
batchTime: Time,
parentRDD: RDD[(K, V)],
prevStateRDD: RDD[(K, S)]) = {
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
......@@ -53,7 +55,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
val headOption = if (itr.hasNext) Some(itr.next()) else None
(t._1, t._2._1.toSeq, headOption)
}
updateFuncLocal(i)
updateFuncLocal(batchTime, i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
......@@ -68,15 +70,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
case Some(prevStateRDD) => // If previous state RDD exists
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => // If parent RDD exists, then compute as usual
computeUsingPreviousRDD(parentRDD, prevStateRDD)
case None => // If parent RDD does not exist
case Some(parentRDD) => // If parent RDD exists, then compute as usual
computeUsingPreviousRDD (validTime, parentRDD, prevStateRDD)
case None => // If parent RDD does not exist
// Re-apply the update function to the old state RDD
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, S)]) => {
val i = iterator.map(t => (t._1, Seq[V](), Option(t._2)))
updateFuncLocal(i)
updateFuncLocal(validTime, i)
}
val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
......@@ -93,15 +94,16 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
updateFuncLocal (validTime,
iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
}
val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
Some(sessionRDD)
case Some(initialStateRDD) =>
computeUsingPreviousRDD(parentRDD, initialStateRDD)
Some (sessionRDD)
case Some (initialStateRDD) =>
computeUsingPreviousRDD(validTime, parentRDD, initialStateRDD)
}
case None => // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
......
......@@ -471,6 +471,72 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}
test("updateStateByKey - testing time stamps as input") {
type StreamingState = Long
val initial: Seq[(String, StreamingState)] = Seq(("a", 0L), ("c", 0L))
val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)
// a -> 1000, 3000, 6000, 10000, 15000, 15000
// b -> 0, 2000, 5000, 9000, 9000, 9000
// c -> 1000, 1000, 3000, 3000, 3000, 3000
val outputData: Seq[Seq[(String, StreamingState)]] = Seq(
Seq(
("a", 1000L),
("c", 0L)), // t = 1000
Seq(
("a", 3000L),
("b", 2000L),
("c", 0L)), // t = 2000
Seq(
("a", 6000L),
("b", 5000L),
("c", 3000L)), // t = 3000
Seq(
("a", 10000L),
("b", 9000L),
("c", 3000L)), // t = 4000
Seq(
("a", 15000L),
("b", 9000L),
("c", 3000L)), // t = 5000
Seq(
("a", 15000L),
("b", 9000L),
("c", 3000L)) // t = 6000
)
val updateStateOperation = (s: DStream[String]) => {
val initialRDD = s.context.sparkContext.makeRDD(initial)
val updateFunc = (time: Time,
key: String,
values: Seq[Int],
state: Option[StreamingState]) => {
// Update only if we receive values for this key during the batch.
if (values.nonEmpty) {
Option(time.milliseconds + state.getOrElse(0L))
} else {
Option(state.getOrElse(0L))
}
}
s.map(x => (x, 1)).updateStateByKey[StreamingState](updateFunc = updateFunc,
partitioner = new HashPartitioner (numInputPartitions), rememberPartitioner = false,
initialRDD = Option(initialRDD))
}
testOperation(input = inputData, operation = updateStateOperation,
expectedOutput = outputData, useSet = true)
}
test("updateStateByKey - with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))
......
......@@ -164,6 +164,10 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = {
val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) }
val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator }
val updateF3 = (_: Time, _: Int, _: Seq[Int], _: Option[Int]) => {
return
Option(1)
}
val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) }
expectCorrectException { ds.updateStateByKey(updateF1) }
expectCorrectException { ds.updateStateByKey(updateF1, 5) }
......@@ -177,6 +181,14 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
expectCorrectException {
ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD)
}
expectCorrectException {
ds.updateStateByKey(
updateFunc = updateF3,
partitioner = new HashPartitioner(5),
rememberPartitioner = true,
initialRDD = Option(initialRDD)
)
}
}
private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException {
ds.mapValues { _ => return; 1 }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment