diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala index 9ba476d2ba26a926612772f9fb6331f3b96c32fb..ff2f58d81142dc7f47924756972deeb95324959a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -41,9 +41,10 @@ private[spark] object InputFileBlockHolder { * The thread variable for the name of the current file being read. This is used by * the InputFileName function in Spark SQL. */ - private[this] val inputBlock: ThreadLocal[FileBlock] = new ThreadLocal[FileBlock] { - override protected def initialValue(): FileBlock = new FileBlock - } + private[this] val inputBlock: InheritableThreadLocal[FileBlock] = + new InheritableThreadLocal[FileBlock] { + override protected def initialValue(): FileBlock = new FileBlock + } /** * Returns the holding file name or empty string if it is unknown. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a8250281dab351b827d18cbf1738e7fc4f06477f..73a5df65e0ab3344902091d0c81e868b0d0b8e5e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -435,6 +435,30 @@ class SQLTests(ReusedPySparkTestCase): row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() self.assertTrue(row[0].find("people1.json") != -1) + def test_udf_with_input_file_name_for_hadooprdd(self): + from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.types import StringType + + def filename(path): + return path + + sameText = udf(filename, StringType()) + + rdd = self.sc.textFile('python/test_support/sql/people.json') + df = self.spark.read.json(rdd).select(input_file_name().alias('file')) + row = df.select(sameText(df['file'])).first() + self.assertTrue(row[0].find("people.json") != -1) + + rdd2 = self.sc.newAPIHadoopFile( + 'python/test_support/sql/people.json', + 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', + 'org.apache.hadoop.io.LongWritable', + 'org.apache.hadoop.io.Text') + + df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) + row2 = df2.select(sameText(df2['file'])).first() + self.assertTrue(row2[0].find("people.json") != -1) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd)