diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f7a46876c0366ea8dd5d8c85a4841..be2ae0b47336372be57b77fa60141233ec1d7d15 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -44,18 +44,6 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.sum - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) @@ -71,9 +59,16 @@ object StatefulNetworkWordCount { val wordDstream = words.map(x => (x, 1)) // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + // This will give a DStream made of state (which is the cumulative count of the words) + val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOption.getOrElse(0) + val output = (word, sum) + state.update(sum) + Some(output) + } + + val stateDstream = wordDstream.trackStateByKey( + StateSpec.function(trackStateFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala new file mode 100644 index 0000000000000000000000000000000000000000..7dd1b72f804997db4aba35d6dbefd983c307c7cf --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of + * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Scala example of using `State`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example: + * {{{ + * TODO(@zsxwing) + * }}} + */ +@Experimental +sealed abstract class State[S] { + + /** Whether the state already exists */ + def exists(): Boolean + + /** + * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. + * Check with `exists()` whether the state exists or not before calling `get()`. + * + * @throws java.util.NoSuchElementException If the state does not exist. + */ + def get(): S + + /** + * Update the state with a new value. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + * + * @throws java.lang.IllegalArgumentException If the state has already been removed, or is + * going to be removed + */ + def update(newState: S): Unit + + /** + * Remove the state if it exists. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + */ + def remove(): Unit + + /** + * Whether the state is timing out and going to be removed by the system after the current batch. + * This timeout can occur if timeout duration has been specified in the + * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data + * for that timeout duration. + */ + def isTimingOut(): Boolean + + /** + * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. + */ + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None + + @inline final override def toString(): String = { + getOption.map { _.toString }.getOrElse("<state not set>") + } +} + +/** Internal implementation of the [[State]] interface */ +private[streaming] class StateImpl[S] extends State[S] { + + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = false + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + override def exists(): Boolean = { + defined + } + + override def get(): S = { + if (defined) { + state + } else { + throw new NoSuchElementException("State is not set") + } + } + + override def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + state = newState + defined = true + updated = true + } + + override def isTimingOut(): Boolean = { + timingOut + } + + override def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + require(!removed, "Cannot remove the state that has already been removed") + defined = false + updated = false + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + /** + * Update the internal data and flags in `this` to the given state option. + * This method allows `this` object to be reused across many state records. + */ + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + /** + * Update the internal data and flags in `this` to the given state that is going to be timed out. + * This method allows `this` object to be reused across many state records. + */ + def wrapTiminoutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala new file mode 100644 index 0000000000000000000000000000000000000000..c9fe35e74c1c7b01976933ffacc0fffa6053ac0f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.util.ClosureCleaner +import org.apache.spark.{HashPartitioner, Partitioner} + + +/** + * :: Experimental :: + * Abstract class representing all the specifications of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or + * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of + * this class. + * + * Example in Scala: + * {{{ + * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { + * ... + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * }}} + * + * Example in Java: + * {{{ + * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); + * + * JavaDStream[EmittedDataType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} + */ +@Experimental +sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { + + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + def initialState(rdd: RDD[(KeyType, StateType)]): this.type + + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type + + /** + * Set the number of partitions by which the state RDDs generated by `trackStateByKey` + * will be partitioned. Hash partitioning will be used. + */ + def numPartitions(numPartitions: Int): this.type + + /** + * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be + * be partitioned. + */ + def partitioner(partitioner: Partitioner): this.type + + /** + * Set the duration after which the state of an idle key will be removed. A key and its state is + * considered idle if it has not received any data for at least the given duration. The state + * tracking function will be called one final time on the idle states that are going to be + * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set + * to `true` in that call. + */ + def timeout(idleDuration: Duration): this.type +} + + +/** + * :: Experimental :: + * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] + * that is used for specifying the parameters of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Example in Scala: + * {{{ + * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { + * ... + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * }}} + * + * Example in Java: + * {{{ + * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); + * + * JavaDStream[EmittedDataType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} + */ +@Experimental +object StateSpec { + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * @param trackingFunction The function applied on every data item to manage the associated state + * and generate the emitted data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { + ClosureCleaner.clean(trackingFunction, checkSerializable = true) + new StateSpecImpl(trackingFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * @param trackingFunction The function applied on every data item to manage the associated state + * and generate the emitted data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (Option[ValueType], State[StateType]) => EmittedType + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { + ClosureCleaner.clean(trackingFunction, checkSerializable = true) + val wrappedFunction = + (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => { + Some(trackingFunction(value, state)) + } + new StateSpecImpl(wrappedFunction) + } +} + + +/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ +private[streaming] +case class StateSpecImpl[K, V, S, T]( + function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + override def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + + override def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + override def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + override def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 71bec96d46c8db48305f68f6d9020c844f4178f7..fb691eed27e327b68c22555450e53eed51ead7b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,19 +24,19 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.streaming._ import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} +import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) - extends Serializable -{ + extends Serializable { private[streaming] def ssc = self.ssc private[streaming] def sparkContext = self.context.sparkContext @@ -350,6 +350,44 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } + /** + * :: Experimental :: + * Return a new DStream of data generated by combining the key-value data in `this` stream + * with a continuously updated per-key state. The user-provided state tracking function is + * applied on each keyed data item along with its corresponding state. The function can choose to + * update/remove the state and return a transformed data, which forms the + * [[org.apache.spark.streaming.dstream.TrackStateDStream]]. + * + * The specifications of this transformation is made through the + * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there + * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. + * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details. + * + * Example of using `trackStateByKey`: + * {{{ + * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = { + * // Check if state exists, accordingly update/remove state and return transformed data + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * + * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec) + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state + * @tparam EmittedType Class type of the tranformed data return by the tracking function + */ + @Experimental + def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( + spec: StateSpec[K, V, StateType, EmittedType] + ): TrackStateDStream[K, V, StateType, EmittedType] = { + new TrackStateDStreamImpl[K, V, StateType, EmittedType]( + self, + spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] + ) + } + /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..58d89c93bcbefecbd07c6af5d2134c4242cc5a26 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} + +/** + * :: Experimental :: + * DStream representing the stream of records emitted by the tracking function in the + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * Additionally, it also gives access to the stream of state snapshots, that is, the state data of + * all keys after a batch has updated them. + * + * @tparam KeyType Class of the state key + * @tparam StateType Class of the state data + * @tparam EmittedType Class of the emitted records + */ +@Experimental +sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag]( + ssc: StreamingContext) extends DStream[EmittedType](ssc) { + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] +} + +/** Internal implementation of the [[TrackStateDStream]] */ +private[streaming] class TrackStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag]( + dataStream: DStream[(KeyType, ValueType)], + spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType]) + extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) { + + private val internalStream = + new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec) + + override def slideDuration: Duration = internalStream.slideDuration + + override def dependencies: List[DStream[_]] = List(internalStream) + + override def compute(validTime: Time): Option[RDD[EmittedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } } + } + + /** + * Forward the checkpoint interval to the internal DStream that computes the state maps. This + * to make sure that this DStream does not get checkpointed, only the internal stream. + */ + override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = { + internalStream.checkpoint(checkpointInterval) + this + } + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] = { + internalStream.flatMap { + _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + } + + def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass + + def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass + + def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass + + def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass +} + +/** + * A DStream that allows per-key state to be maintains, and arbitrary records to be generated + * based on updates to the state. This is the main DStream that implements the `trackStateByKey` + * operation on DStreams. + * + * @param parent Parent (key, value) stream that is the source + * @param spec Specifications of the trackStateByKey operation + * @tparam K Key type + * @tparam V Value type + * @tparam S Type of the state maintained + * @tparam E Type of the emitted data + */ +private[streaming] +class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) + extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val trackingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + /** Enable automatic checkpointing */ + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { + // Get the previous state or create a new empty state RDD + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, validTime + ) + } + + // Compute the new state RDD with previous state RDD and partitioned data RDD + parent.getOrCompute(validTime).map { dataRDD => + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed7cea26d06086a650571aa83f55b44473564f6e --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.streaming.{Time, StateImpl, State} +import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} +import org.apache.spark.util.Utils +import org.apache.spark._ + +/** + * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the tracking function of `trackStateByKey`. + */ +private[streaming] case class TrackStateRDDRecord[K, S, T]( + var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) + +/** + * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state + * RDD, and a partitioned keyed-data RDD + */ +private[streaming] class TrackStateRDDPartition( + idx: Int, + @transient private var prevStateRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[rdd] var previousSessionRDDPartition: Partition = null + private[rdd] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = prevStateRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + + +/** + * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records. + * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking + * function of `trackStateByKey`. + * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created + * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps + * in the `prevStateRDD` to create `this` RDD + * @param trackingFunction The function that will be used to update state and return new data + * @param batchTime The time of the batch to which this RDD belongs to. Use to update + */ +private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], + private var partitionedDataRDD: RDD[(K, V)], + trackingFunction: (Time, K, Option[V], State[S]) => Option[T], + batchTime: Time, timeoutThresholdTime: Option[Long] + ) extends RDD[TrackStateRDDRecord[K, S, T]]( + partitionedDataRDD.sparkContext, + List( + new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) + ) { + + @volatile private var doFullScan = false + + require(prevStateRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } + + override def compute( + partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = { + + val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + stateRDDPartition.partitionedDataRDDPartition, context) + + // Create a new state map by cloning the previous one (if it exists) or by creating an empty one + val newStateMap = if (prevStateRDDIterator.hasNext) { + prevStateRDDIterator.next().stateMap.copy() + } else { + new EmptyStateMap[K, S]() + } + + val emittedRecords = new ArrayBuffer[T] + val wrappedState = new StateImpl[S]() + + // Call the tracking function on each record in the data RDD partition, and accordingly + // update the states touched, and the data returned by the tracking function. + dataIterator.foreach { case (key, value) => + wrappedState.wrap(newStateMap.get(key)) + val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState) + if (wrappedState.isRemoved) { + newStateMap.remove(key) + } else if (wrappedState.isUpdated) { + newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) + } + emittedRecords ++= emittedRecord + } + + // If the RDD is expected to be doing a full scan of all the data in the StateMap, + // then use this opportunity to filter out those keys that have timed out. + // For each of them call the tracking function. + if (doFullScan && timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTiminoutState(state) + val emittedRecord = trackingFunction(batchTime, key, None, wrappedState) + emittedRecords ++= emittedRecord + newStateMap.remove(key) + } + } + + Iterator(TrackStateRDDRecord(newStateMap, emittedRecords)) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(prevStateRDD.partitions.length) { i => + new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + prevStateRDD = null + partitionedDataRDD = null + } + + def setFullScan(): Unit = { + doFullScan = true + } +} + +private[streaming] object TrackStateRDD { + + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Time): TrackStateRDD[K, V, S, T] = { + + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } +} + +private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) { + override protected def getPartitions: Array[Partition] = parent.partitions + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + parent.compute(partition, context).flatMap { _.emittedRecords } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed622ef7bf7007c45b3f02ecf76091b8835785ec --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ +import org.apache.spark.util.collection.OpenHashMap + +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { + + /** Get the state for a key if it exists */ + def get(key: K): Option[S] + + /** Get all the keys and states whose updated time is older than the given threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. + */ + def copy(): StateMap[K, S] + + def toDebugString(): String = toString() +} + +/** Companion object for [[StateMap]], with utility methods */ +private[streaming] object StateMap { + def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + } +} + +/** Implementation of StateMap interface representing an empty map */ +private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = { + throw new NotImplementedError("put() should not be called on an EmptyStateMap") + } + override def get(key: K): Option[S] = None + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = this + override def remove(key: K): Unit = { } + override def toDebugString(): String = "" +} + +/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ +private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( + @transient @volatile var parentStateMap: StateMap[K, S], + initialCapacity: Int = 64, + deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + ) extends StateMap[K, S] { self => + + def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) + + def this(deltaChainThreshold: Int) = this( + initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + + def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + + @transient @volatile private var deltaMap = + new OpenHashMap[K, StateInfo[S]](initialCapacity) + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + if (!stateInfo.deleted) { + Some(stateInfo.data) + } else { + None + } + } else { + parentStateMap.get(key) + } + } + + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) => + !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime + }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { + + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) + } else { + deltaMap.update(key, new StateInfo(state, updateTime)) + } + } + + /** Remove a state */ + override def remove(key: K): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() + } else { + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) + } + + /** Whether the delta chain lenght is long enough that it should be compacted */ + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold + } + + /** Length of the delta chains of this map */ + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 + case _ => 0 + } + + /** + * Approximate number of keys in the map. This is an overestimation that is mainly used to + * reserve capacity in a new map at delta compaction time. + */ + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize + case _ => 0 + } + } + + /** Get all the data of this map as string formatted as a tree based on the delta depth */ + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) + "+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + + override def toString(): String = { + s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]" + } + + /** + * Serialize the map data. Besides serialization, this method actually compact the deltas + * (if needed) in a single pass over all the data in the map. + */ + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + + // Write the data in the delta of this state map + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) + } + assert(deltaMapCount == deltaMap.size) + + // Write the data in the parent state map while copying the data into a new parent map for + // compaction (if needed) + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + val initCapacity = if (approxSize > 0) approxSize else 64 + new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentStateMap.getAll() + + var parentSessionCount = 0 + + // First write the approximate size of the data to be written, so that readObject can + // allocate appropriately sized OpenHashMap. + outputStream.writeInt(approxSize) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) + + if (doCompaction) { + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + + // Write the final limit marking object with the correct count of records written. + val limiterObj = new LimitMarker(parentSessionCount) + outputStream.writeObject(limiterObj) + if (doCompaction) { + parentStateMap = newParentSessionStore + } + } + + /** Deserialize the map data. */ + private def readObject(inputStream: ObjectInputStream): Unit = { + + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + + // Read the data of the delta + val deltaMapSize = inputStream.readInt() + deltaMap = new OpenHashMap[K, StateInfo[S]]() + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + + + // Read the data of the parent map. Keep reading records, until the limiter is reached + // First read the approximate number of records to expect and allocate properly size + // OpenHashMap + val parentSessionStoreSizeHint = inputStream.readInt() + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( + initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + + // Read the records until the limit marking object has been reached + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + if (obj.isInstanceOf[LimitMarker]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[LimitMarker].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + parentStateMap = newParentSessionStore + } +} + +/** + * Companion object of [[OpenHashMapBasedStateMap]] having associated helper + * classes and methods + */ +private[streaming] object OpenHashMapBasedStateMap { + + /** Internal class to represent the state information */ + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } + + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + /** + * Internal class to represent a marker the demarkate the the end of all state data in the + * serialized bytes. + */ + class LimitMarker(val num: Int) extends Serializable + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..48d3b41b66cbfec8c24854ca08d30d067530fa00 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.collection.{immutable, mutable, Map} +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} +import org.apache.spark.util.Utils + +class StateMapSuite extends SparkFunSuite { + + test("EmptyStateMap") { + val map = new EmptyStateMap[Int, Int] + intercept[scala.NotImplementedError] { + map.put(1, 1, 1) + } + assert(map.get(1) === None) + assert(map.getByTime(10000).isEmpty) + assert(map.getAll().isEmpty) + map.remove(1) // no exception + assert(map.copy().eq(map)) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { + val map = new OpenHashMapBasedStateMap[Int, Int]() + + map.put(1, 100, 10) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10))) + + map.put(2, 200, 20) + assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20))) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) + + map.remove(1) + assert(map.get(1) === None) + assert(map.getAll().toSet === Set((2, 200, 20))) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") { + val parentMap = new OpenHashMapBasedStateMap[Int, Int]() + parentMap.put(1, 100, 1) + parentMap.put(2, 200, 2) + parentMap.remove(1) + + // Create child map and make changes + val map = parentMap.copy() + assert(map.get(1) === None) + assert(map.get(2) === Some(200)) + assert(map.getByTime(10).toSet === Set((2, 200, 2))) + assert(map.getByTime(2).toSet === Set.empty) + assert(map.getAll().toSet === Set((2, 200, 2))) + + // Add new items + map.put(3, 300, 3) + assert(map.get(3) === Some(300)) + map.put(4, 400, 4) + assert(map.get(4) === Some(400)) + assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3))) + assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Remove items + map.remove(4) + assert(map.get(4) === None) // item added in this map, then removed in this map + map.remove(2) + assert(map.get(2) === None) // item removed in parent map, then added in this map + assert(map.getAll().toSet === Set((3, 300, 3))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Update items + map.put(1, 1000, 100) + assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map + map.put(2, 2000, 200) + assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map + map.put(3, 3000, 300) + assert(map.get(3) === Some(3000)) // item added + updated in this map + map.put(4, 4000, 400) + assert(map.get(4) === Some(4000)) // item removed + updated in this map + + assert(map.getAll().toSet === + Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + map.remove(2) // remove item present in parent map, so that its not visible in child map + + // Create child map and see availability of items + val childMap = map.copy() + assert(childMap.getAll().toSet === map.getAll().toSet) + assert(childMap.get(1) === Some(1000)) // item removed in grandparent, but added in parent map + assert(childMap.get(2) === None) // item added in grandparent, but removed in parent map + assert(childMap.get(3) === Some(3000)) // item added and updated in parent map + + childMap.put(2, 20000, 200) + assert(childMap.get(2) === Some(20000)) // item map + } + + test("OpenHashMapBasedStateMap - serializing and deserializing") { + val map1 = new OpenHashMapBasedStateMap[Int, Int]() + map1.put(1, 100, 1) + map1.put(2, 200, 2) + + val map2 = map1.copy() + map2.put(3, 300, 3) + map2.put(4, 400, 4) + + val map3 = map2.copy() + map3.put(3, 600, 3) + map3.remove(2) + + // Do not test compaction + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + + val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( + Utils.serialize(map3), Thread.currentThread().getContextClassLoader) + assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { + val targetDeltaLength = 10 + val deltaChainThreshold = 5 + + var map = new OpenHashMapBasedStateMap[Int, Int]( + deltaChainThreshold = deltaChainThreshold) + + // Make large delta chain with length more than deltaChainThreshold + for(i <- 1 to targetDeltaLength) { + map.put(Random.nextInt(), Random.nextInt(), 1) + map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + } + assert(map.deltaChainLength > deltaChainThreshold) + assert(map.shouldCompact === true) + + val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assert(deser_map.deltaChainLength < deltaChainThreshold) + assert(deser_map.shouldCompact === false) + assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") + } + + test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { + /* + * This tests the map using all permutations of sequences operations, across multiple map + * copies as well as between copies. It is to ensure complete coverage, though it is + * kind of hard to debug this. It is set up as follows. + * + * - For any key, there can be 2 types of update ops on a state map - put or remove + * + * - These operations are done on a test map in "sets". After each set, the map is "copied" + * to create a new map, and the next set of operations are done on the new one. This tests + * whether the map data persistes correctly across copies. + * + * - Within each set, there are a number of operations to test whether the map correctly + * updates and removes data without affecting the parent state map. + * + * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types + * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence + * of operations, which we will test with different keys. + * + * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that + * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys. + * _______________________________________________ + * | | Set1 | Set2 | + * | |-----------------|-----------------| + * | | Op1 Op2 |c| Op3 Op4 | + * |---------|----------------|o|----------------| + * | key 0 | put put |p| put put | + * | key 1 | put put |y| put rem | + * | key 2 | put put | | rem put | + * | key 3 | put put |t| rem rem | + * | key 4 | put rem |h| put put | + * | key 5 | put rem |e| put rem | + * | key 6 | put rem | | rem put | + * | key 7 | put rem |s| rem rem | + * | key 8 | rem put |t| put put | + * | key 9 | rem put |a| put rem | + * | key 10 | rem put |t| rem put | + * | key 11 | rem put |e| rem rem | + * | key 12 | rem rem | | put put | + * | key 13 | rem rem |m| put rem | + * | key 14 | rem rem |a| rem put | + * | key 15 | rem rem |p| rem rem | + * |_________|________________|_|________________| + */ + + val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value + val numSets = 3 + val numOpsPerSet = 3 // to test seq of ops like update -> remove -> update in same set + val numTotalOps = numOpsPerSet * numSets + val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops + + val refMap = new mutable.HashMap[Int, (Int, Long)]() + var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null + + var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() + var prevSetStateMap: StateMap[Int, Int] = null + + var time = 1L + + for (setId <- 0 until numSets) { + for (opInSetId <- 0 until numOpsPerSet) { + val opId = setId * numOpsPerSet + opInSetId + for (keyId <- 0 until numKeys) { + time += 1 + // Find the operation type that needs to be done + // This is similar to finding the nth bit value of a binary number + // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 + val opCode = + (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps + opCode match { + case 0 => + val value = Random.nextInt() + stateMap.put(keyId, value, time) + refMap.put(keyId, (value, time)) + case 1 => + stateMap.remove(keyId) + refMap.remove(keyId) + } + } + + // Test whether the current state map after all key updates is correct + assertMap(stateMap, refMap, time, "State map does not match reference map") + + // Test whether the previous map before copy has not changed + if (prevSetStateMap != null && prevSetRefMap != null) { + assertMap(prevSetStateMap, prevSetRefMap, time, + "Parent state map somehow got modified, does not match corresponding reference map") + } + } + + // Copy the map and remember the previous maps for future tests + prevSetStateMap = stateMap + prevSetRefMap = refMap.toMap + stateMap = stateMap.copy() + + // Assert that the copied map has the same data + assertMap(stateMap, prevSetRefMap, time, + "State map does not match reference map after copying") + } + assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") + } + + // Assert whether all the data and operations on a state map matches that of a reference state map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: StateMap[Int, Int], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.getAll().map { _._1 }) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId)) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + assert(mapToTest.getByTime(t).toSet === refMapToTestWith.getByTime(t).toSet) + } + } + } + + // Assert whether all the data and operations on a state map matches that of a reference map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: Map[Int, (Int, Long)], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === + refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.keys) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 }) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + val expectedRecords = + refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) } + assert(mapToTest.getByTime(t).toSet === expectedRecords.toSet) + } + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e3072b44428402f87e70d2f84dcde191446ec5f8 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.reflect.ClassTag + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { + + private var sc: SparkContext = null + private var ssc: StreamingContext = null + private var checkpointDir: File = null + private val batchDuration = Seconds(1) + + before { + StreamingContext.getActive().foreach { + _.stop(stopSparkContext = false) + } + checkpointDir = Utils.createTempDir("checkpoint") + + ssc = new StreamingContext(sc, batchDuration) + ssc.checkpoint(checkpointDir.toString) + } + + after { + StreamingContext.getActive().foreach { + _.stop(stopSparkContext = false) + } + } + + override def beforeAll(): Unit = { + val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + sc = new SparkContext(conf) + } + + test("state - get, exists, update, remove, ") { + var state: StateImpl[Int] = null + + def testState( + expectedData: Option[Int], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false, + shouldBeTimingOut: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + assert(state.getOption.getOrElse(-1) === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + assert(state.getOption.getOrElse(-1) === -1) + } + + assert(state.isTimingOut() === shouldBeTimingOut) + if (shouldBeTimingOut) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + + assert(state.isUpdated() === shouldBeUpdated) + + assert(state.isRemoved() === shouldBeRemoved) + if (shouldBeRemoved) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + } + + state = new StateImpl[Int]() + testState(None) + + state.wrap(None) + testState(None) + + state.wrap(Some(1)) + testState(Some(1)) + + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state = new StateImpl[Int]() + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state.remove() + testState(None, shouldBeRemoved = true) + + state.wrapTiminoutState(3) + testState(Some(3), shouldBeTimingOut = true) + } + + test("trackStateByKey - basic operations with simple API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(1), + Seq(2, 1), + Seq(3, 2, 1), + Seq(4, 3), + Seq(5), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, and updated count is returned + val trackStateFunc = (value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + sum + } + + testOperation[String, Int, Int]( + inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } + + test("trackStateByKey - basic operations with advanced API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq("aa"), + Seq("aa", "bb"), + Seq("aa", "bb", "cc"), + Seq("aa", "bb"), + Seq("aa"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, key string doubled and returned + val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + Some(key * 2) + } + + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } + + test("trackStateByKey - type inferencing and class tags") { + + // Simple track state function with value as Int, state as Double and emitted type as Double + val simpleFunc = (value: Option[Int], state: State[Double]) => { + 0L + } + + // Advanced track state function with key as String, value as Int, state as Double and + // emitted type as Double + val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { + Some(0L) + } + + def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]] + assert(dstreamImpl.keyClass === classOf[String]) + assert(dstreamImpl.valueClass === classOf[Int]) + assert(dstreamImpl.stateClass === classOf[Double]) + assert(dstreamImpl.emittedClass === classOf[Long]) + } + + val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) + + // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.trackStateByKey( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(simpleFunctionStateStream1) + + // Separately defining StateSpec with simple function requires explicitly specifying types + val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) + val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec) + testTypes(simpleFunctionStateStream2) + + // Separately defining StateSpec with advanced function implicitly gets the types + val advFuncSpec1 = StateSpec.function(advancedFunc) + val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1) + testTypes(advFunctionStateStream1) + + // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.trackStateByKey( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(advFunctionStateStream2) + + // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) + val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2) + testTypes(advFunctionStateStream3) + } + + test("trackStateByKey - states as emitted records") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } + + test("trackStateByKey - initial states, with nothing emitted") { + + val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) + + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) + + val stateData = + Seq( + Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), + Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), + Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) + ) + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + None.asInstanceOf[Option[Int]] + } + + val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, trackStateSpec, outputData, stateData) + } + + test("trackStateByKey - state removing") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed + Seq() + ) + + // States that were removed + val outputData = + Seq( + Seq(), + Seq(), + Seq("a"), + Seq("b"), + Seq("a", "c"), + Seq("b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("b", 1)), + Seq(("a", 1), ("c", 1)), + Seq(("b", 1)), + Seq(("a", 1)), + Seq(), + Seq() + ) + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (state.exists) { + state.remove() + Some(key) + } else { + state.update(value.get) + None + } + } + + testOperation( + inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData) + } + + test("trackStateByKey - state timing out") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (value.isDefined) { + state.update(1) + } + if (state.isTimingOut) { + Some(key) + } else { + None + } + } + + val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( + inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20) + + // b and c should be emitted once each, when they were marked as expired + assert(collectedOutputs.flatten.sorted === Seq("b", "c")) + + // States for a, b, c should be defined at one point of time + assert(collectedStateSnapshots.exists { + _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) + }) + + // Finally state should be defined only for a + assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) + } + + + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + trackStateSpec: StateSpec[K, Int, S, T], + expectedOutputs: Seq[Seq[T]], + expectedStateSnapshots: Seq[Seq[(K, S)]] + ): Unit = { + require(expectedOutputs.size == expectedStateSnapshots.size) + + val (collectedOutputs, collectedStateSnapshots) = + getOperationOutput(input, trackStateSpec, expectedOutputs.size) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + trackStateSpec: StateSpec[K, Int, S, T], + numBatches: Int + ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { + + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) + val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) + val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] + val stateSnapshotStream = new TestOutputStream( + trackeStateStream.stateSnapshots(), collectedStateSnapshots) + outputStream.register() + stateSnapshotStream.register() + + val batchCounter = new BatchCounter(ssc) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds * numBatches) + + batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + (collectedOutputs, collectedStateSnapshots) + } + + private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + val debugString = "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") + assert(expected.size === collected.size, + s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + + debugString) + expected.zip(collected).foreach { case (c, e) => + assert(c.toSet === e.toSet, + s"collected $typ is different from expected $debugString" + ) + } + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..fc5f26607ef98f74e4cdfadbbdda8922580d4093 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Time, State} +import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} + +class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var sc = new SparkContext( + new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + + override def afterAll(): Unit = { + sc.stop() + } + + test("creation from pair RDD") { + val data = Seq((1, "1"), (2, "2"), (3, "3")) + val partitioner = new HashPartitioner(10) + val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int]( + sc.parallelize(data), partitioner, Time(123)) + assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) + assert(rdd.partitions.size === partitioner.numPartitions) + + assert(rdd.partitioner === Some(partitioner)) + } + + test("states generated by TrackStateRDD") { + val initStates = Seq(("k1", 0), ("k2", 0)) + val initTime = 123 + val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet + val partitioner = new HashPartitioner(2) + val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int]( + sc.parallelize(initStates), partitioner, Time(initTime)).persist() + assertRDD(initStateRDD, initStateWthTime, Set.empty) + + val updateTime = 345 + + /** + * Test that the test state RDD, when operated with new data, + * creates a new state RDD with expected states + */ + def testStateUpdates( + testStateRDD: TrackStateRDD[String, Int, Int, Int], + testData: Seq[(String, Int)], + expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = { + + // Persist the test TrackStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only track which state keys are being touched in the next op. + testStateRDD.persist().count() + + // To track which keys are being touched + TrackStateRDDSuite.touchedStateKeys.clear() + + val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => { + + // Track the key that has been touched + TrackStateRDDSuite.touchedStateKeys += key + + // If the data is 0, do not do anything with the state + // else if the data is 1, increment the state if it exists, or set new state to 0 + // else if the data is 2, remove the state if it exists + data match { + case Some(1) => + if (state.exists()) { state.update(state.get + 1) } + else state.update(0) + case Some(2) => + state.remove() + case _ => + } + None.asInstanceOf[Option[Int]] // Do not return anything, not being tested + } + val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) + + // Assert that the new state RDD has expected state data + val newStateRDD = assertOperation( + testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty) + + // Assert that the function was called only for the keys present in the data + assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size, + "More number of keys are being touched than that is expected") + assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + "Keys not in the data are being touched unexpectedly") + + // Assert that the test RDD's data has not changed + assertRDD(initStateRDD, initStateWthTime, Set.empty) + newStateRDD + } + + // Test no-op, no state should change + testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state + testStateUpdates( + initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state + testStateUpdates( + initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state + + // Test creation of new state + val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) + + val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) + + // Test updating of state + val rdd3 = testStateUpdates( + initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 + Set(("k1", 1, updateTime), ("k2", 0, initTime))) + + val rdd4 = testStateUpdates(rdd3, + Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) + + val rdd5 = testStateUpdates( + rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) + + // Test removing of state + val rdd6 = testStateUpdates( // should remove k1's state + initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) + + val rdd7 = testStateUpdates( // should remove k2's state + rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) + + val rdd8 = testStateUpdates( + rdd7, Seq(("k3", 2)), Set() // + ) + } + + /** Assert whether the `trackStateByKey` operation generates expected results */ + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + testStateRDD: TrackStateRDD[K, V, S, T], + newDataRDD: RDD[(K, V)], + trackStateFunc: (Time, K, Option[V], State[S]) => Option[T], + currentTime: Long, + expectedStates: Set[(K, S, Int)], + expectedEmittedRecords: Set[T], + doFullScan: Boolean = false + ): TrackStateRDD[K, V, S, T] = { + + val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { + newDataRDD.partitionBy(testStateRDD.partitioner.get) + } else { + newDataRDD + } + + val newStateRDD = new TrackStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None) + if (doFullScan) newStateRDD.setFullScan() + + // Persist to make sure that it gets computed only once and we can track precisely how many + // state keys the computing touched + newStateRDD.persist() + assertRDD(newStateRDD, expectedStates, expectedEmittedRecords) + newStateRDD + } + + /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */ + private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackStateRDD: TrackStateRDD[K, V, S, T], + expectedStates: Set[(K, S, Int)], + expectedEmittedRecords: Set[T]): Unit = { + val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet + assert(states === expectedStates, "states after track state operation were not as expected") + assert(emittedRecords === expectedEmittedRecords, + "emitted records after track state operation were not as expected") + } +} + +object TrackStateRDDSuite { + private val touchedStateKeys = new ArrayBuffer[String]() +}