diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
index 7158abc08894ae88fb44459bece2ecef3311bc0a..b2cd524f28b74f963add937f37d783aed65144b6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -166,10 +166,12 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
       var segment: WriteAheadLogRecordHandle = null
       if (buffer.length > 0) {
         logDebug(s"Batched ${buffer.length} records for Write Ahead Log write")
+        // threads may not be able to add items in order by time
+        val sortedByTime = buffer.sortBy(_.time)
         // We take the latest record for the timestamp. Please refer to the class Javadoc for
         // detailed explanation
-        val time = buffer.last.time
-        segment = wrappedLog.write(aggregate(buffer), time)
+        val time = sortedByTime.last.time
+        segment = wrappedLog.write(aggregate(sortedByTime), time)
       }
       buffer.foreach(_.promise.success(segment))
     } catch {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index eaa88ea3cd380f8e32889393c0d7b307e67af85d..ef1e89df313050b4f742fc2429f764a3d76e3db8 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -480,7 +480,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
     p
   }
 
-  test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") {
+  test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") {
     val blockingWal = new BlockingWriteAheadLog(wal, walHandle)
     val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf)
 
@@ -500,8 +500,14 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
     // rest of the records will be batched while it takes time for 3 to get written
     writeAsync(batchedWal, event2, 5L)
     writeAsync(batchedWal, event3, 8L)
-    writeAsync(batchedWal, event4, 12L)
-    writeAsync(batchedWal, event5, 10L)
+    // we would like event 5 to be written before event 4 in order to test that they get
+    // sorted before being aggregated
+    writeAsync(batchedWal, event5, 12L)
+    eventually(timeout(1 second)) {
+      assert(blockingWal.isBlocked)
+      assert(batchedWal.invokePrivate(queueLength()) === 3)
+    }
+    writeAsync(batchedWal, event4, 10L)
     eventually(timeout(1 second)) {
       assert(walBatchingThreadPool.getActiveCount === 5)
       assert(batchedWal.invokePrivate(queueLength()) === 4)
@@ -517,7 +523,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
       // the file name should be the timestamp of the last record, as events should be naturally
       // in order of timestamp, and we need the last element.
       val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
-      verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L))
+      verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L))
       val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString)
       assert(records.toSet === queuedEvents)
     }