Skip to content
Snippets Groups Projects
Commit 8c826880 authored by Tathagata Das's avatar Tathagata Das
Browse files

[SPARK-13809][SQL] State store for streaming aggregations

## What changes were proposed in this pull request?

In this PR, I am implementing a new abstraction for management of streaming state data - State Store. It is a key-value store for persisting running aggregates for aggregate operations in streaming dataframes. The motivation and design is discussed here.

https://docs.google.com/document/d/1-ncawFx8JS5Zyfq1HAEGBx56RDet9wfVp_hDM8ZL254/edit#

## How was this patch tested?
- [x] Unit tests
- [x] Cluster tests

**Coverage from unit tests**

<img width="952" alt="screen shot 2016-03-21 at 3 09 40 pm" src="https://cloud.githubusercontent.com/assets/663212/13935872/fdc8ba86-ef76-11e5-93e8-9fa310472c7b.png">

## TODO
- [x] Fix updates() iterator to avoid duplicate updates for same key
- [x] Use Coordinator in ContinuousQueryManager
- [x] Plugging in hadoop conf and other confs
- [x] Unit tests
  - [x] StateStore object lifecycle and methods
  - [x] StateStoreCoordinator communication and logic
  - [x] StateStoreRDD fault-tolerance
  - [x] StateStoreRDD preferred location using StateStoreCoordinator
- [ ] Cluster tests
  - [ ] Whether preferred locations are set correctly
  - [ ] Whether recovery works correctly with distributed storage
  - [x] Basic performance tests
- [x] Docs

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #11645 from tdas/state-store.
parent 0a64294f
No related branches found
No related tags found
No related merge requests found
Showing
with 2052 additions and 0 deletions
...@@ -21,6 +21,7 @@ import scala.collection.mutable ...@@ -21,6 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener import org.apache.spark.sql.util.ContinuousQueryListener
/** /**
...@@ -33,6 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener ...@@ -33,6 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener
@Experimental @Experimental
class ContinuousQueryManager(sqlContext: SQLContext) { class ContinuousQueryManager(sqlContext: SQLContext) {
private[sql] val stateStoreCoordinator =
StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env)
private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus)
private val activeQueries = new mutable.HashMap[String, ContinuousQuery] private val activeQueries = new mutable.HashMap[String, ContinuousQuery]
private val activeQueriesLock = new Object private val activeQueriesLock = new Object
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.util.Timer
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import scala.collection.mutable
import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils
/** Unique identifier for a [[StateStore]] */
case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int)
/**
* Base trait for a versioned key-value store used for streaming aggregations
*/
trait StateStore {
/** Unique identifier of the store */
def id: StateStoreId
/** Version of the data in this store before committing updates. */
def version: Long
/**
* Update the value of a key using the value generated by the update function.
* @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
* versions of the store data.
*/
def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
/**
* Remove keys that match the following condition.
*/
def remove(condition: UnsafeRow => Boolean): Unit
/**
* Commit all the updates that have been made to the store, and return the new version.
*/
def commit(): Long
/** Cancel all the updates that have been made to the store. */
def cancel(): Unit
/**
* Iterator of store data after a set of updates have been committed.
* This can be called only after commitUpdates() has been called in the current thread.
*/
def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
/**
* Iterator of the updates that have been committed.
* This can be called only after commitUpdates() has been called in the current thread.
*/
def updates(): Iterator[StoreUpdate]
/**
* Whether all updates have been committed
*/
def hasCommitted: Boolean
}
/** Trait representing a provider of a specific version of a [[StateStore]]. */
trait StateStoreProvider {
/** Get the store with the existing version. */
def getStore(version: Long): StateStore
/** Optional method for providers to allow for background maintenance */
def doMaintenance(): Unit = { }
}
/** Trait representing updates made to a [[StateStore]]. */
sealed trait StoreUpdate
case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
case class KeyRemoved(key: UnsafeRow) extends StoreUpdate
/**
* Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
* by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
* it also runs a periodic background tasks to do maintenance on the loaded stores. For each
* store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
* the store is the active instance. Accordingly, it either keeps it loaded and performs
* maintenance, or unloads the store.
*/
private[state] object StateStore extends Logging {
val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval"
val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
private val maintenanceTaskExecutor =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
@volatile private var maintenanceTask: ScheduledFuture[_] = null
@volatile private var _coordRef: StateStoreCoordinatorRef = null
/** Get or create a store associated with the id. */
def get(
storeId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
version: Long,
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStore = {
require(version >= 0)
val storeProvider = loadedProviders.synchronized {
startMaintenanceIfNeeded()
val provider = loadedProviders.getOrElseUpdate(
storeId,
new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf))
reportActiveStoreInstance(storeId)
provider
}
storeProvider.getStore(version)
}
/** Unload a state store provider */
def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeId)
}
/** Whether a state store provider is loaded or not */
def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
loadedProviders.contains(storeId)
}
/** Unload and stop all state store providers */
def stop(): Unit = loadedProviders.synchronized {
loadedProviders.clear()
_coordRef = null
if (maintenanceTask != null) {
maintenanceTask.cancel(false)
maintenanceTask = null
}
logInfo("StateStore stopped")
}
/** Start the periodic maintenance task if not already started and if Spark active */
private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized {
val env = SparkEnv.get
if (maintenanceTask == null && env != null) {
val periodMs = env.conf.getTimeAsMs(
MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s")
val runnable = new Runnable {
override def run(): Unit = { doMaintenance() }
}
maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate(
runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
logInfo("State Store maintenance task started")
}
}
/**
* Execute background maintenance task in all the loaded store providers if they are still
* the active instances according to the coordinator.
*/
private def doMaintenance(): Unit = {
logDebug("Doing maintenance")
loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
try {
if (verifyIfStoreInstanceActive(id)) {
provider.doMaintenance()
} else {
unload(id)
logInfo(s"Unloaded $provider")
}
} catch {
case NonFatal(e) =>
logWarning(s"Error managing $provider")
}
}
}
private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
try {
val host = SparkEnv.get.blockManager.blockManagerId.host
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
logDebug(s"Reported that the loaded instance $storeId is active")
} catch {
case NonFatal(e) =>
logWarning(s"Error reporting active instance of $storeId")
}
}
private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
try {
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
val verified =
coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" )
verified
} catch {
case NonFatal(e) =>
logWarning(s"Error verifying active instance of $storeId")
false
}
}
private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized {
val env = SparkEnv.get
if (env != null) {
if (_coordRef == null) {
_coordRef = StateStoreCoordinatorRef.forExecutor(env)
}
logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
Some(_coordRef)
} else {
_coordRef = null
None
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import org.apache.spark.sql.internal.SQLConf
/** A class that contains configuration parameters for [[StateStore]]s. */
private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
def this() = this(new SQLConf)
import SQLConf._
val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
}
private[state] object StateStoreConf {
val empty = new StateStoreConf()
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import scala.collection.mutable
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.util.RpcUtils
/** Trait representing all messages to [[StateStoreCoordinator]] */
private sealed trait StateStoreCoordinatorMessage extends Serializable
/** Classes representing messages */
private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String)
extends StateStoreCoordinatorMessage
private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String)
extends StateStoreCoordinatorMessage
private case class GetLocation(storeId: StateStoreId)
extends StateStoreCoordinatorMessage
private case class DeactivateInstances(storeRootLocation: String)
extends StateStoreCoordinatorMessage
private object StopCoordinator
extends StateStoreCoordinatorMessage
/** Helper object used to create reference to [[StateStoreCoordinator]]. */
private[sql] object StateStoreCoordinatorRef extends Logging {
private val endpointName = "StateStoreCoordinator"
/**
* Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as
* executors.
*/
def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
try {
val coordinator = new StateStoreCoordinator(env.rpcEnv)
val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator)
logInfo("Registered StateStoreCoordinator endpoint")
new StateStoreCoordinatorRef(coordinatorRef)
} catch {
case e: IllegalArgumentException =>
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv)
logDebug("Retrieved existing StateStoreCoordinator endpoint")
new StateStoreCoordinatorRef(rpcEndpointRef)
}
}
def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv)
logDebug("Retrieved existing StateStoreCoordinator endpoint")
new StateStoreCoordinatorRef(rpcEndpointRef)
}
}
/**
* Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of
* [[StateStore]]s across all the executors, and get their locations for job scheduling.
*/
private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
private[state] def reportActiveInstance(
storeId: StateStoreId,
host: String,
executorId: String): Unit = {
rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId))
}
/** Verify whether the given executor has the active instance of a state store */
private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = {
rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId))
}
/** Get the location of the state store */
private[state] def getLocation(storeId: StateStoreId): Option[String] = {
rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId))
}
/** Deactivate instances related to a set of operator */
private[state] def deactivateInstances(storeRootLocation: String): Unit = {
rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation))
}
private[state] def stop(): Unit = {
rpcEndpointRef.askWithRetry[Boolean](StopCoordinator)
}
}
/**
* Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
* and get their locations for job scheduling.
*/
private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint {
private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]
override def receive: PartialFunction[Any, Unit] = {
case ReportActiveInstance(id, host, executorId) =>
instances.put(id, ExecutorCacheTaskLocation(host, executorId))
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case VerifyIfInstanceActive(id, execId) =>
val response = instances.get(id) match {
case Some(location) => location.executorId == execId
case None => false
}
context.reply(response)
case GetLocation(id) =>
context.reply(instances.get(id).map(_.toString))
case DeactivateInstances(loc) =>
val storeIdsToRemove =
instances.keys.filter(_.checkpointLocation == loc).toSeq
instances --= storeIdsToRemove
context.reply(true)
case StopCoordinator =>
stop() // Stop before replying to ensure that endpoint name has been deregistered
context.reply(true)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
* An RDD that allows computations to be executed against [[StateStore]]s. It
* uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as
* preferred locations.
*/
class StateStoreRDD[T: ClassTag, U: ClassTag](
dataRDD: RDD[T],
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType,
storeConf: StateStoreConf,
@transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
extends RDD[U](dataRDD) {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val confBroadcast = dataRDD.context.broadcast(
new SerializableConfiguration(dataRDD.context.hadoopConfiguration))
override protected def getPartitions: Array[Partition] = dataRDD.partitions
override def getPreferredLocations(partition: Partition): Seq[String] = {
val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
}
override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
var store: StateStore = null
Utils.tryWithSafeFinally {
val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
store = StateStore.get(
storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
val outputIter = storeUpdateFunction(store, inputIter)
assert(store.hasCommitted)
outputIter
} {
if (store != null) store.cancel()
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType
package object state {
implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
/** Map each partition of a RDD along with data in a [[StateStore]]. */
def mapPartitionWithStateStore[U: ClassTag](
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType
)(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
mapPartitionWithStateStore(
storeUpdateFunction,
checkpointLocation,
operatorId,
storeVersion,
keySchema,
valueSchema,
new StateStoreConf(sqlContext.conf),
Some(sqlContext.streams.stateStoreCoordinator))
}
/** Map each partition of a RDD along with data in a [[StateStore]]. */
private[state] def mapPartitionWithStateStore[U: ClassTag](
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType,
storeConf: StateStoreConf,
storeCoordinator: Option[StateStoreCoordinatorRef]
): StateStoreRDD[T, U] = {
val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
new StateStoreRDD(
dataRDD,
cleanedF,
checkpointLocation,
operatorId,
storeVersion,
keySchema,
valueSchema,
storeConf,
storeCoordinator)
}
}
}
...@@ -524,6 +524,19 @@ object SQLConf { ...@@ -524,6 +524,19 @@ object SQLConf {
doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", doc = "When true, the planner will try to find out duplicated exchanges and re-use them.",
isPublic = false) isPublic = false)
val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf(
"spark.sql.streaming.stateStore.minDeltasForSnapshot",
defaultValue = Some(10),
doc = "Minimum number of state store delta files that needs to be generated before they " +
"consolidated into snapshots.",
isPublic = false)
val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf(
"spark.sql.streaming.stateStore.minBatchesToRetain",
defaultValue = Some(2),
doc = "Minimum number of versions of a state store's data to retain after cleaning.",
isPublic = false)
val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation", val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation",
defaultValue = None, defaultValue = None,
doc = "The default location for storing checkpoint data for continuously executing queries.", doc = "The default location for storing checkpoint data for continuously executing queries.",
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
import StateStoreCoordinatorSuite._
test("report, verify, getLocation") {
withCoordinatorRef(sc) { coordinatorRef =>
val id = StateStoreId("x", 0, 0)
assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
assert(coordinatorRef.getLocation(id) === None)
coordinatorRef.reportActiveInstance(id, "hostX", "exec1")
eventually(timeout(5 seconds)) {
assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true)
assert(
coordinatorRef.getLocation(id) ===
Some(ExecutorCacheTaskLocation("hostX", "exec1").toString))
}
coordinatorRef.reportActiveInstance(id, "hostX", "exec2")
eventually(timeout(5 seconds)) {
assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true)
assert(
coordinatorRef.getLocation(id) ===
Some(ExecutorCacheTaskLocation("hostX", "exec2").toString))
}
}
}
test("make inactive") {
withCoordinatorRef(sc) { coordinatorRef =>
val id1 = StateStoreId("x", 0, 0)
val id2 = StateStoreId("y", 1, 0)
val id3 = StateStoreId("x", 0, 1)
val host = "hostX"
val exec = "exec1"
coordinatorRef.reportActiveInstance(id1, host, exec)
coordinatorRef.reportActiveInstance(id2, host, exec)
coordinatorRef.reportActiveInstance(id3, host, exec)
eventually(timeout(5 seconds)) {
assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true)
assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true)
}
coordinatorRef.deactivateInstances("x")
assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false)
assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false)
assert(coordinatorRef.getLocation(id1) === None)
assert(
coordinatorRef.getLocation(id2) ===
Some(ExecutorCacheTaskLocation(host, exec).toString))
assert(coordinatorRef.getLocation(id3) === None)
coordinatorRef.deactivateInstances("y")
assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false)
assert(coordinatorRef.getLocation(id2) === None)
}
}
test("multiple references have same underlying coordinator") {
withCoordinatorRef(sc) { coordRef1 =>
val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env)
val id = StateStoreId("x", 0, 0)
coordRef1.reportActiveInstance(id, "hostX", "exec1")
eventually(timeout(5 seconds)) {
assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true)
assert(
coordRef2.getLocation(id) ===
Some(ExecutorCacheTaskLocation("hostX", "exec1").toString))
}
}
}
}
object StateStoreCoordinatorSuite {
def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = {
var coordinatorRef: StateStoreCoordinatorRef = null
try {
coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env)
body(coordinatorRef)
} finally {
if (coordinatorRef != null) coordinatorRef.stop()
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.io.File
import java.nio.file.Files
import scala.util.Random
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.LocalSparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.Utils
class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
private val keySchema = StructType(Seq(StructField("key", StringType, true)))
private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
import StateStoreSuite._
after {
StateStore.stop()
}
override def afterAll(): Unit = {
super.afterAll()
Utils.deleteRecursively(new File(tempDir))
}
test("versioning and immutability") {
quietly {
withSpark(new SparkContext(sparkConf)) { sc =>
implicit val sqlContet = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
store.update(
stringToRow(s), oldRow => {
val oldValue = oldRow.map(rowToInt).getOrElse(0)
intToRow(oldValue + 1)
})
}
store.commit()
store.iterator().map(rowsToStringInt)
}
val opId = 0
val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 0, keySchema, valueSchema)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 1, keySchema, valueSchema)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
}
}
}
test("recovering from files") {
quietly {
val opId = 0
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
def makeStoreRDD(
sc: SparkContext,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
implicit val sqlContext = new SQLContext(sc)
makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion, keySchema, valueSchema)
}
// Generate RDDs and state store data
withSpark(new SparkContext(sparkConf)) { sc =>
for (i <- 1 to 20) {
require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
}
}
// With a new context, try using the earlier state store data
withSpark(new SparkContext(sparkConf)) { sc =>
assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
}
}
}
test("preferred locations using StateStoreCoordinator") {
quietly {
val opId = 0
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
withSpark(new SparkContext(sparkConf)) { sc =>
implicit val sqlContext = new SQLContext(sc)
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
assert(
coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 0, keySchema, valueSchema)
require(rdd.partitions.length === 2)
assert(
rdd.preferredLocations(rdd.partitions(0)) ===
Seq(ExecutorCacheTaskLocation("host1", "exec1").toString))
assert(
rdd.preferredLocations(rdd.partitions(1)) ===
Seq(ExecutorCacheTaskLocation("host2", "exec2").toString))
rdd.collect()
}
}
}
test("distributed test") {
quietly {
withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
implicit val sqlContet = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
store.update(
stringToRow(s), oldRow => {
val oldValue = oldRow.map(rowToInt).getOrElse(0)
intToRow(oldValue + 1)
})
}
store.commit()
store.iterator().map(rowsToStringInt)
}
val opId = 0
val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 0, keySchema, valueSchema)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
increment, path, opId, storeVersion = 1, keySchema, valueSchema)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
}
}
}
private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = {
sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
}
private val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
store.update(
stringToRow(s), oldRow => {
val oldValue = oldRow.map(rowToInt).getOrElse(0)
intToRow(oldValue + 1)
})
}
store.commit()
store.iterator().map(rowsToStringInt)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment