Skip to content
Snippets Groups Projects
Commit 2a40de40 authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Tathagata Das
Browse files

[SPARK-18497][SS] Make ForeachSink support watermark

## What changes were proposed in this pull request?

The issue in ForeachSink is the new created DataSet still uses the old QueryExecution. When `foreachPartition` is called, `QueryExecution.toString` will be called and then fail because it doesn't know how to plan EventTimeWatermark.

This PR just replaces the QueryExecution with IncrementalExecution to fix the issue.

## How was this patch tested?

`test("foreach with watermark")`.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #15934 from zsxwing/SPARK-18497.
parent 6f7ff750
No related branches found
No related tags found
No related merge requests found
...@@ -47,22 +47,22 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria ...@@ -47,22 +47,22 @@ class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Seria
// method supporting incremental planning. But in the long run, we should generally make newly // method supporting incremental planning. But in the long run, we should generally make newly
// created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to
// resolve). // resolve).
val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution]
val datasetWithIncrementalExecution = val datasetWithIncrementalExecution =
new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) { new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) {
override lazy val rdd: RDD[T] = { override lazy val rdd: RDD[T] = {
val objectType = exprEnc.deserializer.dataType val objectType = exprEnc.deserializer.dataType
val deserialized = CatalystSerde.deserialize[T](logicalPlan) val deserialized = CatalystSerde.deserialize[T](logicalPlan)
// was originally: sparkSession.sessionState.executePlan(deserialized) ... // was originally: sparkSession.sessionState.executePlan(deserialized) ...
val incrementalExecution = new IncrementalExecution( val newIncrementalExecution = new IncrementalExecution(
this.sparkSession, this.sparkSession,
deserialized, deserialized,
data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, incrementalExecution.outputMode,
data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, incrementalExecution.checkpointLocation,
data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId, incrementalExecution.currentBatchId,
data.queryExecution.asInstanceOf[IncrementalExecution].currentEventTimeWatermark) incrementalExecution.currentEventTimeWatermark)
incrementalExecution.toRdd.mapPartitions { rows => newIncrementalExecution.toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType)) rows.map(_.get(0, objectType))
}.asInstanceOf[RDD[T]] }.asInstanceOf[RDD[T]]
} }
......
...@@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter ...@@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
...@@ -169,6 +170,40 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf ...@@ -169,6 +170,40 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
assert(errorEvent.error.get.getMessage === "error") assert(errorEvent.error.get.getMessage === "error")
} }
} }
test("foreach with watermark") {
val inputData = MemoryStream[Int]
val windowedAggregation = inputData.toDF()
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.groupBy(window($"eventTime", "5 seconds") as 'window)
.agg(count("*") as 'count)
.select($"count".as[Long])
.map(_.toInt)
.repartition(1)
val query = windowedAggregation
.writeStream
.outputMode(OutputMode.Complete)
.foreach(new TestForeachWriter())
.start()
try {
inputData.addData(10, 11, 12)
query.processAllAvailable()
val allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 1)
val expectedEvents = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))
} finally {
query.stop()
}
}
} }
/** A global object to collect events in the executor */ /** A global object to collect events in the executor */
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment