diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 1dd8818dedb2ef64988d99dc3d63fdcb6c54eaba..32e2fdc3f970762bac9f5f1d3d7d1a3f9d1e9ec9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
 import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation}
 import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
-import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
+import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger}
 import org.apache.spark.util.Utils
@@ -40,7 +40,9 @@ import org.apache.spark.util.Utils
  *
  * @since 1.4.0
  */
-final class DataFrameWriter private[sql](df: DataFrame) {
+final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
+
+  private val df = ds.toDF()
 
   /**
    * Specifies the behavior when data or table already exists. Options include:
@@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def mode(saveMode: SaveMode): DataFrameWriter = {
+  def mode(saveMode: SaveMode): DataFrameWriter[T] = {
     // mode() is used for non-continuous queries
     // outputMode() is used for continuous queries
     assertNotStreaming("mode() can only be called on non-continuous queries")
@@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def mode(saveMode: String): DataFrameWriter = {
+  def mode(saveMode: String): DataFrameWriter[T] = {
     // mode() is used for non-continuous queries
     // outputMode() is used for continuous queries
     assertNotStreaming("mode() can only be called on non-continuous queries")
@@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def outputMode(outputMode: OutputMode): DataFrameWriter = {
+  def outputMode(outputMode: OutputMode): DataFrameWriter[T] = {
     assertStreaming("outputMode() can only be called on continuous queries")
     this.outputMode = outputMode
     this
@@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def outputMode(outputMode: String): DataFrameWriter = {
+  def outputMode(outputMode: String): DataFrameWriter[T] = {
     assertStreaming("outputMode() can only be called on continuous queries")
     this.outputMode = outputMode.toLowerCase match {
       case "append" =>
@@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def trigger(trigger: Trigger): DataFrameWriter = {
+  def trigger(trigger: Trigger): DataFrameWriter[T] = {
     assertStreaming("trigger() can only be called on continuous queries")
     this.trigger = trigger
     this
@@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def format(source: String): DataFrameWriter = {
+  def format(source: String): DataFrameWriter[T] = {
     this.source = source
     this
   }
@@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def option(key: String, value: String): DataFrameWriter = {
+  def option(key: String, value: String): DataFrameWriter[T] = {
     this.extraOptions += (key -> value)
     this
   }
@@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 2.0.0
    */
-  def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString)
+  def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)
 
   /**
    * Adds an output option for the underlying data source.
    *
    * @since 2.0.0
    */
-  def option(key: String, value: Long): DataFrameWriter = option(key, value.toString)
+  def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)
 
   /**
    * Adds an output option for the underlying data source.
    *
    * @since 2.0.0
    */
-  def option(key: String, value: Double): DataFrameWriter = option(key, value.toString)
+  def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)
 
   /**
    * (Scala-specific) Adds output options for the underlying data source.
    *
    * @since 1.4.0
    */
