diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 7702f535ad2f4275b44806bc01ffce2644190fbc..cefa8be0c600707a28587147c2f98295d66a4529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -104,7 +104,7 @@ object ResolvedDataSource extends Logging { s"Data source $providerName does not support streamed reading") } - provider.createSource(sqlContext, options, userSpecifiedSchema) + provider.createSource(sqlContext, userSpecifiedSchema, providerName, options) } def createSink( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala new file mode 100644 index 0000000000000000000000000000000000000000..14ba9f69bb1d7193d81ac201912d927d0f6ba124 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Codec + +import com.google.common.base.Charsets.UTF_8 +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} + +import org.apache.spark.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashSet + +/** + * A very simple source that reads text files from the given directory as they appear. + * + * TODO Clean up the metadata files periodically + */ +class FileStreamSource( + sqlContext: SQLContext, + metadataPath: String, + path: String, + dataSchema: Option[StructType], + providerName: String, + dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { + + private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + private var maxBatchId = -1 + private val seenFiles = new OpenHashSet[String] + + /** Map of batch id to files. This map is also stored in `metadataPath`. */ + private val batchToMetadata = new HashMap[Long, Seq[String]] + + { + // Restore file paths from the metadata files + val existingBatchFiles = fetchAllBatchFiles() + if (existingBatchFiles.nonEmpty) { + val existingBatchIds = existingBatchFiles.map(_.getPath.getName.toInt) + maxBatchId = existingBatchIds.max + // Recover "batchToMetadata" and "seenFiles" from existing metadata files. + existingBatchIds.sorted.foreach { batchId => + val files = readBatch(batchId) + if (files.isEmpty) { + // Assert that the corrupted file must be the latest metadata file. + if (batchId != maxBatchId) { + throw new IllegalStateException("Invalid metadata files") + } + maxBatchId = maxBatchId - 1 + } else { + batchToMetadata(batchId) = files + files.foreach(seenFiles.add) + } + } + } + } + + /** Returns the schema of the data from this source */ + override lazy val schema: StructType = { + dataSchema.getOrElse { + val filesPresent = fetchAllFiles() + if (filesPresent.isEmpty) { + if (providerName == "text") { + // Add a default schema for "text" + new StructType().add("value", StringType) + } else { + throw new IllegalArgumentException("No schema specified") + } + } else { + // There are some existing files. Use them to infer the schema. + dataFrameBuilder(filesPresent.toArray).schema + } + } + } + + /** + * Returns the maximum offset that can be retrieved from the source. + * + * `synchronized` on this method is for solving race conditions in tests. In the normal usage, + * there is no race here, so the cost of `synchronized` should be rare. + */ + private def fetchMaxOffset(): LongOffset = synchronized { + val filesPresent = fetchAllFiles() + val newFiles = new ArrayBuffer[String]() + filesPresent.foreach { file => + if (!seenFiles.contains(file)) { + logDebug(s"new file: $file") + newFiles.append(file) + seenFiles.add(file) + } else { + logDebug(s"old file: $file") + } + } + + if (newFiles.nonEmpty) { + maxBatchId += 1 + writeBatch(maxBatchId, newFiles) + } + + new LongOffset(maxBatchId) + } + + /** + * For test only. Run `func` with the internal lock to make sure when `func` is running, + * the current offset won't be changed and no new batch will be emitted. + */ + def withBatchingLocked[T](func: => T): T = synchronized { + func + } + + /** Return the latest offset in the source */ + def currentOffset: LongOffset = synchronized { + new LongOffset(maxBatchId) + } + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + override def getNextBatch(start: Option[Offset]): Option[Batch] = { + val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + val end = fetchMaxOffset() + val endId = end.offset + + if (startId + 1 <= endId) { + val files = (startId + 1 to endId).filter(_ >= 0).flatMap { batchId => + batchToMetadata.getOrElse(batchId, Nil) + }.toArray + logDebug(s"Return files from batches ${startId + 1}:$endId") + logDebug(s"Streaming ${files.mkString(", ")}") + Some(new Batch(end, dataFrameBuilder(files))) + } + else { + None + } + } + + private def fetchAllBatchFiles(): Seq[FileStatus] = { + try fs.listStatus(new Path(metadataPath)) catch { + case _: java.io.FileNotFoundException => + fs.mkdirs(new Path(metadataPath)) + Seq.empty + } + } + + private def fetchAllFiles(): Seq[String] = { + fs.listStatus(new Path(path)) + .filterNot(_.getPath.getName.startsWith("_")) + .map(_.getPath.toUri.toString) + } + + /** + * Write the metadata of a batch to disk. The file format is as follows: + * + * {{{ + * <FileStreamSource.VERSION> + * START + * -/a/b/c + * -/d/e/f + * ... + * END + * }}} + * + * Note: <FileStreamSource.VERSION> means the value of `FileStreamSource.VERSION`. Every file + * path starts with "-" so that we can know if a line is a file path easily. + */ + private def writeBatch(id: Int, files: Seq[String]): Unit = { + assert(files.nonEmpty, "create a new batch without any file") + val output = fs.create(new Path(metadataPath + "/" + id), true) + val writer = new PrintWriter(new OutputStreamWriter(output, UTF_8)) + try { + // scalastyle:off println + writer.println(FileStreamSource.VERSION) + writer.println(FileStreamSource.START_TAG) + files.foreach(file => writer.println(FileStreamSource.PATH_PREFIX + file)) + writer.println(FileStreamSource.END_TAG) + // scalastyle:on println + } finally { + writer.close() + } + batchToMetadata(id) = files + } + + /** Read the file names of the specified batch id from the metadata file */ + private def readBatch(id: Int): Seq[String] = { + val input = fs.open(new Path(metadataPath + "/" + id)) + try { + FileStreamSource.readBatch(input) + } finally { + input.close() + } + } +} + +object FileStreamSource { + + private val START_TAG = "START" + private val END_TAG = "END" + private val PATH_PREFIX = "-" + val VERSION = "FILESTREAM_V1" + + /** + * Parse a metadata file and return the content. If the metadata file is corrupted, it will return + * an empty `Seq`. + */ + def readBatch(input: InputStream): Seq[String] = { + val lines = scala.io.Source.fromInputStream(input)(Codec.UTF8).getLines().toArray + if (lines.length < 4) { + // version + start tag + end tag + at least one file path + return Nil + } + if (lines.head != VERSION) { + return Nil + } + if (lines(1) != START_TAG) { + return Nil + } + if (lines.last != END_TAG) { + return Nil + } + lines.slice(2, lines.length - 1).map(_.stripPrefix(PATH_PREFIX)) // Drop character "-" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 737be7dfd12f6087978a42994f53b0c7d1cf29af..428a313ca9dc21f2cf4ff8533e435a25aa6a0274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet @@ -131,8 +131,9 @@ trait SchemaRelationProvider { trait StreamSourceProvider { def createSource( sqlContext: SQLContext, - parameters: Map[String, String], - schema: Option[StructType]): Source + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source } /** @@ -169,7 +170,7 @@ trait StreamSinkProvider { * @since 1.4.0 */ @Experimental -trait HadoopFsRelationProvider { +trait HadoopFsRelationProvider extends StreamSourceProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity @@ -196,6 +197,30 @@ trait HadoopFsRelationProvider { } createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters) } + + override def createSource( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val path = parameters.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + val metadataPath = parameters.getOrElse("metadataPath", s"$path/_metadata") + + def dataFrameBuilder(files: Array[String]): DataFrame = { + val relation = createRelation( + sqlContext, + files, + schema, + partitionColumns = None, + bucketSpec = None, + parameters) + DataFrame(sqlContext, LogicalRelation(relation)) + } + + new FileStreamSource(sqlContext, metadataPath, path, schema, providerName, dataFrameBuilder) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index f45abbf2496a2744580de28f14e3dbb4b92212d8..7e388ea6023430e558a01e8f3775346dfe999136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -59,6 +59,8 @@ trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + + def toDS[A: Encoder](): Dataset[A] = new Dataset(sqlContext, StreamingRelation(s)) } /** How long to wait for an active stream to catch up when checking a result. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 36212e4395985871f74f44988565efd2e2ee4995..b762f9b90ed86ab9476fe555e97fe1edd3e1eeac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -33,8 +33,9 @@ object LastOptions { class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def createSource( sqlContext: SQLContext, - parameters: Map[String, String], - schema: Option[StructType]): Source = { + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { LastOptions.parameters = parameters LastOptions.schema = schema new Source { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..7a4ee0ef264d842582cb176ace652dfb8051a340 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.{ByteArrayInputStream, File, FileNotFoundException, InputStream} + +import com.google.common.base.Charsets.UTF_8 + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.FileStreamSource._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.Utils + +class FileStreamSourceTest extends StreamTest with SharedSQLContext { + + import testImplicits._ + + case class AddTextFileData(source: FileStreamSource, content: String, src: File, tmp: File) + extends AddData { + + override def addData(): Offset = { + source.withBatchingLocked { + val file = Utils.tempFileWith(new File(tmp, "text")) + stringToFile(file, content).renameTo(new File(src, file.getName)) + source.currentOffset + } + 1 + } + } + + case class AddParquetFileData( + source: FileStreamSource, + content: Seq[String], + src: File, + tmp: File) extends AddData { + + override def addData(): Offset = { + source.withBatchingLocked { + val file = Utils.tempFileWith(new File(tmp, "parquet")) + content.toDS().toDF().write.parquet(file.getCanonicalPath) + file.renameTo(new File(src, file.getName)) + source.currentOffset + } + 1 + } + } + + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ + def createFileStreamSource( + format: String, + path: String, + schema: Option[StructType] = None): FileStreamSource = { + val reader = + if (schema.isDefined) { + sqlContext.read.format(format).schema(schema.get) + } else { + sqlContext.read.format(format) + } + reader.stream(path) + .queryExecution.analyzed + .collect { case StreamingRelation(s: FileStreamSource, _) => s } + .head + } + + val valueSchema = new StructType().add("value", StringType) +} + +class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { + + import testImplicits._ + + private def createFileStreamSourceAndGetSchema( + format: Option[String], + path: Option[String], + schema: Option[StructType] = None): StructType = { + val reader = sqlContext.read + format.foreach(reader.format) + schema.foreach(reader.schema) + val df = + if (path.isDefined) { + reader.stream(path.get) + } else { + reader.stream() + } + df.queryExecution.analyzed + .collect { case StreamingRelation(s: FileStreamSource, _) => s } + .head + .schema + } + + test("FileStreamSource schema: no path") { + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema(format = None, path = None, schema = None) + } + assert("'path' is not specified" === e.getMessage) + } + + test("FileStreamSource schema: path doesn't exist") { + intercept[FileNotFoundException] { + createFileStreamSourceAndGetSchema(format = None, path = Some("/a/b/c"), schema = None) + } + } + + test("FileStreamSource schema: text, no existing files, no schema") { + withTempDir { src => + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: text, existing files, no schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "a\nb\nc") + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: text, existing files, schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "a\nb\nc") + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("FileStreamSource schema: parquet, no existing files, no schema") { + withTempDir { src => + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(new File(src, "1").getCanonicalPath), schema = None) + } + assert("No schema specified" === e.getMessage) + } + } + + test("FileStreamSource schema: parquet, existing files, no schema") { + withTempDir { src => + Seq("a", "b", "c").toDS().as("userColumn").toDF() + .write.parquet(new File(src, "1").getCanonicalPath) + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: parquet, existing files, schema") { + withTempPath { src => + Seq("a", "b", "c").toDS().as("oldUserColumn").toDF() + .write.parquet(new File(src, "1").getCanonicalPath) + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("FileStreamSource schema: json, no existing files, no schema") { + withTempDir { src => + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + } + assert("No schema specified" === e.getMessage) + } + } + + test("FileStreamSource schema: json, existing files, no schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c': '3'}") + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("c", StringType)) + } + } + + test("FileStreamSource schema: json, existing files, schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c', '3'}") + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("read from text files") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from json files") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("json", src.getCanonicalPath, Some(valueSchema)) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData( + textSource, + "{'value': 'drop1'}\n{'value': 'keep2'}\n{'value': 'keep3'}", + src, + tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData( + textSource, + "{'value': 'drop4'}\n{'value': 'keep5'}\n{'value': 'keep6'}", + src, + tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData( + textSource, + "{'value': 'drop7'}\n{'value': 'keep8'}\n{'value': 'keep9'}", + src, + tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from json files with inferring schema") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") + + val textSource = createFileStreamSource("json", src.getCanonicalPath) + + // FileStreamSource should infer the column "c" + val filtered = textSource.toDF().filter($"c" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from parquet files") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val fileSource = createFileStreamSource("parquet", src.getCanonicalPath, Some(valueSchema)) + val filtered = fileSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddParquetFileData(fileSource, Seq("drop1", "keep2", "keep3"), src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddParquetFileData(fileSource, Seq("drop4", "keep5", "keep6"), src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddParquetFileData(fileSource, Seq("drop7", "keep8", "keep9"), src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("file stream source without schema") { + val src = Utils.createTempDir("streaming.src") + + // Only "text" doesn't need a schema + createFileStreamSource("text", src.getCanonicalPath) + + // Both "json" and "parquet" require a schema if no existing file to infer + intercept[IllegalArgumentException] { + createFileStreamSource("json", src.getCanonicalPath) + } + intercept[IllegalArgumentException] { + createFileStreamSource("parquet", src.getCanonicalPath) + } + + Utils.deleteRecursively(src) + } + + test("fault tolerance") { + def assertBatch(batch1: Option[Batch], batch2: Option[Batch]): Unit = { + (batch1, batch2) match { + case (Some(b1), Some(b2)) => + assert(b1.end === b2.end) + assert(b1.data.as[String].collect() === b2.data.as[String].collect()) + case (None, None) => + case _ => fail(s"batch ($batch1) is not equal to batch ($batch2)") + } + } + + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + val textSource2 = createFileStreamSource("text", src.getCanonicalPath) + assert(textSource2.currentOffset === textSource.currentOffset) + assertBatch(textSource2.getNextBatch(None), textSource.getNextBatch(None)) + for (f <- 0L to textSource.currentOffset.offset) { + val offset = LongOffset(f) + assertBatch(textSource2.getNextBatch(Some(offset)), textSource.getNextBatch(Some(offset))) + } + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("fault tolerance with corrupted metadata file") { + val src = Utils.createTempDir("streaming.src") + assert(new File(src, "_metadata").mkdirs()) + stringToFile( + new File(src, "_metadata/0"), + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n") + stringToFile(new File(src, "_metadata/1"), s"${FileStreamSource.VERSION}\nSTART\n-") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + // the metadata file of batch is corrupted, so currentOffset should be 0 + assert(textSource.currentOffset === LongOffset(0)) + + Utils.deleteRecursively(src) + } + + test("fault tolerance with normal metadata file") { + val src = Utils.createTempDir("streaming.src") + assert(new File(src, "_metadata").mkdirs()) + stringToFile( + new File(src, "_metadata/0"), + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n") + stringToFile( + new File(src, "_metadata/1"), + s"${FileStreamSource.VERSION}\nSTART\n-/x/y/z\nEND\n") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + assert(textSource.currentOffset === LongOffset(1)) + + Utils.deleteRecursively(src) + } + + test("readBatch") { + def stringToStream(str: String): InputStream = new ByteArrayInputStream(str.getBytes(UTF_8)) + + // Invalid metadata + assert(readBatch(stringToStream("")) === Nil) + assert(readBatch(stringToStream(FileStreamSource.VERSION)) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\n")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n")) === Nil) + assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEN")) === Nil) + + // Valid metadata + assert(readBatch(stringToStream( + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND")) === Seq("/a/b/c")) + assert(readBatch(stringToStream( + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND\n")) === Seq("/a/b/c")) + assert(readBatch(stringToStream( + s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n")) + === Seq("/a/b/c", "/e/f/g")) + } +} + +class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { + + import testImplicits._ + + test("file source stress test") { + val src = Utils.createTempDir("streaming.src") + val tmp = Utils.createTempDir("streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val ds = textSource.toDS[String]().map(_.toInt + 1) + runStressTest(ds, data => { + AddTextFileData(textSource, data.mkString("\n"), src, tmp) + }) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } +}