diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index e42df5dd61c702a799534c337be008d5968bcde9..5e79232a2043b76586d89dab33a84f7787a5db89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -120,7 +120,7 @@ case class FlatMapGroupsWithStateExec(
         val filteredIter = watermarkPredicateForData match {
           case Some(predicate) if timeoutConf == EventTimeTimeout =>
             iter.filter(row => !predicate.eval(row))
-          case None =>
+          case _ =>
             iter
         }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 85aa7dbe9ed86215884751699e84d2bbdb4ffd4c..89cfba6c559d6480ead8b79aae46208d988ff8b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -589,7 +589,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     )
   }
 
-  test("flatMapGroupsWithState - streaming with event time timeout") {
+  test("flatMapGroupsWithState - streaming with event time timeout + watermark") {
     // Function to maintain the max event time
     // Returns the max event time in the state, or -1 if the state was removed by timeout
     val stateFunc = (
@@ -761,6 +761,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     assert(e.getMessage === "The output mode of function should be append or update")
   }
 
+  def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
+    test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) {
+      // Function to maintain running count up to 2, and then remove the count
+      // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+      val stateFunc =
+      (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
+        if (state.hasTimedOut) {
+          state.remove()
+          Iterator((key, "-1"))
+        } else {
+          val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+          state.update(RunningCount(count))
+          state.setTimeoutDuration("10 seconds")
+          Iterator((key, count.toString))
+        }
+      }
+
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[(String, Long)]
+      val result =
+        inputData.toDF().toDF("key", "time")
+          .selectExpr("key", "cast(time as timestamp) as timestamp")
+          .withWatermark("timestamp", "10 second")
+          .as[(String, Long)]
+          .groupByKey(x => x._1)
+          .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc)
+
+      testStream(result, Update)(
+        StartStream(ProcessingTime("1 second"), triggerClock = clock),
+        AddData(inputData, ("a", 1L)),
+        AdvanceManualClock(1 * 1000),
+        CheckLastBatch(("a", "1"))
+      )
+    }
+  }
+  testWithTimeout(NoTimeout)
+  testWithTimeout(ProcessingTimeTimeout)
+
   def testStateUpdateWithData(
       testName: String,
       stateUpdates: GroupState[Int] => Unit,