diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index ce68c0968fb60b2a507125b07b02a6fbde06940e..31083469138002db28d2bcc62644697218b48b57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
 
 import java.util.concurrent.{CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.locks.ReentrantLock
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
@@ -53,8 +54,12 @@ class StreamExecution(
     val trigger: Trigger)
   extends ContinuousQuery with Logging {
 
-  /** An monitor used to wait/notify when batches complete. */
-  private val awaitBatchLock = new Object
+  /**
+   * A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation.
+   */
+  private val awaitBatchLock = new ReentrantLock(true)
+  private val awaitBatchLockCondition = awaitBatchLock.newCondition()
+
   private val startLatch = new CountDownLatch(1)
   private val terminationLatch = new CountDownLatch(1)
 
@@ -242,17 +247,22 @@ class StreamExecution(
     // method. See SPARK-14131.
     //
     // Check to see what new data is available.
-    val hasNewData = awaitBatchLock.synchronized {
-      val newData = microBatchThread.runUninterruptibly {
-        uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
-      }
-      availableOffsets ++= newData
+    val hasNewData = {
+      awaitBatchLock.lock()
+      try {
+        val newData = microBatchThread.runUninterruptibly {
+          uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
+        }
+        availableOffsets ++= newData
 
-      if (dataAvailable) {
-        true
-      } else {
-        noNewData = true
-        false
+        if (dataAvailable) {
+          true
+        } else {
+          noNewData = true
+          false
+        }
+      } finally {
+        awaitBatchLock.unlock()
       }
     }
     if (hasNewData) {
@@ -269,9 +279,12 @@ class StreamExecution(
       currentBatchId += 1
       logInfo(s"Committed offsets for batch $currentBatchId.")
     } else {
-      awaitBatchLock.synchronized {
+      awaitBatchLock.lock()
+      try {
         // Wake up any threads that are waiting for the stream to progress.
-        awaitBatchLock.notifyAll()
+        awaitBatchLockCondition.signalAll()
+      } finally {
+        awaitBatchLock.unlock()
       }
     }
   }
@@ -332,9 +345,12 @@ class StreamExecution(
       new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema))
     sink.addBatch(currentBatchId - 1, nextBatch)
 
-    awaitBatchLock.synchronized {
+    awaitBatchLock.lock()
+    try {
       // Wake up any threads that are waiting for the stream to progress.
-      awaitBatchLock.notifyAll()
+      awaitBatchLockCondition.signalAll()
+    } finally {
+      awaitBatchLock.unlock()
     }
 
     val batchTime = (System.nanoTime() - startTime).toDouble / 1000000
@@ -374,8 +390,12 @@ class StreamExecution(
     }
 
     while (notDone) {
-      logInfo(s"Waiting until $newOffset at $source")
-      awaitBatchLock.synchronized { awaitBatchLock.wait(100) }
+      awaitBatchLock.lock()
+      try {
+        awaitBatchLockCondition.await(100, TimeUnit.MILLISECONDS)
+      } finally {
+        awaitBatchLock.unlock()
+      }
     }
     logDebug(s"Unblocked at $newOffset for $source")
   }
@@ -383,16 +403,21 @@ class StreamExecution(
   /** A flag to indicate that a batch has completed with no new data available. */
   @volatile private var noNewData = false
 
-  override def processAllAvailable(): Unit = awaitBatchLock.synchronized {
-    noNewData = false
-    while (true) {
-      awaitBatchLock.wait(10000)
-      if (streamDeathCause != null) {
-        throw streamDeathCause
-      }
-      if (noNewData) {
-        return
+  override def processAllAvailable(): Unit = {
+    awaitBatchLock.lock()
+    try {
+      noNewData = false
+      while (true) {
+        awaitBatchLockCondition.await(10000, TimeUnit.MILLISECONDS)
+        if (streamDeathCause != null) {
+          throw streamDeathCause
+        }
+        if (noNewData) {
+          return
+        }
       }
+    } finally {
+      awaitBatchLock.unlock()
     }
   }