From 7863c623791d088684107f833fdecb4b5fdab4ec Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Mon, 5 Dec 2016 20:35:24 -0800
Subject: [PATCH] [SPARK-18721][SS] Fix ForeachSink with watermark + append

## What changes were proposed in this pull request?

Right now ForeachSink creates a new physical plan, so StreamExecution cannot retrieval metrics and watermark.

This PR changes ForeachSink to manually convert InternalRows to objects without creating a new plan.

## How was this patch tested?

`test("foreach with watermark: append")`.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16160 from zsxwing/SPARK-18721.
---
 .../sql/execution/streaming/ForeachSink.scala | 45 ++++--------
 .../streaming/ForeachSinkSuite.scala          | 68 ++++++++++++++++++-
 2 files changed, 79 insertions(+), 34 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
index c93fcfb77c..de09fb568d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
@@ -18,9 +18,8 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.TaskContext
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter}
-import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde
+import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 
 /**
  * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
@@ -32,46 +31,26 @@ import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde
 class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
 
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
-    // TODO: Refine this method when SPARK-16264 is resolved; see comments below.
-
     // This logic should've been as simple as:
     // ```
     //   data.as[T].foreachPartition { iter => ... }
     // ```
     //
     // Unfortunately, doing that would just break the incremental planing. The reason is,
-    // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just
-    // does not support `IncrementalExecution`.
+    // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will
+    // create a new plan. Because StreamExecution uses the existing plan to collect metrics and
+    // update watermark, we should never create a new plan. Otherwise, metrics and watermark are
+    // updated in the new plan, and StreamExecution cannot retrieval them.
     //
-    // So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()`
-    // method supporting incremental planning. But in the long run, we should generally make newly
-    // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to
-    // resolve).
-    val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution]
-    val datasetWithIncrementalExecution =
-      new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) {
-        override lazy val rdd: RDD[T] = {
-          val objectType = exprEnc.deserializer.dataType
-          val deserialized = CatalystSerde.deserialize[T](logicalPlan)
-
-          // was originally: sparkSession.sessionState.executePlan(deserialized) ...
-          val newIncrementalExecution = new IncrementalExecution(
-            this.sparkSession,
-            deserialized,
-            incrementalExecution.outputMode,
-            incrementalExecution.checkpointLocation,
-            incrementalExecution.currentBatchId,
-            incrementalExecution.currentEventTimeWatermark)
-          newIncrementalExecution.toRdd.mapPartitions { rows =>
-            rows.map(_.get(0, objectType))
-          }.asInstanceOf[RDD[T]]
-        }
-      }
-    datasetWithIncrementalExecution.foreachPartition { iter =>
+    // Hence, we need to manually convert internal rows to objects using encoder.
+    val encoder = encoderFor[T].resolveAndBind(
+      data.logicalPlan.output,
+      data.sparkSession.sessionState.analyzer)
+    data.queryExecution.toRdd.foreachPartition { iter =>
       if (writer.open(TaskContext.getPartitionId(), batchId)) {
         try {
           while (iter.hasNext) {
-            writer.process(iter.next())
+            writer.process(encoder.fromRow(iter.next()))
           }
         } catch {
           case e: Throwable =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index ee6261036f..4a3eeb70b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -171,7 +171,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
     }
   }
 
-  test("foreach with watermark") {
+  test("foreach with watermark: complete") {
     val inputData = MemoryStream[Int]
 
     val windowedAggregation = inputData.toDF()
@@ -204,6 +204,72 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
       query.stop()
     }
   }
+
+  test("foreach with watermark: append") {
+    val inputData = MemoryStream[Int]
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", $"value".cast("timestamp"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"count".as[Long])
+      .map(_.toInt)
+      .repartition(1)
+
+    val query = windowedAggregation
+      .writeStream
+      .outputMode(OutputMode.Append)
+      .foreach(new TestForeachWriter())
+      .start()
+    try {
+      inputData.addData(10, 11, 12)
+      query.processAllAvailable()
+      inputData.addData(25) // Advance watermark to 15 seconds
+      query.processAllAvailable()
+      inputData.addData(25) // Evict items less than previous watermark
+      query.processAllAvailable()
+
+      // There should be 3 batches and only does the last batch contain a value.
+      val allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 3)
+      val expectedEvents = Seq(
+        Seq(
+          ForeachSinkSuite.Open(partition = 0, version = 0),
+          ForeachSinkSuite.Close(None)
+        ),
+        Seq(
+          ForeachSinkSuite.Open(partition = 0, version = 1),
+          ForeachSinkSuite.Close(None)
+        ),
+        Seq(
+          ForeachSinkSuite.Open(partition = 0, version = 2),
+          ForeachSinkSuite.Process(value = 3),
+          ForeachSinkSuite.Close(None)
+        )
+      )
+      assert(allEvents === expectedEvents)
+    } finally {
+      query.stop()
+    }
+  }
+
+  test("foreach sink should support metrics") {
+    val inputData = MemoryStream[Int]
+    val query = inputData.toDS()
+      .writeStream
+      .foreach(new TestForeachWriter())
+      .start()
+    try {
+      inputData.addData(10, 11, 12)
+      query.processAllAvailable()
+      val recentProgress = query.recentProgresses.filter(_.numInputRows != 0).headOption
+      assert(recentProgress.isDefined && recentProgress.get.numInputRows === 3,
+        s"recentProgresses[${query.recentProgresses.toList}] doesn't contain correct metrics")
+    } finally {
+      query.stop()
+    }
+  }
 }
 
 /** A global object to collect events in the executor */
-- 
GitLab