diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index c1e8e65464fc10b60997d3271b56a80803af8315..b345276b08ba35533fc7a89128e6e93e8e0e0c08 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -70,6 +70,10 @@
       <artifactId>scalatest_${scala.binary.version}</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.scala-lang</groupId>
+      <artifactId>scala-library</artifactId>
+    </dependency>
     <dependency>
       <!--
         Netty explicitly added in test as it has been excluded from
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index 7da8eb3e35912347354933061daf5d393eb47fd0..e77cf7bfa54d0e91b9b99d92911cef86d161330e 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -19,6 +19,8 @@ package org.apache.spark.streaming.flume.sink
 import java.util.concurrent.{ConcurrentHashMap, Executors}
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.collection.JavaConversions._
+
 import org.apache.flume.Channel
 import org.apache.commons.lang.RandomStringUtils
 import com.google.common.util.concurrent.ThreadFactoryBuilder
@@ -45,7 +47,8 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
   val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads,
     new ThreadFactoryBuilder().setDaemon(true)
       .setNameFormat("Spark Sink Processor Thread - %d").build()))
-  private val processorMap = new ConcurrentHashMap[CharSequence, TransactionProcessor]()
+  private val sequenceNumberToProcessor =
+    new ConcurrentHashMap[CharSequence, TransactionProcessor]()
   // This sink will not persist sequence numbers and reuses them if it gets restarted.
   // So it is possible to commit a transaction which may have been meant for the sink before the
   // restart.
@@ -55,6 +58,8 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
   private val seqBase = RandomStringUtils.randomAlphanumeric(8)
   private val seqCounter = new AtomicLong(0)
 
+  @volatile private var stopped = false
+
   /**
    * Returns a bunch of events to Spark over Avro RPC.
    * @param n Maximum number of events to return in a batch
@@ -63,18 +68,33 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
   override def getEventBatch(n: Int): EventBatch = {
     logDebug("Got getEventBatch call from Spark.")
     val sequenceNumber = seqBase + seqCounter.incrementAndGet()
-    val processor = new TransactionProcessor(channel, sequenceNumber,
-      n, transactionTimeout, backOffInterval, this)
-    transactionExecutorOpt.foreach(executor => {
-      executor.submit(processor)
-    })
-    // Wait until a batch is available - will be an error if error message is non-empty
-    val batch = processor.getEventBatch
-    if (!SparkSinkUtils.isErrorBatch(batch)) {
-      processorMap.put(sequenceNumber.toString, processor)
-      logDebug("Sending event batch with sequence number: " + sequenceNumber)
+    createProcessor(sequenceNumber, n) match {
+      case Some(processor) =>
+        transactionExecutorOpt.foreach(_.submit(processor))
+        // Wait until a batch is available - will be an error if error message is non-empty
+        val batch = processor.getEventBatch
+        if (SparkSinkUtils.isErrorBatch(batch)) {
+          // Remove the processor if it is an error batch since no ACK is sent.
+          removeAndGetProcessor(sequenceNumber)
+          logWarning("Received an error batch - no events were received from channel! ")
+        }
+        batch
+      case None =>
+        new EventBatch("Spark sink has been stopped!", "", java.util.Collections.emptyList())
+    }
+  }
+
+  private def createProcessor(seq: String, n: Int): Option[TransactionProcessor] = {
+    sequenceNumberToProcessor.synchronized {
+      if (!stopped) {
+        val processor = new TransactionProcessor(
+          channel, seq, n, transactionTimeout, backOffInterval, this)
+        sequenceNumberToProcessor.put(seq, processor)
+        Some(processor)
+      } else {
+        None
+      }
     }
-    batch
   }
 
   /**
@@ -116,7 +136,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
    *         longer tracked and the caller is responsible for that txn processor.
    */
   private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): TransactionProcessor = {
-    processorMap.remove(sequenceNumber.toString) // The toString is required!
+    sequenceNumberToProcessor.synchronized {
+      sequenceNumberToProcessor.remove(sequenceNumber.toString)
+    }
   }
 
   /**
@@ -124,8 +146,10 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
    */
   def shutdown() {
     logInfo("Shutting down Spark Avro Callback Handler")
-    transactionExecutorOpt.foreach(executor => {
-      executor.shutdownNow()
-    })
+    sequenceNumberToProcessor.synchronized {
+      stopped = true
+      sequenceNumberToProcessor.values().foreach(_.shutdown())
+    }
+    transactionExecutorOpt.foreach(_.shutdownNow())
   }
 }
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
index b9e3c786ebb3b9e9dc5aabde3675a2f18d52bc75..13f3aa94be414b9905c84370113b3ff9f50de943 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
@@ -60,6 +60,8 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
   // succeeded.
   @volatile private var batchSuccess = false
 
+  @volatile private var stopped = false
+
   // The transaction that this processor would handle
   var txOpt: Option[Transaction] = None
 
@@ -88,6 +90,11 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
     batchAckLatch.countDown()
   }
 
+  private[flume] def shutdown(): Unit = {
+    logDebug("Shutting down transaction processor")
+    stopped = true
+  }
+
   /**
    * Populates events into the event batch. If the batch cannot be populated,
    * this method will not set the events into the event batch, but it sets an error message.
@@ -106,7 +113,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
         var gotEventsInThisTxn = false
         var loopCounter: Int = 0
         loop.breakable {
-          while (events.size() < maxBatchSize
+          while (!stopped && events.size() < maxBatchSize
             && loopCounter < totalAttemptsToRemoveFromChannel) {
             loopCounter += 1
             Option(channel.take()) match {
@@ -115,7 +122,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
                   ByteBuffer.wrap(event.getBody)))
                 gotEventsInThisTxn = true
               case None =>
-                if (!gotEventsInThisTxn) {
+                if (!gotEventsInThisTxn && !stopped) {
                   logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" +
                     " the current transaction")
                   TimeUnit.MILLISECONDS.sleep(backOffInterval)
@@ -125,7 +132,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
             }
           }
         }
-        if (!gotEventsInThisTxn) {
+        if (!gotEventsInThisTxn && !stopped) {
           val msg = "Tried several times, " +
             "but did not get any events from the channel!"
           logWarning(msg)
@@ -136,6 +143,11 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
         }
       })
     } catch {
+      case interrupted: InterruptedException =>
+        // Don't pollute logs if the InterruptedException came from this being stopped
+        if (!stopped) {
+          logWarning("Error while processing transaction.", interrupted)
+        }
       case e: Exception =>
         logWarning("Error while processing transaction.", e)
         eventBatch.setErrorMsg(e.getMessage)
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala
new file mode 100644
index 0000000000000000000000000000000000000000..88cc2aa3bf0220116e52d6d7f7eff808e492f0cc
--- /dev/null
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.flume
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+
+import com.google.common.base.Throwables
+
+import org.apache.spark.Logging
+import org.apache.spark.streaming.flume.sink._
+
+/**
+ * This class implements the core functionality of [[FlumePollingReceiver]]. When started it
+ * pulls data from Flume, stores it to Spark and then sends an Ack or Nack. This class should be
+ * run via an [[java.util.concurrent.Executor]] as this implements [[Runnable]]
+ *
+ * @param receiver The receiver that owns this instance.
+ */
+
+private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends Runnable with
+  Logging {
+
+  def run(): Unit = {
+    while (!receiver.isStopped()) {
+      val connection = receiver.getConnections.poll()
+      val client = connection.client
+      var batchReceived = false
+      var seq: CharSequence = null
+      try {
+        getBatch(client) match {
+          case Some(eventBatch) =>
+            batchReceived = true
+            seq = eventBatch.getSequenceNumber
+            val events = toSparkFlumeEvents(eventBatch.getEvents)
+            if (store(events)) {
+              sendAck(client, seq)
+            } else {
+              sendNack(batchReceived, client, seq)
+            }
+          case None =>
+        }
+      } catch {
+        case e: Exception =>
+          Throwables.getRootCause(e) match {
+            // If the cause was an InterruptedException, then check if the receiver is stopped -
+            // if yes, just break out of the loop. Else send a Nack and log a warning.
+            // In the unlikely case, the cause was not an Exception,
+            // then just throw it out and exit.
+            case interrupted: InterruptedException =>
+              if (!receiver.isStopped()) {
+                logWarning("Interrupted while receiving data from Flume", interrupted)
+                sendNack(batchReceived, client, seq)
+              }
+            case exception: Exception =>
+              logWarning("Error while receiving data from Flume", exception)
+              sendNack(batchReceived, client, seq)
+          }
+      } finally {
+        receiver.getConnections.add(connection)
+      }
+    }
+  }
+
+  /**
+   * Gets a batch of events from the specified client. This method does not handle any exceptions
+   * which will be propogated to the caller.
+   * @param client Client to get events from
+   * @return [[Some]] which contains the event batch if Flume sent any events back, else [[None]]
+   */
+  private def getBatch(client: SparkFlumeProtocol.Callback): Option[EventBatch] = {
+    val eventBatch = client.getEventBatch(receiver.getMaxBatchSize)
+    if (!SparkSinkUtils.isErrorBatch(eventBatch)) {
+      // No error, proceed with processing data
+      logDebug(s"Received batch of ${eventBatch.getEvents.size} events with sequence " +
+        s"number: ${eventBatch.getSequenceNumber}")
+      Some(eventBatch)
+    } else {
+      logWarning("Did not receive events from Flume agent due to error on the Flume agent: " +
+        eventBatch.getErrorMsg)
+      None
+    }
+  }
+
+  /**
+   * Store the events in the buffer to Spark. This method will not propogate any exceptions,
+   * but will propogate any other errors.
+   * @param buffer The buffer to store
+   * @return true if the data was stored without any exception being thrown, else false
+   */
+  private def store(buffer: ArrayBuffer[SparkFlumeEvent]): Boolean = {
+    try {
+      receiver.store(buffer)
+      true
+    } catch {
+      case e: Exception =>
+        logWarning("Error while attempting to store data received from Flume", e)
+        false
+    }
+  }
+
+  /**
+   * Send an ack to the client for the sequence number. This method does not handle any exceptions
+   * which will be propagated to the caller.
+   * @param client client to send the ack to
+   * @param seq sequence number of the batch to be ack-ed.
+   * @return
+   */
+  private def sendAck(client: SparkFlumeProtocol.Callback, seq: CharSequence): Unit = {
+    logDebug("Sending ack for sequence number: " + seq)
+    client.ack(seq)
+    logDebug("Ack sent for sequence number: " + seq)
+  }
+
+  /**
+   * This method sends a Nack if a batch was received to the client with the given sequence
+   * number. Any exceptions thrown by the RPC call is simply thrown out as is - no effort is made
+   * to handle it.
+   * @param batchReceived true if a batch was received. If this is false, no nack is sent
+   * @param client The client to which the nack should be sent
+   * @param seq The sequence number of the batch that is being nack-ed.
+   */
+  private def sendNack(batchReceived: Boolean, client: SparkFlumeProtocol.Callback,
+    seq: CharSequence): Unit = {
+    if (batchReceived) {
+      // Let Flume know that the events need to be pushed back into the channel.
+      logDebug("Sending nack for sequence number: " + seq)
+      client.nack(seq) // If the agent is down, even this could fail and throw
+      logDebug("Nack sent for sequence number: " + seq)
+    }
+  }
+
+  /**
+   * Utility method to convert [[SparkSinkEvent]]s to [[SparkFlumeEvent]]s
+   * @param events - Events to convert to SparkFlumeEvents
+   * @return - The SparkFlumeEvent generated from SparkSinkEvent
+   */
+  private def toSparkFlumeEvents(events: java.util.List[SparkSinkEvent]):
+    ArrayBuffer[SparkFlumeEvent] = {
+    // Convert each Flume event to a serializable SparkFlumeEvent
+    val buffer = new ArrayBuffer[SparkFlumeEvent](events.size())
+    var j = 0
+    while (j < events.size()) {
+      val event = events(j)
+      val sparkFlumeEvent = new SparkFlumeEvent()
+      sparkFlumeEvent.event.setBody(event.getBody)
+      sparkFlumeEvent.event.setHeaders(event.getHeaders)
+      buffer += sparkFlumeEvent
+      j += 1
+    }
+    buffer
+  }
+}
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
index 148262bb6771e0b2ab5feee281627e3417f2826c..92fa5b41be89e7d25de26b1604fd905e779d400e 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
@@ -18,10 +18,9 @@ package org.apache.spark.streaming.flume
 
 
 import java.net.InetSocketAddress
-import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, Executors}
+import java.util.concurrent.{LinkedBlockingQueue, Executors}
 
 import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
 
 import com.google.common.util.concurrent.ThreadFactoryBuilder
@@ -86,61 +85,9 @@ private[streaming] class FlumePollingReceiver(
       connections.add(new FlumeConnection(transceiver, client))
     })
     for (i <- 0 until parallelism) {
-      logInfo("Starting Flume Polling Receiver worker threads starting..")
+      logInfo("Starting Flume Polling Receiver worker threads..")
       // Threads that pull data from Flume.
-      receiverExecutor.submit(new Runnable {
-        override def run(): Unit = {
-          while (true) {
-            val connection = connections.poll()
-            val client = connection.client
-            try {
-              val eventBatch = client.getEventBatch(maxBatchSize)
-              if (!SparkSinkUtils.isErrorBatch(eventBatch)) {
-                // No error, proceed with processing data
-                val seq = eventBatch.getSequenceNumber
-                val events: java.util.List[SparkSinkEvent] = eventBatch.getEvents
-                logDebug(
-                  "Received batch of " + events.size() + " events with sequence number: " + seq)
-                try {
-                  // Convert each Flume event to a serializable SparkFlumeEvent
-                  val buffer = new ArrayBuffer[SparkFlumeEvent](events.size())
-                  var j = 0
-                  while (j < events.size()) {
-                    buffer += toSparkFlumeEvent(events(j))
-                    j += 1
-                  }
-                  store(buffer)
-                  logDebug("Sending ack for sequence number: " + seq)
-                  // Send an ack to Flume so that Flume discards the events from its channels.
-                  client.ack(seq)
-                  logDebug("Ack sent for sequence number: " + seq)
-                } catch {
-                  case e: Exception =>
-                    try {
-                      // Let Flume know that the events need to be pushed back into the channel.
-                      logDebug("Sending nack for sequence number: " + seq)
-                      client.nack(seq) // If the agent is down, even this could fail and throw
-                      logDebug("Nack sent for sequence number: " + seq)
-                    } catch {
-                      case e: Exception => logError(
-                        "Sending Nack also failed. A Flume agent is down.")
-                    }
-                    TimeUnit.SECONDS.sleep(2L) // for now just leave this as a fixed 2 seconds.
-                    logWarning("Error while attempting to store events", e)
-                }
-              } else {
-                logWarning("Did not receive events from Flume agent due to error on the Flume " +
-                  "agent: " + eventBatch.getErrorMsg)
-              }
-            } catch {
-              case e: Exception =>
-                logWarning("Error while reading data from Flume", e)
-            } finally {
-              connections.add(connection)
-            }
-          }
-        }
-      })
+      receiverExecutor.submit(new FlumeBatchFetcher(this))
     }
   }
 
