diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 5711262654a165078149a426d737853e4e8b12a9..1528e7f469bef5bc2a115ed59096b6d50132b380 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -628,7 +628,8 @@ class StreamExecution(
     // Rewire the plan to use the new attributes that were returned by the source.
     val replacementMap = AttributeMap(replacements)
     val triggerLogicalPlan = withNewSources transformAllExpressions {
-      case a: Attribute if replacementMap.contains(a) => replacementMap(a)
+      case a: Attribute if replacementMap.contains(a) =>
+        replacementMap(a).withMetadata(a.metadata)
       case ct: CurrentTimestamp =>
         CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
           ct.dataType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index 552911f32ee439c1b837370645d1da59d240b839..4f19fa0bb4a97e34748f78bef079a9a9bc59158e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -391,6 +391,34 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche
     checkDataset[Long](df, 1L to 100L: _*)
   }
 
+  test("SPARK-21565: watermark operator accepts attributes from replacement") {
+    withTempDir { dir =>
+      dir.delete()
+
+      val df = Seq(("a", 100.0, new java.sql.Timestamp(100L)))
+        .toDF("symbol", "price", "eventTime")
+      df.write.json(dir.getCanonicalPath)
+
+      val input = spark.readStream.schema(df.schema)
+        .json(dir.getCanonicalPath)
+
+      val groupEvents = input
+        .withWatermark("eventTime", "2 seconds")
+        .groupBy("symbol", "eventTime")
+        .agg(count("price") as 'count)
+        .select("symbol", "eventTime", "count")
+      val q = groupEvents.writeStream
+        .outputMode("append")
+        .format("console")
+        .start()
+      try {
+        q.processAllAvailable()
+      } finally {
+        q.stop()
+      }
+    }
+  }
+
   private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
     val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
     assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)