diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
index 9ba476d2ba26a926612772f9fb6331f3b96c32fb..ff2f58d81142dc7f47924756972deeb95324959a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
@@ -41,9 +41,10 @@ private[spark] object InputFileBlockHolder {
    * The thread variable for the name of the current file being read. This is used by
    * the InputFileName function in Spark SQL.
    */
-  private[this] val inputBlock: ThreadLocal[FileBlock] = new ThreadLocal[FileBlock] {
-    override protected def initialValue(): FileBlock = new FileBlock
-  }
+  private[this] val inputBlock: InheritableThreadLocal[FileBlock] =
+    new InheritableThreadLocal[FileBlock] {
+      override protected def initialValue(): FileBlock = new FileBlock
+    }
 
   /**
    * Returns the holding file name or empty string if it is unknown.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a8250281dab351b827d18cbf1738e7fc4f06477f..73a5df65e0ab3344902091d0c81e868b0d0b8e5e 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -435,6 +435,30 @@ class SQLTests(ReusedPySparkTestCase):
         row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
         self.assertTrue(row[0].find("people1.json") != -1)
 
+    def test_udf_with_input_file_name_for_hadooprdd(self):
+        from pyspark.sql.functions import udf, input_file_name
+        from pyspark.sql.types import StringType
+
+        def filename(path):
+            return path
+
+        sameText = udf(filename, StringType())
+
+        rdd = self.sc.textFile('python/test_support/sql/people.json')
+        df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
+        row = df.select(sameText(df['file'])).first()
+        self.assertTrue(row[0].find("people.json") != -1)
+
+        rdd2 = self.sc.newAPIHadoopFile(
+            'python/test_support/sql/people.json',
+            'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
+            'org.apache.hadoop.io.LongWritable',
+            'org.apache.hadoop.io.Text')
+
+        df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
+        row2 = df2.select(sameText(df2['file'])).first()
+        self.assertTrue(row2[0].find("people.json") != -1)
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)