diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 351ef404a8e39d15dea3d561edb7f2d64a6cd271..3820968324bfe807db8e575284a44cd419f88658 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.streaming
 
 import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
@@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
   protected val encoder = encoderFor[A]
   protected val logicalPlan = StreamingExecutionRelation(this)
   protected val output = logicalPlan.output
+
+  @GuardedBy("this")
   protected val batches = new ArrayBuffer[Dataset[A]]
 
+  @GuardedBy("this")
   protected var currentOffset: LongOffset = new LongOffset(-1)
 
   def schema: StructType = encoder.schema
@@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
 
   def addData(data: TraversableOnce[A]): Offset = {
     import sqlContext.implicits._
+    val ds = data.toVector.toDS()
+    logDebug(s"Adding ds: $ds")
     this.synchronized {
       currentOffset = currentOffset + 1
-      val ds = data.toVector.toDS()
-      logDebug(s"Adding ds: $ds")
       batches.append(ds)
       currentOffset
     }
@@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
 
   override def toString: String = s"MemoryStream[${output.mkString(",")}]"
 
-  override def getOffset: Option[Offset] = if (batches.isEmpty) {
-    None
-  } else {
-    Some(currentOffset)
+  override def getOffset: Option[Offset] = synchronized {
+    if (batches.isEmpty) {
+      None
+    } else {
+      Some(currentOffset)
+    }
   }
 
   /**
@@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
     val startOrdinal =
       start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
     val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
-    val newBlocks = batches.slice(startOrdinal, endOrdinal)
+    val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) }
 
     logDebug(
       s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
@@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
  */
 class MemorySink(val schema: StructType) extends Sink with Logging {
   /** An order list of batches that have been written to this [[Sink]]. */
+  @GuardedBy("this")
   private val batches = new ArrayBuffer[Array[Row]]()
 
   /** Returns all rows that are stored in this [[Sink]]. */
@@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
     batches.flatten
   }
 
-  def lastBatch: Seq[Row] = batches.last
+  def lastBatch: Seq[Row] = synchronized { batches.last }
 
   def toDebugString: String = synchronized {
     batches.zipWithIndex.map { case (b, i) =>
@@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
     }.mkString("\n")
   }
 
-  override def addBatch(batchId: Long, data: DataFrame): Unit = {
+  override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
     if (batchId == batches.size) {
       logDebug(s"Committing batch $batchId")
       batches.append(data.collect())