diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 8171dcc046379fdbf41a673ee145e57e8511d1c3..ad1fddbde7b00dbdad5939d81b36d2337a09e30a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -20,7 +20,7 @@ package org.apache.spark.rdd
 import java.io.{IOException, ObjectOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
-import scala.collection.parallel.ForkJoinTaskSupport
+import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport}
 import scala.concurrent.forkjoin.ForkJoinPool
 import scala.reflect.ClassTag
 
@@ -58,6 +58,11 @@ private[spark] class UnionPartition[T: ClassTag](
   }
 }
 
+object UnionRDD {
+  private[spark] lazy val partitionEvalTaskSupport =
+    new ForkJoinTaskSupport(new ForkJoinPool(8))
+}
+
 @DeveloperApi
 class UnionRDD[T: ClassTag](
     sc: SparkContext,
@@ -68,13 +73,10 @@ class UnionRDD[T: ClassTag](
   private[spark] val isPartitionListingParallel: Boolean =
     rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
 
-  @transient private lazy val partitionEvalTaskSupport =
-      new ForkJoinTaskSupport(new ForkJoinPool(8))
-
   override def getPartitions: Array[Partition] = {
     val parRDDs = if (isPartitionListingParallel) {
       val parArray = rdds.par
-      parArray.tasksupport = partitionEvalTaskSupport
+      parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
       parArray
     } else {
       rdds