Skip to content
Snippets Groups Projects
Commit 5854f77c authored by jerryshao's avatar jerryshao Committed by Wenchen Fan
Browse files

[SPARK-20244][CORE] Handle incorrect bytesRead metrics when using PySpark

## What changes were proposed in this pull request?

Hadoop FileSystem's statistics in based on thread local variables, this is ok if the RDD computation chain is running in the same thread. But if child RDD creates another thread to consume the iterator got from Hadoop RDDs, the bytesRead computation will be error, because now the iterator's `next()` and `close()` may run in different threads. This could be happened when using PySpark with PythonRDD.

So here building a map to track the `bytesRead` for different thread and add them together. This method will be used in three RDDs, `HadoopRDD`, `NewHadoopRDD` and `FileScanRDD`. I assume `FileScanRDD` cannot be called directly, so I only fixed `HadoopRDD` and `NewHadoopRDD`.

## How was this patch tested?

Unit test and local cluster verification.

Author: jerryshao <sshao@hortonworks.com>

Closes #17617 from jerryshao/SPARK-20244.
parent 24db3582
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@ import java.text.DateFormat
import java.util.{Arrays, Comparator, Date, Locale}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal
import com.google.common.primitives.Longs
......@@ -143,14 +144,29 @@ class SparkHadoopUtil extends Logging {
* Returns a function that can be called to find Hadoop FileSystem bytes read. If
* getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
* return the bytes read on r since t.
*
* @return None if the required method can't be found.
*/
private[spark] def getFSBytesReadOnThreadCallback(): () => Long = {
val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics)
val f = () => threadStats.map(_.getBytesRead).sum
val baselineBytesRead = f()
() => f() - baselineBytesRead
val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
val baseline = (Thread.currentThread().getId, f())
/**
* This function may be called in both spawned child threads and parent task thread (in
* PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics.
* So we need a map to track the bytes read from the child threads and parent thread,
* summing them together to get the bytes read of this task.
*/
new Function0[Long] {
private val bytesReadMap = new mutable.HashMap[Long, Long]()
override def apply(): Long = {
bytesReadMap.synchronized {
bytesReadMap.put(Thread.currentThread().getId, f())
bytesReadMap.map { case (k, v) =>
v - (if (k == baseline._1) baseline._2 else 0)
}.sum
}
}
}
}
/**
......
......@@ -251,7 +251,13 @@ class HadoopRDD[K, V](
null
}
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener{ context => closeIfNeeded() }
context.addTaskCompletionListener { context =>
// Update the bytes read before closing is to make sure lingering bytesRead statistics in
// this thread get correctly added.
updateBytesRead()
closeIfNeeded()
}
private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
......
......@@ -191,7 +191,13 @@ class NewHadoopRDD[K, V](
}
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
context.addTaskCompletionListener { context =>
// Update the bytesRead before closing is to make sure lingering bytesRead statistics in
// this thread get correctly added.
updateBytesRead()
close()
}
private var havePair = false
private var recordsSinceMetricsUpdate = 0
......
......@@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{SharedSparkContext, SparkFunSuite}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}
class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
with BeforeAndAfter {
......@@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
}
assert(bytesRead >= tmpFile.length())
}
test("input metrics with old Hadoop API in different thread") {
val bytesRead = runAndReturnBytesRead {
sc.textFile(tmpFilePath, 4).mapPartitions { iter =>
val buf = new ArrayBuffer[String]()
ThreadUtils.runInNewThread("testThread", false) {
iter.flatMap(_.split(" ")).foreach(buf.append(_))
}
buf.iterator
}.count()
}
assert(bytesRead >= tmpFile.length())
}
test("input metrics with new Hadoop API in different thread") {
val bytesRead = runAndReturnBytesRead {
sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
classOf[Text]).mapPartitions { iter =>
val buf = new ArrayBuffer[String]()
ThreadUtils.runInNewThread("testThread", false) {
iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_))
}
buf.iterator
}.count()
}
assert(bytesRead >= tmpFile.length())
}
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment