diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c5d69c204642ede6e840a60e650a1708e9f29aea..c6f5cf641b8d535a552f7712b364efc16c4adb54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -552,6 +552,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_PROVIDER_CLASS = + buildConf("spark.sql.streaming.stateStore.providerClass") + .internal() + .doc( + "The class used to manage state data in stateful streaming queries. This class must " + + "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") + .stringConf + .createOptional + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") .internal() @@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS) + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 3ceb4cf84a413ae7942ae9dfd73a3da2991866a1..2aad8701a4eca7a799a0507f058bcf2f510381b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, groupingAttributes.toStructType, stateAttributes.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val updater = new StateStoreUpdater(store) @@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.filter { case (_, stateRow) => - val timeoutTimestamp = getTimeoutTimestamp(stateRow) + val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { case (keyRow, stateRow) => - callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) + timingOutKeys.flatMap { rowPair => + callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty } @@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec( * Call the user function on a key's data, update the state store, and return the return data * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. + * + * @param keyRow Row representing the key, cannot be null + * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty + * @param prevStateRow Row representing the previous state, can be null + * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow], - prevStateRowOption: Option[UnsafeRow], + prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = getStateObj(prevStateRowOption) + val stateObj = getStateObj(prevStateRow) val keyedState = GroupStateImpl.createForStreaming( - stateObjOption, + Option(stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec( numUpdatedStateRows += 1 } else { - val previousTimeoutTimestamp = prevStateRowOption match { - case Some(row) => getTimeoutTimestamp(row) - case None => NO_TIMESTAMP - } + val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { - prevStateRowOption.orNull + prevStateRow } val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp @@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException("Attempting to write empty state") } setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow.copy(), stateRowToWrite.copy()) + store.put(keyRow, stateRowToWrite) numUpdatedStateRows += 1 } } @@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec( } /** Returns the state as Java object if defined */ - def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = { - stateRowOption.map(getStateObjFromRow) + def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow != null) getStateObjFromRow(stateRow) else null } /** Returns the row for an updated state */ def getStateRow(obj: Any): UnsafeRow = { + assert(obj != null) getStateRowFromObj(obj) } /** Returns the timeout timestamp of a state row is set */ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP + if (isTimeoutEnabled && stateRow != null) { + stateRow.getLong(timeoutTimestampIndex) + } else NO_TIMESTAMP } /** Set the timestamp in a state row */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index fb2bf47d6e83bfaab6648d7575c62b15ed9d042a..67d86daf10812e668d0163b65d3af75110305962 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -67,13 +67,7 @@ import org.apache.spark.util.Utils * to ensure re-executed RDD operations re-apply updates on the correct past version of the * store. */ -private[state] class HDFSBackedStateStoreProvider( - val id: StateStoreId, - keySchema: StructType, - valueSchema: StructType, - storeConf: StateStoreConf, - hadoopConf: Configuration - ) extends StateStoreProvider with Logging { +private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging { // ConcurrentHashMap is used because it generates fail-safe iterators on filtering // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in @@ -95,92 +89,36 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() - @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - override def get(key: UnsafeRow): Option[UnsafeRow] = { - Option(mapToUpdate.get(key)) - } - - override def filter( - condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - mapToUpdate - .entrySet - .asScala - .iterator - .filter { entry => condition(entry.getKey, entry.getValue) } - .map { entry => (entry.getKey, entry.getValue) } + override def get(key: UnsafeRow): UnsafeRow = { + mapToUpdate.get(key) } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot put after already committed or aborted") - - val isNewKey = !mapToUpdate.containsKey(key) - mapToUpdate.put(key, value) - - Option(allUpdates.get(key)) match { - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added already, keep it marked as added - allUpdates.put(key, ValueAdded(key, value)) - case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) => - // Value existed in previous version and updated/removed, mark it as updated - allUpdates.put(key, ValueUpdated(key, value)) - case None => - // There was no prior update, so mark this as added or updated according to its presence - // in previous version. - val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) - allUpdates.put(key, update) - } - writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + val keyCopy = key.copy() + val valueCopy = value.copy() + mapToUpdate.put(keyCopy, valueCopy) + writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) } - /** Remove keys that match the following condition */ - override def remove(condition: UnsafeRow => Boolean): Unit = { + override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") - val entryIter = mapToUpdate.entrySet().iterator() - while (entryIter.hasNext) { - val entry = entryIter.next - if (condition(entry.getKey)) { - val value = entry.getValue - val key = entry.getKey - entryIter.remove() - - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } + val prevValue = mapToUpdate.remove(key) + if (prevValue != null) { + writeRemoveToDeltaFile(tempDeltaFileStream, key) } } - /** Remove a single key. */ - override def remove(key: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") - if (mapToUpdate.containsKey(key)) { - val value = mapToUpdate.remove(key) - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } + override def getRange( + start: Option[UnsafeRow], + end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + verify(state == UPDATING, "Cannot getRange after already committed or aborted") + iterator() } /** Commit all the updates that have been made to the store, and return the new version. */ @@ -227,20 +165,11 @@ private[state] class HDFSBackedStateStoreProvider( * Get an iterator of all the store data. * This can be called only after committing all the updates made in the current thread. */ - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, - "Cannot get iterator of store data before committing or after aborting") - HDFSBackedStateStoreProvider.this.iterator(newVersion) - } - - /** - * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing all the updates made in the current thread. - */ - override def updates(): Iterator[StoreUpdate] = { - verify(state == COMMITTED, - "Cannot get iterator of updates before committing or after aborting") - allUpdates.values().asScala.toIterator + override def iterator(): Iterator[UnsafeRowPair] = { + val unsafeRowPair = new UnsafeRowPair() + mapToUpdate.entrySet.asScala.iterator.map { entry => + unsafeRowPair.withRows(entry.getKey, entry.getValue) + } } override def numKeys(): Long = mapToUpdate.size() @@ -269,6 +198,23 @@ private[state] class HDFSBackedStateStoreProvider( store } + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], // for sorting the data + storeConf: StateStoreConf, + hadoopConf: Configuration): Unit = { + this.stateStoreId = stateStoreId + this.keySchema = keySchema + this.valueSchema = valueSchema + this.storeConf = storeConf + this.hadoopConf = hadoopConf + fs.mkdirs(baseDir) + } + + override def id: StateStoreId = stateStoreId + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { try { @@ -280,19 +226,27 @@ private[state] class HDFSBackedStateStoreProvider( } } + override def close(): Unit = { + loadedMaps.values.foreach(_.clear()) + } + override def toString(): String = { s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } - /* Internal classes and methods */ + /* Internal fields and methods */ - private val loadedMaps = new mutable.HashMap[Long, MapType] - private val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") - private val fs = baseDir.getFileSystem(hadoopConf) - private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + @volatile private var stateStoreId: StateStoreId = _ + @volatile private var keySchema: StructType = _ + @volatile private var valueSchema: StructType = _ + @volatile private var storeConf: StateStoreConf = _ + @volatile private var hadoopConf: Configuration = _ - initialize() + private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) @@ -323,35 +277,18 @@ private[state] class HDFSBackedStateStoreProvider( * Get iterator of all the data of the latest version of the store. * Note that this will look up the files to determined the latest known version. */ - private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet val versionsLoaded = loadedMaps.keySet val allKnownVersions = versionsInFiles ++ versionsLoaded + val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { - loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry => + unsafeRowTuple.withRows(entry.getKey, entry.getValue) } } else Iterator.empty } - /** Get iterator of a specific version of the store */ - private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { - loadMap(version).entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) - } - } - - /** Initialize the store provider */ - private def initialize(): Unit = { - try { - fs.mkdirs(baseDir) - } catch { - case e: IOException => - throw new IllegalStateException( - s"Cannot use ${id.checkpointLocation} for storing state data for $this: $e ", e) - } - } - /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { if (version <= 0) return new MapType @@ -367,32 +304,23 @@ private[state] class HDFSBackedStateStoreProvider( } } - private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { - - def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { - val keyBytes = key.getBytes() - val valueBytes = value.getBytes() - output.writeInt(keyBytes.size) - output.write(keyBytes) - output.writeInt(valueBytes.size) - output.write(valueBytes) - } - - def writeRemove(key: UnsafeRow): Unit = { - val keyBytes = key.getBytes() - output.writeInt(keyBytes.size) - output.write(keyBytes) - output.writeInt(-1) - } + private def writeUpdateToDeltaFile( + output: DataOutputStream, + key: UnsafeRow, + value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } - update match { - case ValueAdded(key, value) => - writeUpdate(key, value) - case ValueUpdated(key, value) => - writeUpdate(key, value) - case ValueRemoved(key, value) => - writeRemove(key) - } + private def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) } private def finalizeDeltaFile(output: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index eaa558eb6d0ed4263b9d315c9ed1504ba52f77b2..29c456f86e1edcfd65236092bd4b3b592910d26f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -29,15 +29,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ThreadUtils - - -/** Unique identifier for a [[StateStore]] */ -case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) +import org.apache.spark.util.{ThreadUtils, Utils} /** - * Base trait for a versioned key-value store used for streaming aggregations + * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific + * version of state data, and such instances are created through a [[StateStoreProvider]]. */ trait StateStore { @@ -47,50 +44,54 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** Get the current value of a key. */ - def get(key: UnsafeRow): Option[UnsafeRow] - /** - * Return an iterator of key-value pairs that satisfy a certain condition. - * Note that the iterator must be fail-safe towards modification to the store, that is, - * it must be based on the snapshot of store the time of this call, and any change made to the - * store while iterating through iterator should not cause the iterator to fail or have - * any affect on the values in the iterator. + * Get the current value of a non-null key. + * @return a non-null row if the key exists in the store, otherwise null. */ - def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] + def get(key: UnsafeRow): UnsafeRow - /** Put a new value for a key. */ + /** + * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in + * the params can be reused, and must make copies of the data as needed for persistence. + */ def put(key: UnsafeRow, value: UnsafeRow): Unit /** - * Remove keys that match the following condition. + * Remove a single non-null key. */ - def remove(condition: UnsafeRow => Boolean): Unit + def remove(key: UnsafeRow): Unit /** - * Remove a single key. + * Get key value pairs with optional approximate `start` and `end` extents. + * If the State Store implementation maintains indices for the data based on the optional + * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use + * `start` and `end` to make a best-effort scan over the data. Default implementation returns + * the full data scan iterator, which is correct but inefficient. Custom implementations must + * ensure that updates (puts, removes) can be made while iterating over this iterator. + * + * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value. + * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value. + * @return An iterator of key-value pairs that is guaranteed not miss any key between start and + * end, both inclusive. */ - def remove(key: UnsafeRow): Unit + def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + iterator() + } /** * Commit all the updates that have been made to the store, and return the new version. + * Implementations should ensure that no more updates (puts, removes) can be after a commit in + * order to avoid incorrect usage. */ def commit(): Long - /** Abort all the updates that have been made to the store. */ - def abort(): Unit - /** - * Iterator of store data after a set of updates have been committed. - * This can be called only after committing all the updates made in the current thread. + * Abort all the updates that have been made to the store. Implementations should ensure that + * no more updates (puts, removes) can be after an abort in order to avoid incorrect usage. */ - def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + def abort(): Unit - /** - * Iterator of the updates that have been committed. - * This can be called only after committing all the updates made in the current thread. - */ - def updates(): Iterator[StoreUpdate] + def iterator(): Iterator[UnsafeRowPair] /** Number of keys in the state store */ def numKeys(): Long @@ -102,28 +103,98 @@ trait StateStore { } -/** Trait representing a provider of a specific version of a [[StateStore]]. */ +/** + * Trait representing a provider that provide [[StateStore]] instances representing + * versions of state data. + * + * The life cycle of a provider and its provide stores are as follows. + * + * - A StateStoreProvider is created in a executor for each unique [[StateStoreId]] when + * the first batch of a streaming query is executed on the executor. All subsequent batches reuse + * this provider instance until the query is stopped. + * + * - Every batch of streaming data request a specific version of the state data by invoking + * `getStore(version)` which returns an instance of [[StateStore]] through which the required + * version of the data can be accessed. It is the responsible of the provider to populate + * this store with context information like the schema of keys and values, etc. + * + * - After the streaming query is stopped, the created provider instances are lazily disposed off. + */ trait StateStoreProvider { - /** Get the store with the existing version. */ + /** + * Initialize the provide with more contextual information from the SQL operator. + * This method will be called first after creating an instance of the StateStoreProvider by + * reflection. + * + * @param stateStoreId Id of the versioned StateStores that this provider will generate + * @param keySchema Schema of keys to be stored + * @param valueSchema Schema of value to be stored + * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by + * which the StateStore implementation could index the data. + * @param storeConfs Configurations used by the StateStores + * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data + */ + def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyIndexOrdinal: Option[Int], // for sorting the data by their keys + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit + + /** + * Return the id of the StateStores this provider will generate. + * Should be the same as the one passed in init(). + */ + def id: StateStoreId + + /** Called when the provider instance is unloaded from the executor */ + def close(): Unit + + /** Return an instance of [[StateStore]] representing state data of the given version */ def getStore(version: Long): StateStore - /** Optional method for providers to allow for background maintenance */ + /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } } - -/** Trait representing updates made to a [[StateStore]]. */ -sealed trait StoreUpdate { - def key: UnsafeRow - def value: UnsafeRow +object StateStoreProvider { + /** + * Return a provider instance of the given provider class. + * The instance will be already initialized. + */ + def instantiate( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], // for sorting the data + storeConf: StateStoreConf, + hadoopConf: Configuration): StateStoreProvider = { + val providerClass = storeConf.providerClass.map(Utils.classForName) + .getOrElse(classOf[HDFSBackedStateStoreProvider]) + val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] + provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + provider + } } -case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate -case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +/** Unique identifier for a bunch of keyed state data. */ +case class StateStoreId( + checkpointLocation: String, + operatorId: Long, + partitionId: Int, + name: String = "") -case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +/** Mutable, and reusable class for representing a pair of UnsafeRows. */ +class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { + def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = { + this.key = key + this.value = value + this + } +} /** @@ -185,6 +256,7 @@ object StateStore extends Logging { storeId: StateStoreId, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], version: Long, storeConf: StateStoreConf, hadoopConf: Configuration): StateStore = { @@ -193,7 +265,9 @@ object StateStore extends Logging { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( storeId, - new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) + StateStoreProvider.instantiate( + storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + ) reportActiveStoreInstance(storeId) provider } @@ -202,7 +276,7 @@ object StateStore extends Logging { /** Unload a state store provider */ def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId) + loadedProviders.remove(storeId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ @@ -216,6 +290,7 @@ object StateStore extends Logging { /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { + loadedProviders.keySet.foreach { key => unload(key) } loadedProviders.clear() _coordRef = null if (maintenanceTask != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index acfaa8e5eb3c4bcf2bf61e3a771c5efe082c874f..bab297c7df594e26479e4b1965af7ba87d94eaf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,16 +20,34 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +class StateStoreConf(@transient private val sqlConf: SQLConf) + extends Serializable { def this() = this(new SQLConf) - val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot - - val minVersionsToRetain = conf.minBatchesToRetain + /** + * Minimum number of delta files in a chain after which HDFSBackedStateStore will + * consider generating a snapshot. + */ + val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot + + /** Minimum versions a State Store implementation should retain to allow rollbacks */ + val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + + /** + * Optional fully qualified name of the subclass of [[StateStoreProvider]] + * managing state data. That is, the implementation of the State Store to use. + */ + val providerClass: Option[String] = sqlConf.stateStoreProviderClass + + /** + * Additional configurations related to state store. This will capture all configs in + * SQLConf that start with `spark.sql.streaming.stateStore.` */ + val confs: Map[String, String] = + sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) } -private[streaming] object StateStoreConf { +object StateStoreConf { val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index e16dda8a5b5640f9da9ae2dddf3298eb5ca3e47d..b744c25dc97a81e4ac741447cffb41c5ebd5e54a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -35,9 +35,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], sessionState: SessionState, @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) extends RDD[U](dataRDD) { @@ -45,21 +47,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( private val storeConf = new StateStoreConf(sessionState.conf) // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = dataRDD.context.broadcast( + private val hadoopConfBroadcast = dataRDD.context.broadcast( new SerializableConfiguration(sessionState.newHadoopConf())) override protected def getPartitions: Array[Partition] = dataRDD.partitions override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) storeCoordinator.flatMap(_.getLocation(storeId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 589042afb1e52ce646cd136675154e93c8d4dcd5..228fe86d59940416509f41e342a08fb36aeb0220 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -34,17 +34,21 @@ package object state { sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, - valueSchema: StructType)( + valueSchema: StructType, + indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( checkpointLocation, operatorId, + storeName, storeVersion, keySchema, valueSchema, + indexOrdinal, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator))( storeUpdateFunction) @@ -54,9 +58,11 @@ package object state { private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { @@ -69,14 +75,17 @@ package object state { }) cleanedF(store, iter) } + new StateStoreRDD( dataRDD, wrappedF, checkpointLocation, operatorId, + storeName, storeVersion, keySchema, valueSchema, + indexOrdinal, sessionState, storeCoordinator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 8dbda298c87bcae220dd33b830614beaa6e5c6e3..3e57f3fbada32ecd1cd93efe323a043aef38573b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.TimeUnit._ + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ @@ -61,11 +62,24 @@ trait StateStoreReader extends StatefulOperator { } /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { +trait StateStoreWriter extends StatefulOperator { self: SparkPlan => + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"), + "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"), + "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"), + "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes") + ) + + /** Records the duration of running `body` for the next query progress update. */ + protected def timeTakenMs(body: => Unit): Long = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + math.max(NANOSECONDS.toMillis(endTime - startTime), 0) + } } /** An operator that supports watermark. */ @@ -108,6 +122,16 @@ trait WatermarkSupport extends UnaryExecNode { /** Predicate based on the child output that matches data older than the watermark. */ lazy val watermarkPredicateForData: Option[Predicate] = watermarkExpression.map(newPredicate(_, child.output)) + + protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + store.getRange(None, None).foreach { rowPair => + if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + store.remove(rowPair.key) + } + } + } + } } /** @@ -126,9 +150,11 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, + storeName = "default", storeVersion = getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) @@ -136,7 +162,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: savedState.toSeq + row +: Option(savedState).toSeq } } } @@ -165,54 +191,88 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numUpdatedStateRows += 1 + } + } + allRemovalsTimeMs += 0 + commitTimeMs += timeTakenMs { + store.commit() } - store.commit() numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => + store.iterator().map { rowPair => numOutputRows += 1 - v.asInstanceOf[InternalRow] + rowPair.value } // Update and output only rows being evicted from the StateStore + // Assumption: watermark predicates must be non-empty if append mode is allowed case Some(Append) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 + allUpdatesTimeMs += timeTakenMs { + val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) + while (filteredIter.hasNext) { + val row = filteredIter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numUpdatedStateRows += 1 + } } - // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicateForKeys.get.eval _) - store.commit() + val removalStartTimeNs = System.nanoTime + val rangeIter = store.getRange(None, None) + + new NextIterator[InternalRow] { + override protected def getNext(): InternalRow = { + var removedValueRow: InternalRow = null + while(rangeIter.hasNext && removedValueRow == null) { + val rowPair = rangeIter.next() + if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + store.remove(rowPair.key) + removedValueRow = rowPair.value + } + } + if (removedValueRow == null) { + finished = true + null + } else { + removedValueRow + } + } - numTotalStateRows += store.numKeys() - store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => - numOutputRows += 1 - removed.value.asInstanceOf[InternalRow] + override protected def close(): Unit = { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + numTotalStateRows += store.numKeys() + } } // Update and output modified rows from the StateStore. case Some(Update) => + val updatesStartTimeNs = System.nanoTime + new Iterator[InternalRow] { // Filter late date using watermark if specified @@ -223,11 +283,11 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + // Remove old aggregates if watermark specified - if (watermarkPredicateForKeys.nonEmpty) { - store.remove(watermarkPredicateForKeys.get.eval _) - } - store.commit() + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } numTotalStateRows += store.numKeys() false } else { @@ -238,7 +298,7 @@ case class StateStoreSaveExec( override def next(): InternalRow = { val row = baseIterator.next().asInstanceOf[UnsafeRow] val key = getKey(row) - store.put(key.copy(), row.copy()) + store.put(key, row) numOutputRows += 1 numUpdatedStateRows += 1 row @@ -273,27 +333,34 @@ case class StreamingDeduplicateExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } + val updatesStartTimeNs = System.nanoTime + val result = baseIterator.filter { r => val row = r.asInstanceOf[UnsafeRow] val key = getKey(row) val value = store.get(key) - if (value.isEmpty) { - store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW) + if (value == null) { + store.put(key, StreamingDeduplicateExec.EMPTY_ROW) numUpdatedStateRows += 1 numOutputRows += 1 true @@ -304,8 +371,9 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) - store.commit() + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } numTotalStateRows += store.numKeys() }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index bd197be655d5862337c8305b13e0e615dbd10d97..4a1a089af54c2178c233fb9f0f413bfd59c05be2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -38,13 +38,13 @@ import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + import StateStoreTestsHelper._ + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) - private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) - import StateStoreSuite._ - after { StateStore.stop() } @@ -60,13 +60,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -84,7 +85,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -110,7 +111,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { val resIterator = iter.map { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) val newValue = oldValue + 1 store.put(key, intToRow(newValue)) (s, newValue) @@ -125,21 +126,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn iter: Iterator[String]): Iterator[(String, Option[Int])] = { iter.map { s => val key = stringToRow(s) - val value = store.get(key).map(rowToInt) + val value = Option(store.get(key)).map(rowToInt) (s, value) } } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -152,15 +156,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + increment) require(rdd.partitions.length === 2) assert( @@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -208,7 +213,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) store.put(key, intToRow(oldValue + 1)) } store.commit() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index cc09b2d5b77638afc3ed2b7ae647174d4b676671..af2b9f1c11fb6a6bd09f70cf0e6fe877fe2c3057 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -40,15 +40,15 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { +class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] + with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ - import StateStoreSuite._ + import StateStoreTestsHelper._ - private val tempDir = Utils.createTempDir().toString - private val keySchema = StructType(Seq(StructField("key", StringType, true))) - private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + val keySchema = StructType(Seq(StructField("key", StringType, true))) + val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) before { StateStore.stop() @@ -60,186 +60,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth require(!StateStore.isMaintenanceRunning) } - test("get, put, remove, commit, and all data iterator") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator().isEmpty) - - val store = provider.getStore(0) - assert(!store.hasCommitted) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - - // Verify state after updating - put(store, "a", 1) - assert(store.numKeys() === 1) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - assert(provider.latestIterator().isEmpty) - - // Make updates, commit and then verify state - put(store, "b", 2) - put(store, "aa", 3) - assert(store.numKeys() === 3) - remove(store, _.startsWith("a")) - assert(store.numKeys() === 1) - assert(store.commit() === 1) - - assert(store.hasCommitted) - assert(rowsToSet(store.iterator()) === Set("b" -> 2)) - assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) - assert(fileExists(provider, version = 1, isSnapshot = false)) - - assert(getDataFromFiles(provider) === Set("b" -> 2)) - - // Trying to get newer versions should fail - intercept[Exception] { - provider.getStore(2) - } - intercept[Exception] { - getDataFromFiles(provider, 2) - } - - // New updates to the reloaded store with new version, and does not change old version - val reloadedProvider = new HDFSBackedStateStoreProvider( - store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) - val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.numKeys() === 1) - put(reloadedStore, "c", 4) - assert(reloadedStore.numKeys() === 2) - assert(reloadedStore.commit() === 2) - assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) - assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) - } - - test("filter and concurrent updates") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator.isEmpty) - val store = provider.getStore(0) - put(store, "a", 1) - put(store, "b", 2) - - // Updates should work while iterating of filtered entries - val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } - filtered.foreach { case (keyRow, valueRow) => - store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) - } - assert(get(store, "a") === Some(2)) - - // Removes should work while iterating of filtered entries - val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } - filtered2.foreach { case (keyRow, _) => - store.remove(keyRow) - } - assert(get(store, "b") === None) - } - - test("updates iterator with all combos of updates and removes") { - val provider = newStoreProvider() - var currentVersion: Int = 0 - - def withStore(body: StateStore => Unit): Unit = { - val store = provider.getStore(currentVersion) - body(store) - currentVersion += 1 - } - - // New data should be seen in updates as value added, even if they had multiple updates - withStore { store => - put(store, "a", 1) - put(store, "aa", 1) - put(store, "aa", 2) - store.commit() - assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) - assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) - } - - // Multiple updates to same key should be collapsed in the updates as a single value update - // Keys that have not been updated should not appear in the updates - withStore { store => - put(store, "a", 4) - put(store, "a", 6) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Keys added, updated and finally removed before commit should not appear in updates - withStore { store => - put(store, "b", 4) // Added, finally removed - put(store, "bb", 5) // Added, updated, finally removed - put(store, "bb", 6) - remove(store, _.startsWith("b")) - store.commit() - assert(updatesToSet(store.updates()) === Set.empty) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Removed data should be seen in updates as a key removed - // Removed, but re-added data should be seen in updates as a value update - withStore { store => - remove(store, _.startsWith("a")) - put(store, "a", 10) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) - assert(rowsToSet(store.iterator()) === Set("a" -> 10)) - } - } - - test("cancel") { - val provider = newStoreProvider() - val store = provider.getStore(0) - put(store, "a", 1) - store.commit() - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - // cancelUpdates should not change the data in the files - val store1 = provider.getStore(1) - put(store1, "b", 1) - store1.abort() - assert(getDataFromFiles(provider) === Set("a" -> 1)) - } - - test("getStore with unexpected versions") { - val provider = newStoreProvider() - - intercept[IllegalArgumentException] { - provider.getStore(-1) - } - - // Prepare some data in the store - val store = provider.getStore(0) - put(store, "a", 1) - assert(store.commit() === 1) - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - intercept[IllegalStateException] { - provider.getStore(2) - } - - // Update store version with some data - val store1 = provider.getStore(1) - put(store1, "b", 1) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) - assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) - } - test("snapshotting") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 def updateVersionTo(targetVersion: Int): Unit = { @@ -253,9 +75,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } updateVersionTo(2) - require(getDataFromFiles(provider) === Set("a" -> 2)) + require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files - assert(getDataFromFiles(provider) === Set("a" -> 2)) + assert(getData(provider) === Set("a" -> 2)) for (i <- 1 to currentVersion) { assert(fileExists(provider, i, isSnapshot = false)) // all delta files present @@ -264,22 +86,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // After version 6, snapshotting should generate one snapshot file updateVersionTo(6) - require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) assert(snapshotVersion.nonEmpty, "snapshot file not generated") deleteFilesEarlierThanVersion(provider, snapshotVersion.get) assert( - getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), "snapshotting messed up the data of the snapshotted version") assert( - getDataFromFiles(provider) === Set("a" -> 6), + getData(provider) === Set("a" -> 6), "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) - require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot val latestSnapshotVersion = (0 to 20).filter(version => @@ -288,11 +110,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) - assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data") } test("cleaning") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 20) { val store = provider.getStore(i - 1) @@ -307,8 +129,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted // last couple of versions should be retrievable - assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) - assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + assert(getData(provider, 20) === Set("a" -> 20)) + assert(getData(provider, 19) === Set("a" -> 19)) } test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { @@ -316,7 +138,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") - val provider = newStoreProvider(hadoopConf = conf) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf) provider.getStore(0).commit() provider.getStore(0).commit() @@ -327,7 +149,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("corrupted file handling") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) put(store, "a", i) @@ -338,62 +160,75 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) // Corrupt snapshot file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion)) corruptFile(provider, snapshotVersion, isSnapshot = true) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion) + getData(provider, snapshotVersion) } // Corrupt delta file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) corruptFile(provider, snapshotVersion - 1, isSnapshot = false) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } // Delete delta file and verify that it throws error deleteFilesEarlierThanVersion(provider, snapshotVersion) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } } test("StateStore.get") { quietly { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, 0, 0) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() - // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { - StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) } - // Increase version of the store - val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + // Increase version of the store and try to get again + val store0 = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) assert(store0.version === 0) put(store0, "a", 1) store0.commit() - assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) - assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + val store1 = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + assert(store1.version === 1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + + // Verify that you can also load older version + val store0reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) + assert(store0reloaded.version === 0) + assert(rowsToSet(store0reloaded.iterator()) === Set.empty) // Verify that you can remove the store and still reload and use it StateStore.unload(storeId) assert(!StateStore.isLoaded(storeId)) - val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + val store1reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - put(store1, "a", 2) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + assert(store1reloaded.version === 1) + put(store1reloaded, "a", 2) + assert(store1reloaded.commit() === 2) + assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2)) } } @@ -407,21 +242,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, opId, 0) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = new HDFSBackedStateStoreProvider( - storeId, keySchema, valueSchema, storeConf, hadoopConf) + val provider = newStoreProvider(storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + val store = StateStore.get(storeId, keySchema, valueSchema, None, + latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() latestStoreVersion += 1 @@ -469,7 +303,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -479,7 +314,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } @@ -495,10 +331,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ - val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath + val dir = scheme + "://" + newDir() val conf = new Configuration() conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName) - val provider = newStoreProvider(dir = dir, hadoopConf = conf) + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf) val store = provider.getStore(0) put(store, "a", 0) val e = intercept[IllegalStateException](store.commit()) @@ -506,7 +343,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("SPARK-18416: do not create temp delta file until the store is updated") { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, 0, 0) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -533,7 +370,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Getting the store should not create temp file val store0 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf) } // Put should create a temp file @@ -548,7 +386,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Remove should create a temp file val store1 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf) } remove(store1, _ == "a") assert(numTempFiles === 1) @@ -561,31 +400,55 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Commit without any updates should create a delta file val store2 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf) } store2.commit() assert(numTempFiles === 0) assert(numDeltaFiles === 3) } - def getDataFromFiles( - provider: HDFSBackedStateStoreProvider, + override def newStoreProvider(): HDFSBackedStateStoreProvider = { + newStoreProvider(opId = Random.nextInt(), partition = 0) + } + + override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + } + + override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { + getData(storeProvider) + } + + override def getData( + provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = new HDFSBackedStateStoreProvider( - provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedProvider = newStoreProvider(provider.id) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { - reloadedProvider.iterator(version).map(rowsToStringInt).toSet + reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet } } - def assertMap( - testMapOption: Option[MapType], - expectedMap: Map[String, Int]): Unit = { - assert(testMapOption.nonEmpty, "no map present") - val convertedMap = testMapOption.get.map(rowsToStringInt) - assert(convertedMap === expectedMap) + def newStoreProvider( + opId: Long, + partition: Int, + dir: String = newDir(), + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + val provider = new HDFSBackedStateStoreProvider() + provider.init( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + indexOrdinal = None, + new StateStoreConf(sqlConf), + hadoopConf) + provider } def fileExists( @@ -622,56 +485,150 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth filePath.delete() filePath.createNewFile() } +} - def storeLoaded(storeId: StateStoreId): Boolean = { - val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) - val loadedStores = StateStore invokePrivate method() - loadedStores.contains(storeId) - } +abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] + extends SparkFunSuite { + import StateStoreTestsHelper._ - def unloadStore(storeId: StateStoreId): Boolean = { - val method = PrivateMethod('remove) - StateStore invokePrivate method(storeId) - } + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() - def newStoreProvider( - opId: Long = Random.nextLong, - partition: Int = 0, - minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - dir: String = Utils.createDirectory(tempDir, Random.nextString(5)).toString, - hadoopConf: Configuration = new Configuration() - ): HDFSBackedStateStoreProvider = { - val sqlConf = new SQLConf() - sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) - sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) - new HDFSBackedStateStoreProvider( - StateStoreId(dir, opId, partition), - keySchema, - valueSchema, - new StateStoreConf(sqlConf), - hadoopConf) + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a") === None) + assert(store.iterator().isEmpty) + assert(store.numKeys() === 0) + + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + assert(store.numKeys() === 1) + + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + assert(store.numKeys() === 3) + remove(store, _.startsWith("a")) + assert(store.numKeys() === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(getLatestData(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = newStoreProvider(store.id) + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.numKeys() === 1) + put(reloadedStore, "c", 4) + assert(reloadedStore.numKeys() === 2) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) + assert(getData(provider, version = 1) === Set("b" -> 2)) } - def remove(store: StateStore, condition: String => Boolean): Unit = { - store.remove(row => condition(rowToString(row))) + test("removing while iterating") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" } + filtered.foreach { tuple => + store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" } + filtered2.foreach { tuple => store.remove(tuple.key) } + assert(get(store, "b") === None) } - private def put(store: StateStore, key: String, value: Int): Unit = { - store.put(stringToRow(key), intToRow(value)) + test("abort") { + val provider = newStoreProvider() + val store = provider.getStore(0) + put(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + put(store1, "b", 1) + store1.abort() } - private def get(store: StateStore, key: String): Option[Int] = { - store.get(stringToRow(key)).map(rowToInt) + test("getStore with invalid versions") { + val provider = newStoreProvider() + + def checkInvalidVersion(version: Int): Unit = { + intercept[Exception] { + provider.getStore(version) + } + } + + checkInvalidVersion(-1) + checkInvalidVersion(1) + + val store = provider.getStore(0) + put(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + val store1_ = provider.getStore(1) + assert(rowsToSet(store1_.iterator()) === Set("a" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(2) + + // Update store version with some data + val store1 = provider.getStore(1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + put(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(3) } -} -private[state] object StateStoreSuite { + /** Return a new provider with a random id */ + def newStoreProvider(): ProviderClass + + /** Return a new provider with the given id */ + def newStoreProvider(storeId: StateStoreId): ProviderClass + + /** Get the latest data referred to by the given provider but not using this provider */ + def getLatestData(storeProvider: ProviderClass): Set[(String, Int)] + + /** + * Get a specific version of data referred to by the given provider but not using + * this provider + */ + def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] +} - /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ - trait TestUpdate - case class Added(key: String, value: Int) extends TestUpdate - case class Updated(key: String, value: Int) extends TestUpdate - case class Removed(key: String) extends TestUpdate +object StateStoreTestsHelper { val strProj = UnsafeProjection.create(Array[DataType](StringType)) val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) @@ -692,26 +649,29 @@ private[state] object StateStoreSuite { row.getInt(0) } - def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { - (rowToInt(row._1), rowToInt(row._2)) + def rowsToStringInt(row: UnsafeRowPair): (String, Int) = { + (rowToString(row.key), rowToInt(row.value)) } + def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } - def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { - (rowToString(row._1), rowToInt(row._2)) + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.getRange(None, None).foreach { rowPair => + if (condition(rowToString(rowPair.key))) store.remove(rowPair.key) + } } - def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { - iterator.map(rowsToStringInt).toSet + def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) } - def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { - case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) - case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case ValueRemoved(key, _) => Removed(rowToString(key)) - }.toSet + def get(store: StateStore, key: String): Option[Int] = { + Option(store.get(stringToRow(key))).map(rowToInt) } + + def newDir(): String = Utils.createTempDir().toString } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6bb9408ce99ed28f0a6e3b7df4ad119393743258..0d9ca81349be5f26e1636b910ad79df74e63d914 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -508,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("StateStoreUpdater - rows are cloned before writing to StateStore") { - // function for running count - val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { - state.update(state.getOption.getOrElse(0) + values.size) - Iterator.empty - } - val store = newStateStore() - val plan = newFlatMapGroupsWithStateExec(func) - val updater = new plan.StateStoreUpdater(store) - val data = Seq(1, 1, 2) - val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) - returnIter.size // consume the iterator to force store updates - val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet - assert(storeData === Set((1, 2), (2, 1))) - } - test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything @@ -1016,11 +1000,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf callFunction() val updatedStateRow = store.get(key) assert( - updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow.nonEmpty) { + if (updatedStateRow != null) { assert( - updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, "final timeout timestamp not as expected") } } @@ -1080,25 +1064,19 @@ object FlatMapGroupsWithStateSuite { import scala.collection.JavaConverters._ private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } } - override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - iterator.filter { case (k, v) => c(k, v) } + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { + map.put(key.copy(), newValue.copy()) } - - override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def remove(condition: (UnsafeRow) => Boolean): Unit = { - iterator.map(_._1).filter(condition).foreach(map.remove) - } override def commit(): Long = version + 1 override def abort(): Unit = { } override def id: StateStoreId = null override def version: Long = 0 - override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } override def numKeys(): Long = map.size override def hasCommitted: Boolean = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 1fc062974e185a9518d53ec896abd27a151ee34f..280f2dc27b4a7d7df4d38a914558ec3a9f88dd52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -31,6 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider @@ -614,6 +616,30 @@ class StreamSuite extends StreamTest { assertDescContainsQueryNameAnd(batch = 2) query.stop() } + + testQuietly("specify custom state store provider") { + val queryName = "memStream" + val providerClassName = classOf[TestStateStoreProvider].getCanonicalName + withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { + val input = MemoryStream[Int] + val query = input + .toDS() + .groupBy() + .count() + .writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .start() + input.addData(1, 2, 3) + val e = intercept[Exception] { + query.awaitTermination() + } + + assert(e.getMessage.contains(providerClassName)) + assert(e.getMessage.contains("instantiated")) + } + } } abstract class FakeSource extends StreamSourceProvider { @@ -719,3 +745,22 @@ object ThrowingInterruptedIOException { */ @volatile var createSourceLatch: CountDownLatch = null } + +class TestStateStoreProvider extends StateStoreProvider { + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit = { + throw new Exception("Successfully instantiated") + } + + override def id: StateStoreId = null + + override def close(): Unit = { } + + override def getStore(version: Long): StateStore = null +}