diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1dd8818dedb2ef64988d99dc3d63fdcb6c54eaba..32e2fdc3f970762bac9f5f1d3d7d1a3f9d1e9ec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Utils @@ -40,7 +40,9 @@ import org.apache.spark.util.Utils * * @since 1.4.0 */ -final class DataFrameWriter private[sql](df: DataFrame) { +final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { + + private val df = ds.toDF() /** * Specifies the behavior when data or table already exists. Options include: @@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def mode(saveMode: SaveMode): DataFrameWriter = { + def mode(saveMode: SaveMode): DataFrameWriter[T] = { // mode() is used for non-continuous queries // outputMode() is used for continuous queries assertNotStreaming("mode() can only be called on non-continuous queries") @@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def mode(saveMode: String): DataFrameWriter = { + def mode(saveMode: String): DataFrameWriter[T] = { // mode() is used for non-continuous queries // outputMode() is used for continuous queries assertNotStreaming("mode() can only be called on non-continuous queries") @@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ @Experimental - def outputMode(outputMode: OutputMode): DataFrameWriter = { + def outputMode(outputMode: OutputMode): DataFrameWriter[T] = { assertStreaming("outputMode() can only be called on continuous queries") this.outputMode = outputMode this @@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ @Experimental - def outputMode(outputMode: String): DataFrameWriter = { + def outputMode(outputMode: String): DataFrameWriter[T] = { assertStreaming("outputMode() can only be called on continuous queries") this.outputMode = outputMode.toLowerCase match { case "append" => @@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ @Experimental - def trigger(trigger: Trigger): DataFrameWriter = { + def trigger(trigger: Trigger): DataFrameWriter[T] = { assertStreaming("trigger() can only be called on continuous queries") this.trigger = trigger this @@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def format(source: String): DataFrameWriter = { + def format(source: String): DataFrameWriter[T] = { this.source = source this } @@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def option(key: String, value: String): DataFrameWriter = { + def option(key: String, value: String): DataFrameWriter[T] = { this.extraOptions += (key -> value) this } @@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 2.0.0 */ - def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString) + def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString) /** * Adds an output option for the underlying data source. * * @since 2.0.0 */ - def option(key: String, value: Long): DataFrameWriter = option(key, value.toString) + def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString) /** * Adds an output option for the underlying data source. * * @since 2.0.0 */ - def option(key: String, value: Double): DataFrameWriter = option(key, value.toString) + def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString) /** * (Scala-specific) Adds output options for the underlying data source. * * @since 1.4.0 */ - def options(options: scala.collection.Map[String, String]): DataFrameWriter = { + def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = { this.extraOptions ++= options this } @@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def options(options: java.util.Map[String, String]): DataFrameWriter = { + def options(options: java.util.Map[String, String]): DataFrameWriter[T] = { this.options(options.asScala) this } @@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataFrameWriter = { + def partitionBy(colNames: String*): DataFrameWriter[T] = { this.partitioningColumns = Option(colNames) this } @@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0 */ @scala.annotation.varargs - def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = { this.numBuckets = Option(numBuckets) this.bucketColumnNames = Option(colName +: colNames) this @@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0 */ @scala.annotation.varargs - def sortBy(colName: String, colNames: String*): DataFrameWriter = { + def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = { this.sortColumnNames = Option(colName +: colNames) this } @@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ @Experimental - def queryName(queryName: String): DataFrameWriter = { + def queryName(queryName: String): DataFrameWriter[T] = { assertStreaming("queryName() can only be called on continuous queries") this.extraOptions += ("queryName" -> queryName) this @@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val queryName = extraOptions.getOrElse( "queryName", throw new AnalysisException("queryName must be specified for memory sink")) - val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => - new Path(userSpecified).toUri.toString - }.orElse { - val checkpointConfig: Option[String] = - df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION) - - checkpointConfig.map { location => - new Path(location, queryName).toUri.toString - } - }.getOrElse { + val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = false).getOrElse { Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath } @@ -378,21 +371,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { className = source, options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) - val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) - val checkpointLocation = extraOptions.get("checkpointLocation") - .orElse { - df.sparkSession.sessionState.conf.checkpointLocation.map { l => - new Path(l, queryName).toUri.toString - } - }.getOrElse { - throw new AnalysisException("checkpointLocation must be specified either " + - "through option() or SQLConf") - } - df.sparkSession.sessionState.continuousQueryManager.startQuery( queryName, - checkpointLocation, + getCheckpointLocation(queryName, failIfNotSet = true).get, df, dataSource.createSink(outputMode), outputMode, @@ -400,6 +382,94 @@ final class DataFrameWriter private[sql](df: DataFrame) { } } + /** + * :: Experimental :: + * Starts the execution of the streaming query, which will continually send results to the given + * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data + * generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned + * [[ContinuousQuery]] object can be used to interact with the stream. + * + * Scala example: + * {{{ + * datasetOfString.write.foreach(new ForeachWriter[String] { + * + * def open(partitionId: Long, version: Long): Boolean = { + * // open connection + * } + * + * def process(record: String) = { + * // write string to connection + * } + * + * def close(errorOrNull: Throwable): Unit = { + * // close the connection + * } + * }) + * }}} + * + * Java example: + * {{{ + * datasetOfString.write().foreach(new ForeachWriter<String>() { + * + * @Override + * public boolean open(long partitionId, long version) { + * // open connection + * } + * + * @Override + * public void process(String value) { + * // write string to connection + * } + * + * @Override + * public void close(Throwable errorOrNull) { + * // close the connection + * } + * }); + * }}} + * + * @since 2.0.0 + */ + @Experimental + def foreach(writer: ForeachWriter[T]): ContinuousQuery = { + assertNotBucketed("foreach") + assertStreaming( + "foreach() can only be called on streaming Datasets/DataFrames.") + + val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) + val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc) + df.sparkSession.sessionState.continuousQueryManager.startQuery( + queryName, + getCheckpointLocation(queryName, failIfNotSet = false).getOrElse { + Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath + }, + df, + sink, + outputMode, + trigger) + } + + /** + * Returns the checkpointLocation for a query. If `failIfNotSet` is `true` but the checkpoint + * location is not set, [[AnalysisException]] will be thrown. If `failIfNotSet` is `false`, `None` + * will be returned if the checkpoint location is not set. + */ + private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): Option[String] = { + val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location => + new Path(location, queryName).toUri.toString + } + } + if (failIfNotSet && checkpointLocation.isEmpty) { + throw new AnalysisException("checkpointLocation must be specified either " + + """through option("checkpointLocation", ...) or """ + + s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""") + } + checkpointLocation + } + /** * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 162524a9efc3abb6e8f8b2f008fc08ed0a40fa95..16bbf30a9437091e34bca727722811f5c1574713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2400,7 +2400,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def write: DataFrameWriter = new DataFrameWriter(toDF()) + def write: DataFrameWriter[T] = new DataFrameWriter[T](this) /** * Returns the content of the Dataset as a Dataset of JSON strings. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..09f07426a6bfa453d64e51de240c19c70b102750 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.streaming.ContinuousQuery + +/** + * :: Experimental :: + * A class to consume data generated by a [[ContinuousQuery]]. Typically this is used to send the + * generated data to external systems. Each partition will use a new deserialized instance, so you + * usually should do all the initialization (e.g. opening a connection or initiating a transaction) + * in the `open` method. + * + * Scala example: + * {{{ + * datasetOfString.write.foreach(new ForeachWriter[String] { + * + * def open(partitionId: Long, version: Long): Boolean = { + * // open connection + * } + * + * def process(record: String) = { + * // write string to connection + * } + * + * def close(errorOrNull: Throwable): Unit = { + * // close the connection + * } + * }) + * }}} + * + * Java example: + * {{{ + * datasetOfString.write().foreach(new ForeachWriter<String>() { + * + * @Override + * public boolean open(long partitionId, long version) { + * // open connection + * } + * + * @Override + * public void process(String value) { + * // write string to connection + * } + * + * @Override + * public void close(Throwable errorOrNull) { + * // close the connection + * } + * }); + * }}} + * @since 2.0.0 + */ +@Experimental +abstract class ForeachWriter[T] extends Serializable { + + /** + * Called when starting to process one partition of new data in the executor. The `version` is + * for data deduplication when there are failures. When recovering from a failure, some data may + * be generated multiple times but they will always have the same version. + * + * If this method finds using the `partitionId` and `version` that this partition has already been + * processed, it can return `false` to skip the further data processing. However, `close` still + * will be called for cleaning up resources. + * + * @param partitionId the partition id. + * @param version a unique id for data deduplication. + * @return `true` if the corresponding partition and version id should be processed. `false` + * indicates the partition should be skipped. + */ + def open(partitionId: Long, version: Long): Boolean + + /** + * Called to process the data in the executor side. This method will be called only when `open` + * returns `true`. + */ + def process(value: T): Unit + + /** + * Called when stopping to process one partition of new data in the executor side. This is + * guaranteed to be called either `open` returns `true` or `false`. However, + * `close` won't be called in the following cases: + * - JVM crashes without throwing a `Throwable` + * - `open` throws a `Throwable`. + * + * @param errorOrNull the error thrown during processing data or null if there was no error. + */ + def close(errorOrNull: Throwable): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala new file mode 100644 index 0000000000000000000000000000000000000000..14b9b1cb09317d899f8e2857a7525c6a62b04481 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.TaskContext +import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} + +/** + * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + data.as[T].foreachPartition { iter => + if (writer.open(TaskContext.getPartitionId(), batchId)) { + var isFailed = false + try { + while (iter.hasNext) { + writer.process(iter.next()) + } + } catch { + case e: Throwable => + isFailed = true + writer.close(e) + } + if (!isFailed) { + writer.close(null) + } + } else { + writer.close(null) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e1fb3b947837bc28bcac5f9b163cf29ed22b3768 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext + +class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("foreach") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(2).write + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .foreach(new TestForeachWriter()) + input.addData(1, 2, 3, 4) + query.processAllAvailable() + + val expectedEventsForPartition0 = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 1), + ForeachSinkSuite.Process(value = 3), + ForeachSinkSuite.Close(None) + ) + val expectedEventsForPartition1 = Seq( + ForeachSinkSuite.Open(partition = 1, version = 0), + ForeachSinkSuite.Process(value = 2), + ForeachSinkSuite.Process(value = 4), + ForeachSinkSuite.Close(None) + ) + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 2) + assert { + allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) || + allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0) + } + query.stop() + } + } + + test("foreach with error") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(1).write + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .foreach(new TestForeachWriter() { + override def process(value: Int): Unit = { + super.process(value) + throw new RuntimeException("error") + } + }) + input.addData(1, 2, 3, 4) + query.processAllAvailable() + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + assert(errorEvent.error.get.isInstanceOf[RuntimeException]) + assert(errorEvent.error.get.getMessage === "error") + query.stop() + } + } +} + +/** A global object to collect events in the executor */ +object ForeachSinkSuite { + + trait Event + + case class Open(partition: Long, version: Long) extends Event + + case class Process[T](value: T) extends Event + + case class Close(error: Option[Throwable]) extends Event + + private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]() + + def addEvents(events: Seq[Event]): Unit = { + _allEvents.add(events) + } + + def allEvents(): Seq[Seq[Event]] = { + _allEvents.toArray(new Array[Seq[Event]](_allEvents.size())) + } + + def clear(): Unit = { + _allEvents.clear() + } +} + +/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ +class TestForeachWriter extends ForeachWriter[Int] { + ForeachSinkSuite.clear() + + private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + + override def open(partitionId: Long, version: Long): Boolean = { + events += ForeachSinkSuite.Open(partition = partitionId, version = version) + true + } + + override def process(value: Int): Unit = { + events += ForeachSinkSuite.Process(value) + } + + override def close(errorOrNull: Throwable): Unit = { + events += ForeachSinkSuite.Close(error = Option(errorOrNull)) + ForeachSinkSuite.addEvents(events) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index bab0092c37d34fd5f104be41443183f62aa09428..fc01ff3f5aa07b6936810f2d1bec0a25fbd6f26e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -238,7 +238,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet shuffleLeft: Boolean, shuffleRight: Boolean): Unit = { withTable("bucketed_table1", "bucketed_table2") { - def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = { + def withBucket( + writer: DataFrameWriter[Row], + bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = { bucketSpec.map { spec => writer.bucketBy( spec.numBuckets,