diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 73e6abc6dad3790af2ff1229d665c9dde57fc5ed..47567032b01958de314463d36aa8bcd871a50d0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -133,20 +133,24 @@ object TextInputCSVDataSource extends CSVDataSource {
       sparkSession: SparkSession,
       inputPaths: Seq[FileStatus],
       parsedOptions: CSVOptions): Option[StructType] = {
-    val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
-    val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first()
-    val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
-    val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
-    val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
-    val tokenRDD = csv.rdd.mapPartitions { iter =>
-      val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
-      val linesWithoutHeader =
-        CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
-      val parser = new CsvParser(parsedOptions.asParserSettings)
-      linesWithoutHeader.map(parser.parseLine)
+    val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
+    CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match {
+      case Some(firstLine) =>
+        val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
+        val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+        val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+        val tokenRDD = csv.rdd.mapPartitions { iter =>
+          val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
+          val linesWithoutHeader =
+            CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
+          val parser = new CsvParser(parsedOptions.asParserSettings)
+          linesWithoutHeader.map(parser.parseLine)
+        }
+        Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+      case None =>
+        // If the first line could not be read, just return the empty schema.
+        Some(StructType(Nil))
     }
-
-    Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
   }
 
   private def createBaseDataset(
@@ -190,28 +194,28 @@ object WholeFileCSVDataSource extends CSVDataSource {
       sparkSession: SparkSession,
       inputPaths: Seq[FileStatus],
       parsedOptions: CSVOptions): Option[StructType] = {
-    val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
-    val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines =>
+    val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions)
+    csv.flatMap { lines =>
       UnivocityParser.tokenizeStream(
         CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
-        false,
+        shouldDropHeader = false,
         new CsvParser(parsedOptions.asParserSettings))
-    }.take(1).headOption
-
-    if (maybeFirstRow.isDefined) {
-      val firstRow = maybeFirstRow.get
-      val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
-      val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
-      val tokenRDD = csv.flatMap { lines =>
-        UnivocityParser.tokenizeStream(
-          CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
-          parsedOptions.headerFlag,
-          new CsvParser(parsedOptions.asParserSettings))
-      }
-      Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
-    } else {
-      // If the first row could not be read, just return the empty schema.
-      Some(StructType(Nil))
+    }.take(1).headOption match {
+      case Some(firstRow) =>
+        val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+        val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+        val tokenRDD = csv.flatMap { lines =>
+          UnivocityParser.tokenizeStream(
+            CodecStreams.createInputStreamWithCloseResource(
+              lines.getConfiguration,
+              lines.getPath()),
+            parsedOptions.headerFlag,
+            new CsvParser(parsedOptions.asParserSettings))
+        }
+        Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+      case None =>
+        // If the first row could not be read, just return the empty schema.
+        Some(StructType(Nil))
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 56071803f685fa1a0f76ba0748c9e008c28814b2..eaedede349134281650389cf9414d7ef45985ae4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -1077,14 +1077,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
     }
   }
 
-  test("Empty file produces empty dataframe with empty schema - wholeFile option") {
-    withTempPath { path =>
-      path.createNewFile()
-
+  test("Empty file produces empty dataframe with empty schema") {
+    Seq(false, true).foreach { wholeFile =>
       val df = spark.read.format("csv")
         .option("header", true)
-        .option("wholeFile", true)
-        .load(path.getAbsolutePath)
+        .option("wholeFile", wholeFile)
+        .load(testFile(emptyFile))
 
       assert(df.schema === spark.emptyDataFrame.schema)
       checkAnswer(df, spark.emptyDataFrame)