diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index c400e4237abe3438f849dac1f8aad51c0797762b..14997c64d505ec99df1b77701249fea361167872 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -65,7 +65,7 @@ public class JavaStatefulNetworkWordCount { JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to trackStateByKey + // Initial state RDD input to mapWithState @SuppressWarnings("unchecked") List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1), new Tuple2<String, Integer>("world", 1)); @@ -90,21 +90,21 @@ public class JavaStatefulNetworkWordCount { }); // Update the cumulative count function - final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc = - new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() { + final Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>> mappingFunc = + new Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>>() { @Override - public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) { + public Tuple2<String, Integer> call(String word, Optional<Integer> one, State<Integer> state) { int sum = one.or(0) + (state.exists() ? state.get() : 0); Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum); state.update(sum); - return Optional.of(output); + return output; } }; - // This will give a Dstream made of state (which is the cumulative count of the words) - JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream = - wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD)); + // DStream made of get cumulative counts that get updated in every batch + JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream = + wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index a4f847f118b2cd7b27ed671fc759dc534003b404..2dce1820d9734db0ebe20526006728c3243690ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -49,7 +49,7 @@ object StatefulNetworkWordCount { val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to trackStateByKey + // Initial state RDD for mapWithState operation val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the @@ -58,17 +58,17 @@ object StatefulNetworkWordCount { val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) - // Update the cumulative count using updateStateByKey + // Update the cumulative count using mapWithState // This will give a DStream made of state (which is the cumulative count of the words) - val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => { + val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => { val sum = one.getOrElse(0) + state.getOption.getOrElse(0) val output = (word, sum) state.update(sum) - Some(output) + output } - val stateDstream = wordDstream.trackStateByKey( - StateSpec.function(trackStateFunc).initialState(initialRDD)) + val stateDstream = wordDstream.mapWithState( + StateSpec.function(mappingFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 4eee97bc89613e581dc66add0c7b5c01a4376d1c..89e0c7fdf7eecf42b016426679e5441fe5e0681d 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -32,12 +32,10 @@ import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.Function4; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaTrackStateDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; /** * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 @@ -863,12 +861,12 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ /** * This test is only for testing the APIs. It's not necessary to run it. */ - public void testTrackStateByAPI() { + public void testMapWithStateAPI() { JavaPairRDD<String, Boolean> initialRDD = null; JavaPairDStream<String, Integer> wordsDstream = null; - JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream = - wordsDstream.trackStateByKey( + JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream = + wordsDstream.mapWithState( StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> { // Use all State's methods here state.exists(); @@ -884,9 +882,9 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots(); - JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 = - wordsDstream.trackStateByKey( - StateSpec.<String, Integer, Boolean, Double>function((value, state) -> { + JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 = + wordsDstream.mapWithState( + StateSpec.<String, Integer, Boolean, Double>function((key, value, state) -> { state.exists(); state.get(); state.isTimingOut(); @@ -898,6 +896,6 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots(); + JavaPairDStream<String, Boolean> mappedDStream = stateDstream2.stateSnapshots(); } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 604e64fc61630faef16062a254983bd7fabd251f..b47bdda2c2137f0953a1af23b83197e7e6078aaf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of - * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Abstract class for getting and updating the state in mapping function used in the `mapWithState` + * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) + * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * * Scala example of using `State`: * {{{ - * // A tracking function that maintains an integer state and return a String - * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { + * // A mapping function that maintains an integer state and returns a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { * // Check if state exists * if (state.exists) { * val existingState = state.get // Get the existing state @@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental * * Java example of using `State`: * {{{ - * // A tracking function that maintains an integer state and return a String - * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc = - * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() { + * // A mapping function that maintains an integer state and returns a String + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { * * @Override - * public Optional<String> call(Optional<Integer> one, State<Integer> state) { + * public String call(String key, Optional<Integer> value, State<Integer> state) { * if (state.exists()) { * int existingState = state.get(); // Get the existing state * boolean shouldRemove = ...; // Decide whether to remove the state @@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental * } * }; * }}} + * + * @tparam S Class of the state */ @Experimental sealed abstract class State[S] { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index bea5b9df20b530c3117c01ffee64dbf356ca6b1b..9f6f95223f6194766bfb87b29d6a96d4a374fb4f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming import com.google.common.base.Optional import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} +import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner import org.apache.spark.{HashPartitioner, Partitioner} @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner} /** * :: Experimental :: * Abstract class representing all the specifications of the DStream transformation - * `trackStateByKey` operation of a + * `mapWithState` operation of a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or @@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * Example in Scala: * {{{ - * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { - * ... + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string * } * - * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * val spec = StateSpec.function(mappingFunction).numPartitions(10) * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) * }}} * * Example in Java: * {{{ - * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec = - * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction) - * .numPartition(10); + * // A mapping function that maintains an integer state and return a string + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { + * @Override + * public Optional<String> call(Optional<Integer> value, State<Integer> state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; * - * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream = - * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec); + * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); * }}} + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped elements */ @Experimental -sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { +sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable { - /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + /** Set the RDD containing the initial states that will be used by `mapWithState` */ def initialState(rdd: RDD[(KeyType, StateType)]): this.type - /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + /** Set the RDD containing the initial states that will be used by `mapWithState` */ def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type /** - * Set the number of partitions by which the state RDDs generated by `trackStateByKey` + * Set the number of partitions by which the state RDDs generated by `mapWithState` * will be partitioned. Hash partitioning will be used. */ def numPartitions(numPartitions: Int): this.type /** - * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be + * Set the partitioner by which the state RDDs generated by `mapWithState` will be * be partitioned. */ def partitioner(partitioner: Partitioner): this.type /** * Set the duration after which the state of an idle key will be removed. A key and its state is - * considered idle if it has not received any data for at least the given duration. The state - * tracking function will be called one final time on the idle states that are going to be + * considered idle if it has not received any data for at least the given duration. The + * mapping function will be called one final time on the idle states that are going to be * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set * to `true` in that call. */ @@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte /** * :: Experimental :: * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] - * that is used for specifying the parameters of the DStream transformation `trackStateByKey` + * that is used for specifying the parameters of the DStream transformation `mapWithState` * that is used for specifying the parameters of the DStream transformation - * `trackStateByKey` operation of a + * `mapWithState` operation of a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * * Example in Scala: * {{{ - * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { - * ... + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string * } * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType]( - * StateSpec.function(trackingFunction).numPartitions(10)) + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) * }}} * * Example in Java: * {{{ - * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec = - * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction) - * .numPartition(10); + * // A mapping function that maintains an integer state and return a string + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { + * @Override + * public Optional<String> call(Optional<Integer> value, State<Integer> state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; * - * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream = - * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec); - * }}} + * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} */ @Experimental object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * of the `trackStateByKey` operation on a + * of the `mapWithState` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * @param trackingFunction The function applied on every data item to manage the associated state - * and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data * @tparam KeyType Class of the keys * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType]( - trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] - ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { - ClosureCleaner.clean(trackingFunction, checkSerializable = true) - new StateSpecImpl(trackingFunction) + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType] + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) + new StateSpecImpl(mappingFunction) } /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * of the `trackStateByKey` operation on a + * of the `mapWithState` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * @param trackingFunction The function applied on every data item to manage the associated state - * and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType]( - trackingFunction: (Option[ValueType], State[StateType]) => EmittedType - ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { - ClosureCleaner.clean(trackingFunction, checkSerializable = true) + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) val wrappedFunction = - (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => { - Some(trackingFunction(value, state)) + (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => { + Some(mappingFunction(key, value, state)) } new StateSpecImpl(wrappedFunction) } /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all - * the specifications of the `trackStateByKey` operation on a + * the specifications of the `mapWithState` operation on a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. * - * @param javaTrackingFunction The function applied on every data item to manage the associated - * state and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data * @tparam KeyType Class of the keys * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: - JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]): - StateSpec[KeyType, ValueType, StateType, EmittedType] = { - val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { - val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + def function[KeyType, ValueType, StateType, MappedType](mappingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s) Option(t.orNull) } - StateSpec.function(trackingFunc) + StateSpec.function(wrappedFunc) } /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * of the `trackStateByKey` operation on a + * of the `mapWithState` operation on a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. * - * @param javaTrackingFunction The function applied on every data item to manage the associated - * state and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType]( - javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]): - StateSpec[KeyType, ValueType, StateType, EmittedType] = { - val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { - javaTrackingFunction.call(Optional.fromNullable(v.get), s) + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { + mappingFunction.call(k, Optional.fromNullable(v.get), s) } - StateSpec.function(trackingFunc) + StateSpec.function(wrappedFunc) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala similarity index 66% rename from streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala rename to streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala index f459930d0660b782bde4062a0a69f44b82f665ac..16c0d6fff8229699d99853ac9b7cc07cb15c1f7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala @@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.streaming.dstream.TrackStateDStream +import org.apache.spark.streaming.dstream.MapWithStateDStream /** * :: Experimental :: - * [[JavaDStream]] representing the stream of records emitted by the tracking function in the - * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the + * DStream representing the stream of data generated by `mapWithState` operation on a + * [[JavaPairDStream]]. Additionally, it also gives access to the * stream of state snapshots, that is, the state data of all keys after a batch has updated them. * - * @tparam KeyType Class of the state key - * @tparam ValueType Class of the state value - * @tparam StateType Class of the state - * @tparam EmittedType Class of the emitted records + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped data */ @Experimental -class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType]( - dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType]) - extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) { +class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming]( + dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType]) + extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) { def stateSnapshots(): JavaPairDStream[KeyType, StateType] = new JavaPairDStream(dstream.stateSnapshots())( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 70e32b383e4580de0d91e7808c581230ec788af3..42ddd63f0f06c30963b512415c35298b180632f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * :: Experimental :: - * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream - * with a continuously updated per-key state. The user-provided state tracking function is - * applied on each keyed data item along with its corresponding state. The function can choose to - * update/remove the state and return a transformed data, which forms the - * [[JavaTrackStateDStream]]. + * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. * - * The specifications of this transformation is made through the - * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there - * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. - * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details. - * - * Example of using `trackStateByKey`: + * Example of using `mapWithState`: * {{{ - * // A tracking function that maintains an integer state and return a String - * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc = - * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() { - * - * @Override - * public Optional<String> call(Optional<Integer> one, State<Integer> state) { - * // Check if state exists, accordingly update/remove state and return transformed data - * } + * // A mapping function that maintains an integer state and return a string + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { + * @Override + * public Optional<String> call(Optional<Integer> value, State<Integer> state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } * }; * - * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream = - * keyValueDStream.<Integer, String>trackStateByKey( - * StateSpec.function(trackStateFunc).numPartitions(10)); - * }}} + * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} * * @param spec Specification of this transformation - * @tparam StateType Class type of the state - * @tparam EmittedType Class type of the tranformed data return by the tracking function + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data */ @Experimental - def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]): - JavaTrackStateDStream[K, V, StateType, EmittedType] = { - new JavaTrackStateDStream(dstream.trackStateByKey(spec)( + def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]): + JavaMapWithStateDStream[K, V, StateType, MappedType] = { + new JavaMapWithStateDStream(dstream.mapWithState(spec)( JavaSparkContext.fakeClassTag, JavaSparkContext.fakeClassTag)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala similarity index 72% rename from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala rename to streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index ea6213420e7abfa0ff7f1a48fe25b17d2dd91868..706465d4e25d76307e71aa1790a90e8c073d61e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -24,53 +24,52 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} -import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ +import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._ /** * :: Experimental :: - * DStream representing the stream of records emitted by the tracking function in the - * `trackStateByKey` operation on a + * DStream representing the stream of data generated by `mapWithState` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * Additionally, it also gives access to the stream of state snapshots, that is, the state data of * all keys after a batch has updated them. * - * @tparam KeyType Class of the state key - * @tparam ValueType Class of the state value + * @tparam KeyType Class of the key + * @tparam ValueType Class of the value * @tparam StateType Class of the state data - * @tparam EmittedType Class of the emitted records + * @tparam MappedType Class of the mapped data */ @Experimental -sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag]( - ssc: StreamingContext) extends DStream[EmittedType](ssc) { +sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag]( + ssc: StreamingContext) extends DStream[MappedType](ssc) { /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ def stateSnapshots(): DStream[(KeyType, StateType)] } -/** Internal implementation of the [[TrackStateDStream]] */ -private[streaming] class TrackStateDStreamImpl[ - KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag]( +/** Internal implementation of the [[MapWithStateDStream]] */ +private[streaming] class MapWithStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag]( dataStream: DStream[(KeyType, ValueType)], - spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType]) - extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) { + spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType]) + extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) { private val internalStream = - new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec) + new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec) override def slideDuration: Duration = internalStream.slideDuration override def dependencies: List[DStream[_]] = List(internalStream) - override def compute(validTime: Time): Option[RDD[EmittedType]] = { - internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } } + override def compute(validTime: Time): Option[RDD[MappedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } } } /** * Forward the checkpoint interval to the internal DStream that computes the state maps. This * to make sure that this DStream does not get checkpointed, only the internal stream. */ - override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = { + override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = { internalStream.checkpoint(checkpointInterval) this } @@ -87,32 +86,32 @@ private[streaming] class TrackStateDStreamImpl[ def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass - def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass + def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass } /** * A DStream that allows per-key state to be maintains, and arbitrary records to be generated - * based on updates to the state. This is the main DStream that implements the `trackStateByKey` + * based on updates to the state. This is the main DStream that implements the `mapWithState` * operation on DStreams. * * @param parent Parent (key, value) stream that is the source - * @param spec Specifications of the trackStateByKey operation + * @param spec Specifications of the mapWithState operation * @tparam K Key type * @tparam V Value type * @tparam S Type of the state maintained - * @tparam E Type of the emitted data + * @tparam E Type of the mapped data */ private[streaming] -class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( +class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) - extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) { + extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) { persist(StorageLevel.MEMORY_ONLY) private val partitioner = spec.getPartitioner().getOrElse( new HashPartitioner(ssc.sc.defaultParallelism)) - private val trackingFunction = spec.getFunction() + private val mappingFunction = spec.getFunction() override def slideDuration: Duration = parent.slideDuration @@ -130,7 +129,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { + override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { case Some(rdd) => @@ -138,13 +137,13 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT // If the RDD is not partitioned the right way, let us repartition it using the // partition index as the key. This is to ensure that state RDD is always partitioned // before creating another state RDD using it - TrackStateRDD.createFromRDD[K, V, S, E]( + MapWithStateRDD.createFromRDD[K, V, S, E]( rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) } else { rdd } case None => - TrackStateRDD.createFromPairRDD[K, V, S, E]( + MapWithStateRDD.createFromPairRDD[K, V, S, E]( spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), partitioner, validTime @@ -161,11 +160,11 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => (validTime - interval).milliseconds } - Some(new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) + Some(new MapWithStateRDD( + prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)) } } -private[streaming] object InternalTrackStateDStream { +private[streaming] object InternalMapWithStateDStream { private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 2762309134eb191d732f1c1bdc8544f5da702623..a64a1fe93f40dcb2b292e3b2227f1bbaffbd506b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * :: Experimental :: - * Return a new DStream of data generated by combining the key-value data in `this` stream - * with a continuously updated per-key state. The user-provided state tracking function is - * applied on each keyed data item along with its corresponding state. The function can choose to - * update/remove the state and return a transformed data, which forms the - * [[org.apache.spark.streaming.dstream.TrackStateDStream]]. + * Return a [[MapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. * - * The specifications of this transformation is made through the - * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there - * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. - * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details. - * - * Example of using `trackStateByKey`: + * Example of using `mapWithState`: * {{{ - * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = { - * // Check if state exists, accordingly update/remove state and return transformed data + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string * } * - * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * val spec = StateSpec.function(mappingFunction).numPartitions(10) * - * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec) + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) * }}} * * @param spec Specification of this transformation - * @tparam StateType Class type of the state - * @tparam EmittedType Class type of the tranformed data return by the tracking function + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data */ @Experimental - def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( - spec: StateSpec[K, V, StateType, EmittedType] - ): TrackStateDStream[K, V, StateType, EmittedType] = { - new TrackStateDStreamImpl[K, V, StateType, EmittedType]( + def mapWithState[StateType: ClassTag, MappedType: ClassTag]( + spec: StateSpec[K, V, StateType, MappedType] + ): MapWithStateDStream[K, V, StateType, MappedType] = { + new MapWithStateDStreamImpl[K, V, StateType, MappedType]( self, - spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] + spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]] ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala similarity index 64% rename from streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala rename to streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 30aafcf1460e3acd305c0eddade52e199f29fcbc..ed95171f73ee117be341d7df6b0978e74c67b783 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -29,60 +29,60 @@ import org.apache.spark.util.Utils import org.apache.spark._ /** - * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a - * sequence of records returned by the tracking function of `trackStateByKey`. + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the mapping function of `mapWithState`. */ -private[streaming] case class TrackStateRDDRecord[K, S, E]( - var stateMap: StateMap[K, S], var emittedRecords: Seq[E]) +private[streaming] case class MapWithStateRDDRecord[K, S, E]( + var stateMap: StateMap[K, S], var mappedData: Seq[E]) -private[streaming] object TrackStateRDDRecord { +private[streaming] object MapWithStateRDDRecord { def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( - prevRecord: Option[TrackStateRDDRecord[K, S, E]], + prevRecord: Option[MapWithStateRDDRecord[K, S, E]], dataIterator: Iterator[(K, V)], - updateFunction: (Time, K, Option[V], State[S]) => Option[E], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long], removeTimedoutData: Boolean - ): TrackStateRDDRecord[K, S, E] = { + ): MapWithStateRDDRecord[K, S, E] = { // Create a new state map by cloning the previous one (if it exists) or by creating an empty one val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } - val emittedRecords = new ArrayBuffer[E] + val mappedData = new ArrayBuffer[E] val wrappedState = new StateImpl[S]() - // Call the tracking function on each record in the data iterator, and accordingly - // update the states touched, and collect the data returned by the tracking function + // Call the mapping function on each record in the data iterator, and accordingly + // update the states touched, and collect the data returned by the mapping function dataIterator.foreach { case (key, value) => wrappedState.wrap(newStateMap.get(key)) - val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState) + val returned = mappingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } - emittedRecords ++= emittedRecord + mappedData ++= returned } - // Get the timed out state records, call the tracking function on each and collect the + // Get the timed out state records, call the mapping function on each and collect the // data returned if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => wrappedState.wrapTiminoutState(state) - val emittedRecord = updateFunction(batchTime, key, None, wrappedState) - emittedRecords ++= emittedRecord + val returned = mappingFunction(batchTime, key, None, wrappedState) + mappedData ++= returned newStateMap.remove(key) } } - TrackStateRDDRecord(newStateMap, emittedRecords) + MapWithStateRDDRecord(newStateMap, mappedData) } } /** - * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state + * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state * RDD, and a partitioned keyed-data RDD */ -private[streaming] class TrackStateRDDPartition( +private[streaming] class MapWithStateRDDPartition( idx: Int, @transient private var prevStateRDD: RDD[_], @transient private var partitionedDataRDD: RDD[_]) extends Partition { @@ -104,27 +104,28 @@ private[streaming] class TrackStateRDDPartition( /** - * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records. - * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a - * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking - * function of `trackStateByKey`. - * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created + * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. + * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * function of `mapWithState`. + * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD + * will be created * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps * in the `prevStateRDD` to create `this` RDD - * @param trackingFunction The function that will be used to update state and return new data + * @param mappingFunction The function that will be used to update state and return new data * @param batchTime The time of the batch to which this RDD belongs to. Use to update * @param timeoutThresholdTime The time to indicate which keys are timeout */ -private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( - private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]], +private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]], private var partitionedDataRDD: RDD[(K, V)], - trackingFunction: (Time, K, Option[V], State[S]) => Option[E], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long] - ) extends RDD[TrackStateRDDRecord[K, S, E]]( + ) extends RDD[MapWithStateRDDRecord[K, S, E]]( partitionedDataRDD.sparkContext, List( - new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD), + new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD), new OneToOneDependency(partitionedDataRDD)) ) { @@ -141,19 +142,19 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: } override def compute( - partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = { + partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = { - val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition] val prevStateRDDIterator = prevStateRDD.iterator( stateRDDPartition.previousSessionRDDPartition, context) val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None - val newRecord = TrackStateRDDRecord.updateRecordWithData( + val newRecord = MapWithStateRDDRecord.updateRecordWithData( prevRecord, dataIterator, - trackingFunction, + mappingFunction, batchTime, timeoutThresholdTime, removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled @@ -163,7 +164,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: override protected def getPartitions: Array[Partition] = { Array.tabulate(prevStateRDD.partitions.length) { i => - new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} } override def clearDependencies(): Unit = { @@ -177,52 +178,46 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: } } -private[streaming] object TrackStateRDD { +private[streaming] object MapWithStateRDD { def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, E] = { + updateTime: Time): MapWithStateRDD[K, V, S, E] = { - val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) } def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( rdd: RDD[(K, S, Long)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, E] = { + updateTime: Time): MapWithStateRDD[K, V, S, E] = { val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } - val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, (state, updateTime)) => stateMap.put(key, state, updateTime) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) - } -} - -private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) { - override protected def getPartitions: Array[Partition] = parent.partitions - override def compute(partition: Partition, context: TaskContext): Iterator[T] = { - parent.compute(partition, context).flatMap { _.emittedRecords } + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java similarity index 80% rename from streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java rename to streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index 89d0bb7b617e46b31f2f1a012dea5ea391cb2845..bc4bc2eb42231d5bcd5eb4d28b656f1189829c3b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -37,12 +37,12 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.Function3; import org.apache.spark.api.java.function.Function4; import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaTrackStateDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; -public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable { +public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable { /** * This test is only for testing the APIs. It's not necessary to run it. @@ -52,7 +52,7 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen JavaPairDStream<String, Integer> wordsDstream = null; final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>> - trackStateFunc = + mappingFunc = new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() { @Override @@ -68,21 +68,21 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen } }; - JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream = - wordsDstream.trackStateByKey( - StateSpec.function(trackStateFunc) + JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream = + wordsDstream.mapWithState( + StateSpec.function(mappingFunc) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots(); + JavaPairDStream<String, Boolean> stateSnapshots = stateDstream.stateSnapshots(); - final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 = - new Function2<Optional<Integer>, State<Boolean>, Double>() { + final Function3<String, Optional<Integer>, State<Boolean>, Double> mappingFunc2 = + new Function3<String, Optional<Integer>, State<Boolean>, Double>() { @Override - public Double call(Optional<Integer> one, State<Boolean> state) { + public Double call(String key, Optional<Integer> one, State<Boolean> state) { // Use all State's methods here state.exists(); state.get(); @@ -93,15 +93,15 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen } }; - JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 = - wordsDstream.trackStateByKey( - StateSpec.<String, Integer, Boolean, Double>function(trackStateFunc2) + JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 = + wordsDstream.mapWithState( + StateSpec.<String, Integer, Boolean, Double>function(mappingFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots(); + JavaPairDStream<String, Boolean> stateSnapshots2 = stateDstream2.stateSnapshots(); } @Test @@ -148,11 +148,11 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen new Tuple2<String, Integer>("c", 1)) ); - Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc = - new Function2<Optional<Integer>, State<Integer>, Integer>() { + Function3<String, Optional<Integer>, State<Integer>, Integer> mappingFunc = + new Function3<String, Optional<Integer>, State<Integer>, Integer>() { @Override - public Integer call(Optional<Integer> value, State<Integer> state) throws Exception { + public Integer call(String key, Optional<Integer> value, State<Integer> state) throws Exception { int sum = value.or(0) + (state.exists() ? state.get() : 0); state.update(sum); return sum; @@ -160,29 +160,29 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen }; testOperation( inputData, - StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc), + StateSpec.<String, Integer, Integer, Integer>function(mappingFunc), outputData, stateData); } private <K, S, T> void testOperation( List<List<K>> input, - StateSpec<K, Integer, S, T> trackStateSpec, + StateSpec<K, Integer, S, T> mapWithStateSpec, List<Set<T>> expectedOutputs, List<Set<Tuple2<K, S>>> expectedStateSnapshots) { int numBatches = expectedOutputs.size(); JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); - JavaTrackStateDStream<K, Integer, S, T> trackeStateStream = + JavaMapWithStateDStream<K, Integer, S, T> mapWithStateDStream = JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() { @Override public Tuple2<K, Integer> call(K x) throws Exception { return new Tuple2<K, Integer>(x, 1); } - })).trackStateByKey(trackStateSpec); + })).mapWithState(mapWithStateSpec); final List<Set<T>> collectedOutputs = Collections.synchronizedList(Lists.<Set<T>>newArrayList()); - trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() { + mapWithStateDStream.foreachRDD(new Function<JavaRDD<T>, Void>() { @Override public Void call(JavaRDD<T> rdd) throws Exception { collectedOutputs.add(Sets.newHashSet(rdd.collect())); @@ -191,7 +191,7 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen }); final List<Set<Tuple2<K, S>>> collectedStateSnapshots = Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList()); - trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() { + mapWithStateDStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() { @Override public Void call(JavaPairRDD<K, S> rdd) throws Exception { collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala similarity index 77% rename from streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 1fc320d31b18b186a3e4d2e21aa1a4732fe60d82..4b08085e09b1f05b487b1fbd8a56b789150f25ca 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -25,11 +25,11 @@ import scala.reflect.ClassTag import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class TrackStateByKeySuite extends SparkFunSuite +class MapWithStateSuite extends SparkFunSuite with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null @@ -49,7 +49,7 @@ class TrackStateByKeySuite extends SparkFunSuite } override def beforeAll(): Unit = { - val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite") + val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) sc = new SparkContext(conf) } @@ -129,7 +129,7 @@ class TrackStateByKeySuite extends SparkFunSuite testState(Some(3), shouldBeTimingOut = true) } - test("trackStateByKey - basic operations with simple API") { + test("mapWithState - basic operations with simple API") { val inputData = Seq( Seq(), @@ -164,17 +164,17 @@ class TrackStateByKeySuite extends SparkFunSuite ) // state maintains running count, and updated count is returned - val trackStateFunc = (value: Option[Int], state: State[Int]) => { + val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => { val sum = value.getOrElse(0) + state.getOption.getOrElse(0) state.update(sum) sum } testOperation[String, Int, Int]( - inputData, StateSpec.function(trackStateFunc), outputData, stateData) + inputData, StateSpec.function(mappingFunc), outputData, stateData) } - test("trackStateByKey - basic operations with advanced API") { + test("mapWithState - basic operations with advanced API") { val inputData = Seq( Seq(), @@ -209,65 +209,65 @@ class TrackStateByKeySuite extends SparkFunSuite ) // state maintains running count, key string doubled and returned - val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { val sum = value.getOrElse(0) + state.getOption.getOrElse(0) state.update(sum) Some(key * 2) } - testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) } - test("trackStateByKey - type inferencing and class tags") { + test("mapWithState - type inferencing and class tags") { - // Simple track state function with value as Int, state as Double and emitted type as Double - val simpleFunc = (value: Option[Int], state: State[Double]) => { + // Simple track state function with value as Int, state as Double and mapped type as Double + val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => { 0L } // Advanced track state function with key as String, value as Int, state as Double and - // emitted type as Double + // mapped type as Double val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { Some(0L) } - def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = { - val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]] + def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]] assert(dstreamImpl.keyClass === classOf[String]) assert(dstreamImpl.valueClass === classOf[Int]) assert(dstreamImpl.stateClass === classOf[Double]) - assert(dstreamImpl.emittedClass === classOf[Long]) + assert(dstreamImpl.mappedClass === classOf[Long]) } val ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) - // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types - val simpleFunctionStateStream1 = inputStream.trackStateByKey( + // Defining StateSpec inline with mapWithState and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.mapWithState( StateSpec.function(simpleFunc).numPartitions(1)) testTypes(simpleFunctionStateStream1) // Separately defining StateSpec with simple function requires explicitly specifying types val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) - val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec) + val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec) testTypes(simpleFunctionStateStream2) // Separately defining StateSpec with advanced function implicitly gets the types val advFuncSpec1 = StateSpec.function(advancedFunc) - val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1) + val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1) testTypes(advFunctionStateStream1) - // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types - val advFunctionStateStream2 = inputStream.trackStateByKey( + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.mapWithState( StateSpec.function(simpleFunc).numPartitions(1)) testTypes(advFunctionStateStream2) - // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) - val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2) + val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2) testTypes(advFunctionStateStream3) } - test("trackStateByKey - states as emitted records") { + test("mapWithState - states as mapped data") { val inputData = Seq( Seq(), @@ -301,17 +301,17 @@ class TrackStateByKeySuite extends SparkFunSuite Seq(("a", 5), ("b", 3), ("c", 1)) ) - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { val sum = value.getOrElse(0) + state.getOption.getOrElse(0) val output = (key, sum) state.update(sum) Some(output) } - testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) } - test("trackStateByKey - initial states, with nothing emitted") { + test("mapWithState - initial states, with nothing returned as from mapping function") { val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) @@ -339,18 +339,18 @@ class TrackStateByKeySuite extends SparkFunSuite Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) ) - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { val sum = value.getOrElse(0) + state.getOption.getOrElse(0) val output = (key, sum) state.update(sum) None.asInstanceOf[Option[Int]] } - val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState)) - testOperation(inputData, trackStateSpec, outputData, stateData) + val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, mapWithStateSpec, outputData, stateData) } - test("trackStateByKey - state removing") { + test("mapWithState - state removing") { val inputData = Seq( Seq(), @@ -388,7 +388,7 @@ class TrackStateByKeySuite extends SparkFunSuite Seq() ) - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { if (state.exists) { state.remove() Some(key) @@ -399,10 +399,10 @@ class TrackStateByKeySuite extends SparkFunSuite } testOperation( - inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData) + inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData) } - test("trackStateByKey - state timing out") { + test("mapWithState - state timing out") { val inputData = Seq( Seq("a", "b", "c"), @@ -413,7 +413,7 @@ class TrackStateByKeySuite extends SparkFunSuite Seq("a") // a will not time out ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { if (value.isDefined) { state.update(1) } @@ -425,9 +425,9 @@ class TrackStateByKeySuite extends SparkFunSuite } val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( - inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20) + inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20) - // b and c should be emitted once each, when they were marked as expired + // b and c should be returned once each, when they were marked as expired assert(collectedOutputs.flatten.sorted === Seq("b", "c")) // States for a, b, c should be defined at one point of time @@ -439,8 +439,8 @@ class TrackStateByKeySuite extends SparkFunSuite assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) } - test("trackStateByKey - checkpoint durations") { - val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) + test("mapWithState - checkpoint durations") { + val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream) def testCheckpointDuration( batchDuration: Duration, @@ -451,18 +451,18 @@ class TrackStateByKeySuite extends SparkFunSuite try { val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) - val dummyFunc = (value: Option[Int], state: State[Int]) => 0 - val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) - val internalTrackStateStream = trackStateStream invokePrivate privateMethod() + val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0 + val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc)) + val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod() explicitCheckpointDuration.foreach { d => - trackStateStream.checkpoint(d) + mapWithStateStream.checkpoint(d) } - trackStateStream.register() + mapWithStateStream.register() ssc.checkpoint(checkpointDir.toString) ssc.start() // should initialize all the checkpoint durations - assert(trackStateStream.checkpointDuration === null) - assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) + assert(mapWithStateStream.checkpointDuration === null) + assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration) } finally { ssc.stop(stopSparkContext = false) } @@ -478,7 +478,7 @@ class TrackStateByKeySuite extends SparkFunSuite } - test("trackStateByKey - driver failure recovery") { + test("mapWithState - driver failure recovery") { val inputData = Seq( Seq(), @@ -505,16 +505,16 @@ class TrackStateByKeySuite extends SparkFunSuite val checkpointDuration = batchDuration * (stateData.size / 2) - val runningCount = (value: Option[Int], state: State[Int]) => { + val runningCount = (key: String, value: Option[Int], state: State[Int]) => { state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) state.get() } - val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( + val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState( StateSpec.function(runningCount)) // Set internval make sure there is one RDD checkpointing - trackStateStream.checkpoint(checkpointDuration) - trackStateStream.stateSnapshots() + mapWithStateStream.checkpoint(checkpointDuration) + mapWithStateStream.stateSnapshots() } testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, @@ -523,28 +523,28 @@ class TrackStateByKeySuite extends SparkFunSuite private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], - trackStateSpec: StateSpec[K, Int, S, T], + mapWithStateSpec: StateSpec[K, Int, S, T], expectedOutputs: Seq[Seq[T]], expectedStateSnapshots: Seq[Seq[(K, S)]] ): Unit = { require(expectedOutputs.size == expectedStateSnapshots.size) val (collectedOutputs, collectedStateSnapshots) = - getOperationOutput(input, trackStateSpec, expectedOutputs.size) + getOperationOutput(input, mapWithStateSpec, expectedOutputs.size) assert(expectedOutputs, collectedOutputs, "outputs") assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") } private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], - trackStateSpec: StateSpec[K, Int, S, T], + mapWithStateSpec: StateSpec[K, Int, S, T], numBatches: Int ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { // Setup the stream computation val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) - val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) + val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec) val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala similarity index 76% rename from streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index 3b2d43f2ce5816c890aa4bae9c56d521869f9ac9..aa95bd33dda9f381d3ce993ce94f1bfc9e7eb2cc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -30,14 +30,14 @@ import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.streaming.{State, Time} import org.apache.spark.util.Utils -class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { +class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { private var sc: SparkContext = null private var checkpointDir: File = _ override def beforeAll(): Unit = { sc = new SparkContext( - new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) checkpointDir = Utils.createTempDir() sc.setCheckpointDir(checkpointDir.toString) } @@ -54,7 +54,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef test("creation from pair RDD") { val data = Seq((1, "1"), (2, "2"), (3, "3")) val partitioner = new HashPartitioner(10) - val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int]( + val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int]( sc.parallelize(data), partitioner, Time(123)) assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) assert(rdd.partitions.size === partitioner.numPartitions) @@ -62,7 +62,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef assert(rdd.partitioner === Some(partitioner)) } - test("updating state and generating emitted data in TrackStateRecord") { + test("updating state and generating mapped data in MapWithStateRDDRecord") { val initialTime = 1000L val updatedTime = 2000L @@ -71,7 +71,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef /** * Assert that applying given data on a prior record generates correct updated record, with - * correct state map and emitted data + * correct state map and mapped data */ def assertRecordUpdate( initStates: Iterable[Int], @@ -86,18 +86,18 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() initStates.foreach { s => initialStateMap.put("key", s, initialTime) } functionCalled = false - val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) + val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) val dataIterator = data.map { v => ("key", v) }.iterator val removedStates = new ArrayBuffer[Int] val timingOutStates = new ArrayBuffer[Int] /** - * Tracking function that updates/removes state based on instructions in the data, and + * Mapping function that updates/removes state based on instructions in the data, and * return state (when instructed or when state is timing out). */ def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { functionCalled = true - assert(t.milliseconds === updatedTime, "tracking func called with wrong time") + assert(t.milliseconds === updatedTime, "mapping func called with wrong time") data match { case Some("noop") => @@ -120,22 +120,22 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef } } - val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int]( + val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int]( Some(record), dataIterator, testFunc, Time(updatedTime), timeoutThreshold, removeTimedoutData) val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } assert(updatedStateData.toSet === expectedStates.toSet, - "states do not match after updating the TrackStateRecord") + "states do not match after updating the MapWithStateRDDRecord") - assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet, - "emitted data do not match after updating the TrackStateRecord") + assert(updatedRecord.mappedData.toSet === expectedOutput.toSet, + "mapped data do not match after updating the MapWithStateRDDRecord") assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + - "match those that were expected to do so while updating the TrackStateRecord") + "match those that were expected to do so while updating the MapWithStateRDDRecord") assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + - "match those that were expected to do so while updating the TrackStateRecord") + "match those that were expected to do so while updating the MapWithStateRDDRecord") } @@ -187,12 +187,12 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef } - test("states generated by TrackStateRDD") { + test("states generated by MapWithStateRDD") { val initStates = Seq(("k1", 0), ("k2", 0)) val initTime = 123 val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet val partitioner = new HashPartitioner(2) - val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int]( + val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int]( sc.parallelize(initStates), partitioner, Time(initTime)).persist() assertRDD(initStateRDD, initStateWthTime, Set.empty) @@ -203,21 +203,21 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef * creates a new state RDD with expected states */ def testStateUpdates( - testStateRDD: TrackStateRDD[String, Int, Int, Int], + testStateRDD: MapWithStateRDD[String, Int, Int, Int], testData: Seq[(String, Int)], - expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = { + expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = { - // Persist the test TrackStateRDD so that its not recomputed while doing the next operation. - // This is to make sure that we only track which state keys are being touched in the next op. + // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only touch which state keys are being touched in the next op. testStateRDD.persist().count() // To track which keys are being touched - TrackStateRDDSuite.touchedStateKeys.clear() + MapWithStateRDDSuite.touchedStateKeys.clear() - val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => { + val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => { // Track the key that has been touched - TrackStateRDDSuite.touchedStateKeys += key + MapWithStateRDDSuite.touchedStateKeys += key // If the data is 0, do not do anything with the state // else if the data is 1, increment the state if it exists, or set new state to 0 @@ -236,12 +236,12 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef // Assert that the new state RDD has expected state data val newStateRDD = assertOperation( - testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty) + testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty) // Assert that the function was called only for the keys present in the data - assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size, + assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size, "More number of keys are being touched than that is expected") - assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, "Keys not in the data are being touched unexpectedly") // Assert that the test RDD's data has not changed @@ -289,19 +289,19 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef test("checkpointing") { /** - * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs - - * the data RDD and the parent TrackStateRDD. + * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs + * - the data RDD and the parent MapWithStateRDD. */ - def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]]) + def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]]) : Set[(List[(Int, Int, Long)], List[Int])] = { - rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) } + rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) } .collect.toSet } - /** Generate TrackStateRDD with data RDD having a long lineage */ + /** Generate MapWithStateRDD with data RDD having a long lineage */ def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) - : TrackStateRDD[Int, Int, Int, Int] = { - TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + : MapWithStateRDD[Int, Int, Int, Int] = { + MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) } testRDD( @@ -309,15 +309,15 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef testRDDPartitions( makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) - /** Generate TrackStateRDD with parent state RDD having a long lineage */ + /** Generate MapWithStateRDD with parent state RDD having a long lineage */ def makeStateRDDWithLongLineageParenttateRDD( - longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = { + longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = { - // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage + // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) - // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent - new TrackStateRDD[Int, Int, Int, Int]( + // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent + new MapWithStateRDD[Int, Int, Int, Int]( stateRDDWithLongLineage, stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, @@ -333,25 +333,25 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef } test("checkpointing empty state RDD") { - val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( + val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int]( sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) emptyStateRDD.checkpoint() assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) - val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( + val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]]( emptyStateRDD.getCheckpointFile.get) assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) } - /** Assert whether the `trackStateByKey` operation generates expected results */ + /** Assert whether the `mapWithState` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - testStateRDD: TrackStateRDD[K, V, S, T], + testStateRDD: MapWithStateRDD[K, V, S, T], newDataRDD: RDD[(K, V)], - trackStateFunc: (Time, K, Option[V], State[S]) => Option[T], + mappingFunction: (Time, K, Option[V], State[S]) => Option[T], currentTime: Long, expectedStates: Set[(K, S, Int)], - expectedEmittedRecords: Set[T], + expectedMappedData: Set[T], doFullScan: Boolean = false - ): TrackStateRDD[K, V, S, T] = { + ): MapWithStateRDD[K, V, S, T] = { val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { newDataRDD.partitionBy(testStateRDD.partitioner.get) @@ -359,31 +359,31 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef newDataRDD } - val newStateRDD = new TrackStateRDD[K, V, S, T]( - testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None) + val newStateRDD = new MapWithStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None) if (doFullScan) newStateRDD.setFullScan() // Persist to make sure that it gets computed only once and we can track precisely how many // state keys the computing touched newStateRDD.persist().count() - assertRDD(newStateRDD, expectedStates, expectedEmittedRecords) + assertRDD(newStateRDD, expectedStates, expectedMappedData) newStateRDD } - /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */ + /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */ private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackStateRDD: TrackStateRDD[K, V, S, T], + stateRDD: MapWithStateRDD[K, V, S, T], expectedStates: Set[(K, S, Int)], - expectedEmittedRecords: Set[T]): Unit = { - val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet - val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet + expectedMappedData: Set[T]): Unit = { + val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet assert(states === expectedStates, - "states after track state operation were not as expected") - assert(emittedRecords === expectedEmittedRecords, - "emitted records after track state operation were not as expected") + "states after mapWithState operation were not as expected") + assert(mappedData === expectedMappedData, + "mapped data after mapWithState operation were not as expected") } } -object TrackStateRDDSuite { +object MapWithStateRDDSuite { private val touchedStateKeys = new ArrayBuffer[String]() }