From 4f7f1c436205630ab77d3758d7210cc1a2f0d04a Mon Sep 17 00:00:00 2001
From: hyukjinkwon <gurwls223@gmail.com>
Date: Mon, 20 Jun 2016 21:55:34 -0700
Subject: [PATCH] [SPARK-16044][SQL] input_file_name() returns empty strings in
 data sources based on NewHadoopRDD

## What changes were proposed in this pull request?

This PR makes `input_file_name()` function return the file paths not empty strings for external data sources based on `NewHadoopRDD`, such as [spark-redshift](https://github.com/databricks/spark-redshift/blob/cba5eee1ab79ae8f0fa9e668373a54d2b5babf6b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala#L149) and [spark-xml](https://github.com/databricks/spark-xml/blob/master/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala#L39-L47).

The codes with the external data sources below:

```scala
df.select(input_file_name).show()
```

will produce

- **Before**
  ```
+-----------------+
|input_file_name()|
+-----------------+
|                 |
+-----------------+
```

- **After**
  ```
+--------------------+
|   input_file_name()|
+--------------------+
|file:/private/var...|
+--------------------+
```

## How was this patch tested?

Unit tests in `ColumnExpressionSuite`.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #13759 from HyukjinKwon/SPARK-16044.
---
 .../spark/rdd/InputFileNameHolder.scala       |  2 +-
 .../org/apache/spark/rdd/NewHadoopRDD.scala   |  7 ++++
 .../spark/sql/ColumnExpressionSuite.scala     | 34 +++++++++++++++++--
 3 files changed, 40 insertions(+), 3 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
index 108e9d2558..f40d4c8e0a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * This holds file names of the current Spark task. This is used in HadoopRDD,
- * FileScanRDD and InputFileName function in Spark SQL.
+ * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL.
  */
 private[spark] object InputFileNameHolder {
   /**
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 189dc7b331..b086baa084 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -135,6 +135,12 @@ class NewHadoopRDD[K, V](
       val inputMetrics = context.taskMetrics().inputMetrics
       val existingBytesRead = inputMetrics.bytesRead
 
+      // Sets the thread local variable for the file's name
+      split.serializableHadoopSplit.value match {
+        case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString)
+        case _ => InputFileNameHolder.unsetInputFileName()
+      }
+
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
       val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
@@ -201,6 +207,7 @@ class NewHadoopRDD[K, V](
 
       private def close() {
         if (reader != null) {
+          InputFileNameHolder.unsetInputFileName()
           // Close the reader and release it. Note: it's very important that we don't close the
           // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
           // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index e89fa32b15..a66c83dea0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
 import org.scalatest.Matchers._
 
 import org.apache.spark.sql.catalyst.expressions.NamedExpression
@@ -592,7 +594,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
     )
   }
 
-  test("input_file_name") {
+  test("input_file_name - FileScanRDD") {
     withTempPath { dir =>
       val data = sparkContext.parallelize(0 to 10).toDF("id")
       data.write.parquet(dir.getCanonicalPath)
@@ -604,6 +606,35 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("input_file_name - HadoopRDD") {
+    withTempPath { dir =>
+      val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF()
+      data.write.text(dir.getCanonicalPath)
+      val df = spark.sparkContext.textFile(dir.getCanonicalPath).toDF()
+      val answer = df.select(input_file_name()).head.getString(0)
+      assert(answer.contains(dir.getCanonicalPath))
+
+      checkAnswer(data.select(input_file_name()).limit(1), Row(""))
+    }
+  }
+
+  test("input_file_name - NewHadoopRDD") {
+    withTempPath { dir =>
+      val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF()
+      data.write.text(dir.getCanonicalPath)
+      val rdd = spark.sparkContext.newAPIHadoopFile(
+        dir.getCanonicalPath,
+        classOf[NewTextInputFormat],
+        classOf[LongWritable],
+        classOf[Text])
+      val df = rdd.map(pair => pair._2.toString).toDF()
+      val answer = df.select(input_file_name()).head.getString(0)
+      assert(answer.contains(dir.getCanonicalPath))
+
+      checkAnswer(data.select(input_file_name()).limit(1), Row(""))
+    }
+  }
+
   test("columns can be compared") {
     assert('key.desc == 'key.desc)
     assert('key.desc != 'key.asc)
@@ -707,5 +738,4 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
       testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
       testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
   }
-
 }
-- 
GitLab