diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 9cc321af4bde249d77efad3c605c2e89ad8d745e..6afe58bff522926f821823c37ff112d44bb1b670 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -23,6 +23,7 @@ import java.text.DateFormat
 import java.util.{Arrays, Comparator, Date, Locale}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import com.google.common.primitives.Longs
@@ -143,14 +144,29 @@ class SparkHadoopUtil extends Logging {
    * Returns a function that can be called to find Hadoop FileSystem bytes read. If
    * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
    * return the bytes read on r since t.
-   *
-   * @return None if the required method can't be found.
    */
   private[spark] def getFSBytesReadOnThreadCallback(): () => Long = {
-    val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics)
-    val f = () => threadStats.map(_.getBytesRead).sum
-    val baselineBytesRead = f()
-    () => f() - baselineBytesRead
+    val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
+    val baseline = (Thread.currentThread().getId, f())
+
+    /**
+     * This function may be called in both spawned child threads and parent task thread (in
+     * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics.
+     * So we need a map to track the bytes read from the child threads and parent thread,
+     * summing them together to get the bytes read of this task.
+     */
+    new Function0[Long] {
+      private val bytesReadMap = new mutable.HashMap[Long, Long]()
+
+      override def apply(): Long = {
+        bytesReadMap.synchronized {
+          bytesReadMap.put(Thread.currentThread().getId, f())
+          bytesReadMap.map { case (k, v) =>
+            v - (if (k == baseline._1) baseline._2 else 0)
+          }.sum
+        }
+      }
+    }
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 4bf8ecc383542b791ae5b8c10c23afee21e04bdf..76ea8b86c53d21d76a092beb8028bcdacaee5b2a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -251,7 +251,13 @@ class HadoopRDD[K, V](
             null
         }
       // Register an on-task-completion callback to close the input stream.
-      context.addTaskCompletionListener{ context => closeIfNeeded() }
+      context.addTaskCompletionListener { context =>
+        // Update the bytes read before closing is to make sure lingering bytesRead statistics in
+        // this thread get correctly added.
+        updateBytesRead()
+        closeIfNeeded()
+      }
+
       private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
       private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
 
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index ce3a9a2a1e2a801cd6653d4fe6cdc4d7adf4300f..482875e6c1ac53b1070b35b48803445eaa035dd8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -191,7 +191,13 @@ class NewHadoopRDD[K, V](
         }
 
       // Register an on-task-completion callback to close the input stream.
-      context.addTaskCompletionListener(context => close())
+      context.addTaskCompletionListener { context =>
+        // Update the bytesRead before closing is to make sure lingering bytesRead statistics in
+        // this thread get correctly added.
+        updateBytesRead()
+        close()
+      }
+
       private var havePair = false
       private var recordsSinceMetricsUpdate = 0
 
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 5d522189a0c292ccf97b643c588cf4b2cc2cfac1..6f4203da1d866a4c3e6fff5a17c8d95f9558b68d 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
 import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
   with BeforeAndAfter {
@@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
     }
     assert(bytesRead >= tmpFile.length())
   }
+
+  test("input metrics with old Hadoop API in different thread") {
+    val bytesRead = runAndReturnBytesRead {
+      sc.textFile(tmpFilePath, 4).mapPartitions { iter =>
+        val buf = new ArrayBuffer[String]()
+        ThreadUtils.runInNewThread("testThread", false) {
+          iter.flatMap(_.split(" ")).foreach(buf.append(_))
+        }
+
+        buf.iterator
+      }.count()
+    }
+    assert(bytesRead >= tmpFile.length())
+  }
+
+  test("input metrics with new Hadoop API in different thread") {
+    val bytesRead = runAndReturnBytesRead {
+      sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
+        classOf[Text]).mapPartitions { iter =>
+        val buf = new ArrayBuffer[String]()
+        ThreadUtils.runInNewThread("testThread", false) {
+          iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_))
+        }
+
+        buf.iterator
+      }.count()
+    }
+    assert(bytesRead >= tmpFile.length())
+  }
 }
 
 /**