-  def options(options: scala.collection.Map[String, String]): DataFrameWriter = {
+  def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = {
     this.extraOptions ++= options
     this
   }
@@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def options(options: java.util.Map[String, String]): DataFrameWriter = {
+  def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
     this.options(options.asScala)
     this
   }
@@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 1.4.0
    */
   @scala.annotation.varargs
-  def partitionBy(colNames: String*): DataFrameWriter = {
+  def partitionBy(colNames: String*): DataFrameWriter[T] = {
     this.partitioningColumns = Option(colNames)
     this
   }
@@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0
    */
   @scala.annotation.varargs
-  def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = {
+  def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = {
     this.numBuckets = Option(numBuckets)
     this.bucketColumnNames = Option(colName +: colNames)
     this
@@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0
    */
   @scala.annotation.varargs
-  def sortBy(colName: String, colNames: String*): DataFrameWriter = {
+  def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
     this.sortColumnNames = Option(colName +: colNames)
     this
   }
@@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   @Experimental
-  def queryName(queryName: String): DataFrameWriter = {
+  def queryName(queryName: String): DataFrameWriter[T] = {
     assertStreaming("queryName() can only be called on continuous queries")
     this.extraOptions += ("queryName" -> queryName)
     this
@@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
       val queryName =
         extraOptions.getOrElse(
           "queryName", throw new AnalysisException("queryName must be specified for memory sink"))
-      val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
-        new Path(userSpecified).toUri.toString
-      }.orElse {
-        val checkpointConfig: Option[String] =
-          df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION)
-
-        checkpointConfig.map { location =>
-          new Path(location, queryName).toUri.toString
-        }
-      }.getOrElse {
+      val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
         Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
       }
 
@@ -378,21 +371,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
           className = source,
           options = extraOptions.toMap,
           partitionColumns = normalizedParCols.getOrElse(Nil))
-
       val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
-      val checkpointLocation = extraOptions.get("checkpointLocation")
-        .orElse {
-          df.sparkSession.sessionState.conf.checkpointLocation.map { l =>
-            new Path(l, queryName).toUri.toString
-          }
-        }.getOrElse {
-          throw new AnalysisException("checkpointLocation must be specified either " +
-            "through option() or SQLConf")
-        }
-
       df.sparkSession.sessionState.continuousQueryManager.startQuery(
         queryName,
-        checkpointLocation,
+        getCheckpointLocation(queryName, failIfNotSet = true).get,
         df,
         dataSource.createSink(outputMode),
         outputMode,
@@ -400,6 +382,94 @@ final class DataFrameWriter private[sql](df: DataFrame) {
     }
   }
 
+  /**
+   * :: Experimental ::
+   * Starts the execution of the streaming query, which will continually send results to the given
+   * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data
+   * generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned
+   * [[ContinuousQuery]] object can be used to interact with the stream.
+   *
+   * Scala example:
+   * {{{
+   *   datasetOfString.write.foreach(new ForeachWriter[String] {
+   *
+   *     def open(partitionId: Long, version: Long): Boolean = {
+   *       // open connection
+   *     }
+   *
+   *     def process(record: String) = {
+   *       // write string to connection
+   *     }
+   *
+   *     def close(errorOrNull: Throwable): Unit = {
+   *       // close the connection
+   *     }
+   *   })
+   * }}}
+   *
+   * Java example:
+   * {{{
+   *  datasetOfString.write().foreach(new ForeachWriter<String>() {
+   *
+   *    @Override
+   *    public boolean open(long partitionId, long version) {
+   *      // open connection
+   *    }
+   *
+   *    @Override
+   *    public void process(String value) {
+   *      // write string to connection
+   *    }
+   *
+   *    @Override
+   *    public void close(Throwable errorOrNull) {
+   *      // close the connection
+   *    }
+   *  });
+   * }}}
+   *
+   * @since 2.0.0
+   */
+  @Experimental
+  def foreach(writer: ForeachWriter[T]): ContinuousQuery = {
+    assertNotBucketed("foreach")
+    assertStreaming(
+      "foreach() can only be called on streaming Datasets/DataFrames.")
+
+    val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
+    val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc)
+    df.sparkSession.sessionState.continuousQueryManager.startQuery(
+      queryName,
+      getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
+        Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath
+      },
+      df,
+      sink,
+      outputMode,
+      trigger)
+  }
+
+  /**
+   * Returns the checkpointLocation for a query. If `failIfNotSet` is `true` but the checkpoint
+   * location is not set, [[AnalysisException]] will be thrown. If `failIfNotSet` is `false`, `None`
+   * will be returned if the checkpoint location is not set.
+   */
+  private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): Option[String] = {
+    val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
+      new Path(userSpecified).toUri.toString
+    }.orElse {
+      df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location =>
+        new Path(location, queryName).toUri.toString
+      }
+    }
+    if (failIfNotSet && checkpointLocation.isEmpty) {
+      throw new AnalysisException("checkpointLocation must be specified either " +
+        """through option("checkpointLocation", ...) or """ +
+        s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")
+    }
+    checkpointLocation
+  }
+
   /**
    * Inserts the content of the [[DataFrame]] to the specified table. It requires that
    * the schema of the [[DataFrame]] is the same as the schema of the table.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 162524a9efc3abb6e8f8b2f008fc08ed0a40fa95..16bbf30a9437091e34bca727722811f5c1574713 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2400,7 +2400,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def write: DataFrameWriter = new DataFrameWriter(toDF())
+  def write: DataFrameWriter[T] = new DataFrameWriter[T](this)
 
   /**
    * Returns the content of the Dataset as a Dataset of JSON strings.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
new file mode 100644
index 0000000000000000000000000000000000000000..09f07426a6bfa453d64e51de240c19c70b102750
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.streaming.ContinuousQuery
+
+/**
+ * :: Experimental ::
+ * A class to consume data generated by a [[ContinuousQuery]]. Typically this is used to send the
+ * generated data to external systems. Each partition will use a new deserialized instance, so you
+ * usually should do all the initialization (e.g. opening a connection or initiating a transaction)
+ * in the `open` method.
+ *
+ * Scala example:
+ * {{{
+ *   datasetOfString.write.foreach(new ForeachWriter[String] {
+ *
+ *     def open(partitionId: Long, version: Long): Boolean = {
+ *       // open connection
+ *     }
+ *
+ *     def process(record: String) = {
+ *       // write string to connection
+ *     }
+ *
+ *     def close(errorOrNull: Throwable): Unit = {
+ *       // close the connection
+ *     }
+ *   })
+ * }}}
+ *
+ * Java example:
+ * {{{
+ *  datasetOfString.write().foreach(new ForeachWriter<String>() {
+ *
+ *    @Override
+ *    public boolean open(long partitionId, long version) {
+ *      // open connection
+ *    }
+ *
+ *    @Override
+ *    public void process(String value) {
+ *      // write string to connection
+ *    }
+ *
+ *    @Override
+ *    public void close(Throwable errorOrNull) {
+ *      // close the connection
+ *    }
+ *  });
+ * }}}
+ * @since 2.0.0
+ */
+@Experimental
+abstract class ForeachWriter[T] extends Serializable {
+
+  /**
+   * Called when starting to process one partition of new data in the executor. The `version` is
+   * for data deduplication when there are failures. When recovering from a failure, some data may
+   * be generated multiple times but they will always have the same version.
+   *
+   * If this method finds using the `partitionId` and `version` that this partition has already been
+   * processed, it can return `false` to skip the further data processing. However, `close` still
+   * will be called for cleaning up resources.
+   *
+   * @param partitionId the partition id.
+   * @param version a unique id for data deduplication.
+   * @return `true` if the corresponding partition and version id should be processed. `false`
+   *         indicates the partition should be skipped.
+   */
+  def open(partitionId: Long, version: Long): Boolean
+
+  /**
+   * Called to process the data in the executor side. This method will be called only when `open`
+   * returns `true`.
+   */
+  def process(value: T): Unit
+
+  /**
+   * Called when stopping to process one partition of new data in the executor side. This is
+   * guaranteed to be called either `open` returns `true` or `false`. However,
+   * `close` won't be called in the following cases:
+   *  - JVM crashes without throwing a `Throwable`
+   *  - `open` throws a `Throwable`.
+   *
+   * @param errorOrNull the error thrown during processing data or null if there was no error.
+   */
+  def close(errorOrNull: Throwable): Unit
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
new file mode 100644
index 0000000000000000000000000000000000000000..14b9b1cb09317d899f8e2857a7525c6a62b04481
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
+
+/**
+ * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @tparam T The expected type of the sink.
+ */
+class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
+
+  override def addBatch(batchId: Long, data: DataFrame): Unit = {
+    data.as[T].foreachPartition { iter =>
+      if (writer.open(TaskContext.getPartitionId(), batchId)) {
+        var isFailed = false
+        try {
+          while (iter.hasNext) {
+            writer.process(iter.next())
+          }
+        } catch {
+          case e: Throwable =>
+            isFailed = true
+            writer.close(e)
+        }
+        if (!isFailed) {
+          writer.close(null)
+        }
+      } else {
+        writer.close(null)
+      }
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e1fb3b947837bc28bcac5f9b163cf29ed22b3768
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.mutable
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  test("foreach") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+      val query = input.toDS().repartition(2).write
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .foreach(new TestForeachWriter())
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+
+      val expectedEventsForPartition0 = Seq(
+        ForeachSinkSuite.Open(partition = 0, version = 0),
+        ForeachSinkSuite.Process(value = 1),
+        ForeachSinkSuite.Process(value = 3),
+        ForeachSinkSuite.Close(None)
+      )
+      val expectedEventsForPartition1 = Seq(
+        ForeachSinkSuite.Open(partition = 1, version = 0),
+        ForeachSinkSuite.Process(value = 2),
+        ForeachSinkSuite.Process(value = 4),
+        ForeachSinkSuite.Close(None)
+      )
+
+      val allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 2)
+      assert {
+        allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) ||
+          allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0)
+      }
+      query.stop()
+    }
+  }
+
+  test("foreach with error") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+      val query = input.toDS().repartition(1).write
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .foreach(new TestForeachWriter() {
+          override def process(value: Int): Unit = {
+            super.process(value)
+            throw new RuntimeException("error")
+          }
+        })
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+
+      val allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 1)
+      assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0))
+      assert(allEvents(0)(1) ===  ForeachSinkSuite.Process(value = 1))
+      val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
+      assert(errorEvent.error.get.isInstanceOf[RuntimeException])
+      assert(errorEvent.error.get.getMessage === "error")
+      query.stop()
+    }
+  }
+}
+
+/** A global object to collect events in the executor */
+object ForeachSinkSuite {
+
+  trait Event
+
+  case class Open(partition: Long, version: Long) extends Event
+
+  case class Process[T](value: T) extends Event
+
+  case class Close(error: Option[Throwable]) extends Event
+
+  private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
+
+  def addEvents(events: Seq[Event]): Unit = {
+    _allEvents.add(events)
+  }
+
+  def allEvents(): Seq[Seq[Event]] = {
+    _allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
+  }
+
+  def clear(): Unit = {
+    _allEvents.clear()
+  }
+}
+
+/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
+class TestForeachWriter extends ForeachWriter[Int] {
+  ForeachSinkSuite.clear()
+
+  private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
+
+  override def open(partitionId: Long, version: Long): Boolean = {
+    events += ForeachSinkSuite.Open(partition = partitionId, version = version)
+    true
+  }
+
+  override def process(value: Int): Unit = {
+    events += ForeachSinkSuite.Process(value)
+  }
+
+  override def close(errorOrNull: Throwable): Unit = {
+    events += ForeachSinkSuite.Close(error = Option(errorOrNull))
+    ForeachSinkSuite.addEvents(events)
+  }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index bab0092c37d34fd5f104be41443183f62aa09428..fc01ff3f5aa07b6936810f2d1bec0a25fbd6f26e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -238,7 +238,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
       shuffleLeft: Boolean,
       shuffleRight: Boolean): Unit = {
     withTable("bucketed_table1", "bucketed_table2") {
-      def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
+      def withBucket(
+          writer: DataFrameWriter[Row],
+          bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = {
         bucketSpec.map { spec =>
           writer.bucketBy(
             spec.numBuckets,