diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 50df68b14483dd0a01f973439c42ba545362a90b..66320bd050c148980ec19be14184d8b9cf1c9de1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -412,6 +412,14 @@ class SQLTests(ReusedPySparkTestCase):
         res.explain(True)
         self.assertEqual(res.collect(), [Row(id=0, copy=0)])
 
+    def test_udf_with_input_file_name(self):
+        from pyspark.sql.functions import udf, input_file_name
+        from pyspark.sql.types import StringType
+        sourceFile = udf(lambda path: path, StringType())
+        filePath = "python/test_support/sql/people1.json"
+        row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
+        self.assertTrue(row[0].find("people1.json") != -1)
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index dcaf2c76d479d1ffb64be51123eabbed6bb28e94..7a5ac48f1b69db7fa33176005baeca9e10c887ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -119,26 +119,23 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
       val pickle = new Pickler(needConversion)
       // Input iterator to Python: input rows are grouped so we send them in batches to Python.
       // For each row, add it to the queue.
-      val inputIterator = iter.grouped(100).map { inputRows =>
-        val toBePickled = inputRows.map { inputRow =>
-          queue.add(inputRow.asInstanceOf[UnsafeRow])
-          val row = projection(inputRow)
-          if (needConversion) {
-            EvaluatePython.toJava(row, schema)
-          } else {
-            // fast path for these types that does not need conversion in Python
-            val fields = new Array[Any](row.numFields)
-            var i = 0
-            while (i < row.numFields) {
-              val dt = dataTypes(i)
-              fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
-              i += 1
-            }
-            fields
+      val inputIterator = iter.map { inputRow =>
+        queue.add(inputRow.asInstanceOf[UnsafeRow])
+        val row = projection(inputRow)
+        if (needConversion) {
+          EvaluatePython.toJava(row, schema)
+        } else {
+          // fast path for these types that does not need conversion in Python
+          val fields = new Array[Any](row.numFields)
+          var i = 0
+          while (i < row.numFields) {
+            val dt = dataTypes(i)
+            fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+            i += 1
           }
-        }.toArray
-        pickle.dumps(toBePickled)
-      }
+          fields
+        }
+      }.grouped(100).map(x => pickle.dumps(x.toArray))
 
       val context = TaskContext.get()