From 5bce4580939c27876f11cd75f0dc2190fb9fa908 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Thu, 7 Jul 2016 23:19:41 -0700
Subject: [PATCH] [SPARK-16430][SQL][STREAMING] Add option maxFilesPerTrigger

## What changes were proposed in this pull request?

An option that limits the file stream source to read 1 file at a time enables rate limiting. It has the additional convenience that a static set of files can be used like a stream for testing as this will allows those files to be considered one at a time.

This PR adds option `maxFilesPerTrigger`.

## How was this patch tested?

New unit test

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #14094 from tdas/SPARK-16430.
---
 .../streaming/FileStreamSource.scala          | 40 ++++++----
 .../sql/streaming/DataStreamReader.scala      | 10 +++
 .../sql/streaming/FileStreamSourceSuite.scala | 76 +++++++++++++++++++
 3 files changed, 112 insertions(+), 14 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index 11bf3c0bd2..72b335a42e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import scala.collection.mutable.ArrayBuffer
+import scala.util.Try
 
 import org.apache.hadoop.fs.Path
 
@@ -46,6 +46,9 @@ class FileStreamSource(
   private val metadataLog = new HDFSMetadataLog[Seq[String]](sparkSession, metadataPath)
   private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)
 
+  /** Maximum number of new files to be considered in each batch */
+  private val maxFilesPerBatch = getMaxFilesPerBatch()
+
   private val seenFiles = new OpenHashSet[String]
   metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) =>
     files.foreach(seenFiles.add)
@@ -58,19 +61,17 @@ class FileStreamSource(
    * there is no race here, so the cost of `synchronized` should be rare.
    */
   private def fetchMaxOffset(): LongOffset = synchronized {
-    val filesPresent = fetchAllFiles()
-    val newFiles = new ArrayBuffer[String]()
-    filesPresent.foreach { file =>
-      if (!seenFiles.contains(file)) {
-        logDebug(s"new file: $file")
-        newFiles.append(file)
-        seenFiles.add(file)
-      } else {
-        logDebug(s"old file: $file")
-      }
+    val newFiles = fetchAllFiles().filter(!seenFiles.contains(_))
+    val batchFiles =
+      if (maxFilesPerBatch.nonEmpty) newFiles.take(maxFilesPerBatch.get) else newFiles
+    batchFiles.foreach { file =>
+      seenFiles.add(file)
+      logDebug(s"New file: $file")
     }
-
-    if (newFiles.nonEmpty) {
+    logTrace(s"Number of new files = ${newFiles.size})")
+    logTrace(s"Number of files selected for batch = ${batchFiles.size}")
+    logTrace(s"Number of seen files = ${seenFiles.size}")
+    if (batchFiles.nonEmpty) {
       maxBatchId += 1
       metadataLog.add(maxBatchId, newFiles)
       logInfo(s"Max batch id increased to $maxBatchId with ${newFiles.size} new files")
@@ -118,7 +119,7 @@ class FileStreamSource(
     val startTime = System.nanoTime
     val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath)
     val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType))
-    val files = catalog.allFiles().map(_.getPath.toUri.toString)
+    val files = catalog.allFiles().sortBy(_.getModificationTime).map(_.getPath.toUri.toString)
     val endTime = System.nanoTime
     val listingTimeMs = (endTime.toDouble - startTime) / 1000000
     if (listingTimeMs > 2000) {
@@ -131,6 +132,17 @@ class FileStreamSource(
     files
   }
 
+  private def getMaxFilesPerBatch(): Option[Int] = {
+    new CaseInsensitiveMap(options)
+      .get("maxFilesPerTrigger")
+      .map { str =>
+        Try(str.toInt).toOption.filter(_ > 0).getOrElse {
+          throw new IllegalArgumentException(
+            s"Invalid value '$str' for option 'maxFilesPerBatch', must be a positive integer")
+        }
+      }
+  }
+
   override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1)
 
   override def toString: String = s"FileStreamSource[$qualifiedBasePath]"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 248247a257..2e606b21bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -161,6 +161,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
    * schema in advance, use the version that specifies the schema to avoid the extra scan.
    *
    * You can set the following JSON-specific options to deal with non-standard JSON files:
+   * <li>`maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+   * considered in every trigger.</li>
    * <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
    * <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal
    * type. If the values do not fit in decimal, then it infers them as doubles.</li>
@@ -199,6 +201,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
    * specify the schema explicitly using [[schema]].
    *
    * You can set the following CSV-specific options to deal with CSV files:
+   * <li>`maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+   * considered in every trigger.</li>
    * <li>`sep` (default `,`): sets the single character as a separator for each
    * field and value.</li>
    * <li>`encoding` (default `UTF-8`): decodes the CSV files by the given encoding
@@ -251,6 +255,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
    * Loads a Parquet file stream, returning the result as a [[DataFrame]].
    *
    * You can set the following Parquet-specific option(s) for reading Parquet files:
+   * <li>`maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+   * considered in every trigger.</li>
    * <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets
    * whether we should merge schemas collected from all Parquet part-files. This will override
    * `spark.sql.parquet.mergeSchema`.</li>
@@ -276,6 +282,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
    *   spark.readStream().text("/path/to/directory/")
    * }}}
    *
+   * You can set the following text-specific options to deal with text files:
+   * <li>`maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+   * considered in every trigger.</li>
+   *
    * @since 2.0.0
    */
   @Experimental
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 8a34cf95f9..29ce578bcd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -585,6 +585,82 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
     }
   }
 
+  test("max files per trigger") {
+    withTempDir { case src =>
+      var lastFileModTime: Option[Long] = None
+
+      /** Create a text file with a single data item */
+      def createFile(data: Int): File = {
+        val file = stringToFile(new File(src, s"$data.txt"), data.toString)
+        if (lastFileModTime.nonEmpty) file.setLastModified(lastFileModTime.get + 1000)
+        lastFileModTime = Some(file.lastModified)
+        file
+      }
+
+      createFile(1)
+      createFile(2)
+      createFile(3)
+
+      // Set up a query to read text files 2 at a time
+      val df = spark
+        .readStream
+        .option("maxFilesPerTrigger", 2)
+        .text(src.getCanonicalPath)
+      val q = df
+        .writeStream
+        .format("memory")
+        .queryName("file_data")
+        .start()
+        .asInstanceOf[StreamExecution]
+      q.processAllAvailable()
+      val memorySink = q.sink.asInstanceOf[MemorySink]
+      val fileSource = q.logicalPlan.collect {
+        case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] =>
+          source.asInstanceOf[FileStreamSource]
+      }.head
+
+      /** Check the data read in the last batch */
+      def checkLastBatchData(data: Int*): Unit = {
+        val schema = StructType(Seq(StructField("value", StringType)))
+        val df = spark.createDataFrame(
+          spark.sparkContext.makeRDD(memorySink.latestBatchData), schema)
+        checkAnswer(df, data.map(_.toString).toDF("value"))
+      }
+
+      /** Check how many batches have executed since the last time this check was made */
+      var lastBatchId = -1L
+      def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = {
+        require(lastBatchId >= 0)
+        assert(memorySink.latestBatchId.get === lastBatchId + numBatches)
+        lastBatchId = memorySink.latestBatchId.get
+      }
+
+      checkLastBatchData(3)  // (1 and 2) should be in batch 1, (3) should be in batch 2 (last)
+      lastBatchId = memorySink.latestBatchId.get
+
+      fileSource.withBatchingLocked {
+        createFile(4)
+        createFile(5)   // 4 and 5 should be in a batch
+        createFile(6)
+        createFile(7)   // 6 and 7 should be in the last batch
+      }
+      q.processAllAvailable()
+      checkLastBatchData(6, 7)
+      checkNumBatchesSinceLastCheck(2)
+
+      fileSource.withBatchingLocked {
+        createFile(8)
+        createFile(9)    // 8 and 9 should be in a batch
+        createFile(10)
+        createFile(11)   // 10 and 11 should be in a batch
+        createFile(12)   // 12 should be in the last batch
+      }
+      q.processAllAvailable()
+      checkLastBatchData(12)
+      checkNumBatchesSinceLastCheck(3)
+    }
+  }
+
   test("explain") {
     withTempDirs { case (src, tmp) =>
       src.mkdirs()
-- 
GitLab