diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
index c919cdb4cd65f19b547a3acd9d2e4d1a0435fba9..e0dd4c9f0e2dc5ce946a604e38be8de7485e3a21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
@@ -24,6 +24,12 @@ import org.apache.spark.unsafe.types.CalendarInterval
 object EventTimeWatermark {
   /** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */
   val delayKey = "spark.watermarkDelayMs"
+
+  def getDelayMs(delay: CalendarInterval): Long = {
+    // We define month as `31 days` to simplify calculation.
+    val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
+    delay.milliseconds + delay.months * millisPerMonth
+  }
 }
 
 /**
@@ -37,9 +43,10 @@ case class EventTimeWatermark(
   // Update the metadata on the eventTime column to include the desired delay.
   override val output: Seq[Attribute] = child.output.map { a =>
     if (a semanticEquals eventTime) {
+      val delayMs = EventTimeWatermark.getDelayMs(delay)
       val updatedMetadata = new MetadataBuilder()
         .withMetadata(a.metadata)
-        .putLong(EventTimeWatermark.delayKey, delay.milliseconds)
+        .putLong(EventTimeWatermark.delayKey, delayMs)
         .build()
       a.withMetadata(updatedMetadata)
     } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 5a9a99e11188e1cf051f9077b6e8aa4f681c3a67..25cf609fc336ef619a4b3c16929f904533164105 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -84,10 +84,7 @@ case class EventTimeWatermarkExec(
     child: SparkPlan) extends SparkPlan {
 
   val eventTimeStats = new EventTimeStatsAccum()
-  val delayMs = {
-    val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
-    delay.milliseconds + delay.months * millisPerMonth
-  }
+  val delayMs = EventTimeWatermark.getDelayMs(delay)
 
   sparkContext.register(eventTimeStats)
 
@@ -105,10 +102,16 @@ case class EventTimeWatermarkExec(
   override val output: Seq[Attribute] = child.output.map { a =>
     if (a semanticEquals eventTime) {
       val updatedMetadata = new MetadataBuilder()
-          .withMetadata(a.metadata)
-          .putLong(EventTimeWatermark.delayKey, delayMs)
-          .build()
-
+        .withMetadata(a.metadata)
+        .putLong(EventTimeWatermark.delayKey, delayMs)
+        .build()
+      a.withMetadata(updatedMetadata)
+    } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
+      // Remove existing watermark
+      val updatedMetadata = new MetadataBuilder()
+        .withMetadata(a.metadata)
+        .remove(EventTimeWatermark.delayKey)
+        .build()
       a.withMetadata(updatedMetadata)
     } else {
       a