@@ -153,16 +100,12 @@ private[streaming] class FlumePollingReceiver(
     channelFactory.releaseExternalResources()
   }
 
-  /**
-   * Utility method to convert [[SparkSinkEvent]] to [[SparkFlumeEvent]]
-   * @param event - Event to convert to SparkFlumeEvent
-   * @return - The SparkFlumeEvent generated from SparkSinkEvent
-   */
-  private def toSparkFlumeEvent(event: SparkSinkEvent): SparkFlumeEvent = {
-    val sparkFlumeEvent = new SparkFlumeEvent()
-    sparkFlumeEvent.event.setBody(event.getBody)
-    sparkFlumeEvent.event.setHeaders(event.getHeaders)
-    sparkFlumeEvent
+  private[flume] def getConnections: LinkedBlockingQueue[FlumeConnection] = {
+    this.connections
+  }
+
+  private[flume] def getMaxBatchSize: Int = {
+    this.maxBatchSize
   }
 }
 
@@ -171,7 +114,7 @@ private[streaming] class FlumePollingReceiver(
  * @param transceiver The transceiver to use for communication with Flume
  * @param client The client that the callbacks are received on.
  */
-private class FlumeConnection(val transceiver: NettyTransceiver,
+private[flume] class FlumeConnection(val transceiver: NettyTransceiver,
   val client: SparkFlumeProtocol.Callback)