From 3c91afec20607e0d853433a904105ee22df73c73 Mon Sep 17 00:00:00 2001
From: Nezih Yigitbasi <nyigitbasi@netflix.com>
Date: Tue, 19 Apr 2016 14:35:26 -0700
Subject: [PATCH] [SPARK-14042][CORE] Add custom coalescer support

## What changes were proposed in this pull request?

This PR adds support for specifying an optional custom coalescer to the `coalesce()` method. Currently I have only added this feature to the `RDD` interface, and once we sort out the details we can proceed with adding this feature to the other APIs (`Dataset` etc.)

## How was this patch tested?

Added a unit test for this functionality.

/cc rxin (per our discussion on the mailing list)

Author: Nezih Yigitbasi <nyigitbasi@netflix.com>

Closes #11865 from nezihyigitbasi/custom_coalesce_policy.
---
 .../org/apache/spark/rdd/CoalescedRDD.scala   | 99 +++++++++----------
 .../main/scala/org/apache/spark/rdd/RDD.scala |  9 +-
 .../apache/spark/rdd/coalesce-public.scala    | 52 ++++++++++
 .../scala/org/apache/spark/rdd/RDDSuite.scala | 99 ++++++++++++++++++-
 project/MimaExcludes.scala                    |  4 +
 5 files changed, 209 insertions(+), 54 deletions(-)
 create mode 100644 core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala

diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 35665ab7c0..e75f1dbf81 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -70,23 +70,23 @@ private[spark] case class CoalescedRDDPartition(
  * parent partitions
  * @param prev RDD to be coalesced
  * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive)
- * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
+ * @param partitionCoalescer [[PartitionCoalescer]] implementation to use for coalescing
  */
 private[spark] class CoalescedRDD[T: ClassTag](
     @transient var prev: RDD[T],
     maxPartitions: Int,
-    balanceSlack: Double = 0.10)
+    partitionCoalescer: Option[PartitionCoalescer] = None)
   extends RDD[T](prev.context, Nil) {  // Nil since we implement getDependencies
 
   require(maxPartitions > 0 || maxPartitions == prev.partitions.length,
     s"Number of partitions ($maxPartitions) must be positive.")
 
   override def getPartitions: Array[Partition] = {
-    val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack)
+    val pc = partitionCoalescer.getOrElse(new DefaultPartitionCoalescer())
 
-    pc.run().zipWithIndex.map {
+    pc.coalesce(maxPartitions, prev).zipWithIndex.map {
       case (pg, i) =>
-        val ids = pg.arr.map(_.index).toArray
+        val ids = pg.partitions.map(_.index).toArray
         new CoalescedRDDPartition(i, prev, ids, pg.prefLoc)
     }
   }
@@ -144,15 +144,15 @@ private[spark] class CoalescedRDD[T: ClassTag](
  * desired partitions is greater than the number of preferred machines (can happen), it needs to
  * start picking duplicate preferred machines. This is determined using coupon collector estimation
  * (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist:
- * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two
- * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions
- * according to locality. (contact alig for questions)
- *
+ * it tries to also achieve locality. This is done by allowing a slack (balanceSlack, where
+ * 1.0 is all locality, 0 is all balance) between two bins. If two bins are within the slack
+ * in terms of balance, the algorithm will assign partitions according to locality.
+ * (contact alig for questions)
  */
 
-private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
-
-  def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
+private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
+  extends PartitionCoalescer {
+  def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.numPartitions < o2.numPartitions
   def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
     if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get)
 
@@ -167,14 +167,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
   // hash used for the first maxPartitions (to avoid duplicates)
   val initialHash = mutable.Set[Partition]()
 
-  // determines the tradeoff between load-balancing the partitions sizes and their locality
-  // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality
-  val slack = (balanceSlack * prev.partitions.length).toInt
-
   var noLocality = true  // if true if no preferredLocations exists for parent RDD
 
   // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
-  def currPrefLocs(part: Partition): Seq[String] = {
+  def currPrefLocs(part: Partition, prev: RDD[_]): Seq[String] = {
     prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host)
   }
 
@@ -192,7 +188,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
     def resetIterator(): Iterator[(String, Partition)] = {
       val iterators = (0 to 2).map { x =>
         prev.partitions.iterator.flatMap { p =>
-          if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None
+          if (currPrefLocs(p, prev).size > x) Some((currPrefLocs(p, prev)(x), p)) else None
         }
       }
       iterators.reduceLeft((x, y) => x ++ y)
@@ -215,8 +211,9 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
   /**
    * Sorts and gets the least element of the list associated with key in groupHash
    * The returned PartitionGroup is the least loaded of all groups that represent the machine "key"
+   *
    * @param key string representing a partitioned group on preferred machine key
-   * @return Option of PartitionGroup that has least elements for key
+   * @return Option of [[PartitionGroup]] that has least elements for key
    */
   def getLeastGroupHash(key: String): Option[PartitionGroup] = {
     groupHash.get(key).map(_.sortWith(compare).head)
@@ -224,7 +221,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
 
   def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = {
     if (!initialHash.contains(part)) {
-      pgroup.arr += part           // already assign this element
+      pgroup.partitions += part           // already assign this element
       initialHash += part // needed to avoid assigning partitions to multiple buckets
       true
     } else { false }
@@ -236,12 +233,12 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
    * until it has seen most of the preferred locations (2 * n log(n))
    * @param targetLen
    */
-  def setupGroups(targetLen: Int) {
+  def setupGroups(targetLen: Int, prev: RDD[_]) {
     val rotIt = new LocationIterator(prev)
 
     // deal with empty case, just create targetLen partition groups with no preferred location
     if (!rotIt.hasNext) {
-      (1 to targetLen).foreach(x => groupArr += PartitionGroup())
+      (1 to targetLen).foreach(x => groupArr += new PartitionGroup())
       return
     }
 
@@ -259,7 +256,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
       tries += 1
       val (nxt_replica, nxt_part) = rotIt.next()
       if (!groupHash.contains(nxt_replica)) {
-        val pgroup = PartitionGroup(nxt_replica)
+        val pgroup = new PartitionGroup(Some(nxt_replica))
         groupArr += pgroup
         addPartToPGroup(nxt_part, pgroup)
         groupHash.put(nxt_replica, ArrayBuffer(pgroup)) // list in case we have multiple
@@ -269,7 +266,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
 
     while (numCreated < targetLen) {  // if we don't have enough partition groups, create duplicates
       var (nxt_replica, nxt_part) = rotIt.next()
-      val pgroup = PartitionGroup(nxt_replica)
+      val pgroup = new PartitionGroup(Some(nxt_replica))
       groupArr += pgroup
       groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup
       var tries = 0
@@ -285,17 +282,29 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
   /**
    * Takes a parent RDD partition and decides which of the partition groups to put it in
    * Takes locality into account, but also uses power of 2 choices to load balance
-   * It strikes a balance between the two use the balanceSlack variable
+   * It strikes a balance between the two using the balanceSlack variable
    * @param p partition (ball to be thrown)
+   * @param balanceSlack determines the trade-off between load-balancing the partitions sizes and
+   *                     their locality. e.g., balanceSlack=0.10 means that it allows up to 10%
+   *                     imbalance in favor of locality
    * @return partition group (bin to be put in)
    */
-  def pickBin(p: Partition): PartitionGroup = {
-    val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs
+  def pickBin(p: Partition, prev: RDD[_], balanceSlack: Double): PartitionGroup = {
+    val slack = (balanceSlack * prev.partitions.length).toInt
+    // least loaded pref locs
+    val pref = currPrefLocs(p, prev).map(getLeastGroupHash(_)).sortWith(compare)
     val prefPart = if (pref == Nil) None else pref.head
 
     val r1 = rnd.nextInt(groupArr.size)
     val r2 = rnd.nextInt(groupArr.size)
-    val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2)
+    val minPowerOfTwo = {
+      if (groupArr(r1).numPartitions < groupArr(r2).numPartitions) {
+        groupArr(r1)
+      }
+      else {
+        groupArr(r2)
+      }
+    }
     if (prefPart.isEmpty) {
       // if no preferred locations, just use basic power of two
       return minPowerOfTwo
@@ -303,55 +312,45 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack:
 
     val prefPartActual = prefPart.get
 
-    if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows
+    // more imbalance than the slack allows
+    if (minPowerOfTwo.numPartitions + slack <= prefPartActual.numPartitions) {
       minPowerOfTwo  // prefer balance over locality
     } else {
       prefPartActual // prefer locality over balance
     }
   }
 
-  def throwBalls() {
+  def throwBalls(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
     if (noLocality) {  // no preferredLocations in parent RDD, no randomization needed
       if (maxPartitions > groupArr.size) { // just return prev.partitions
         for ((p, i) <- prev.partitions.zipWithIndex) {
-          groupArr(i).arr += p
+          groupArr(i).partitions += p
         }
       } else { // no locality available, then simply split partitions based on positions in array
         for (i <- 0 until maxPartitions) {
           val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt
           val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt
-          (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) }
+          (rangeStart until rangeEnd).foreach{ j => groupArr(i).partitions += prev.partitions(j) }
         }
       }
     } else {
       for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group
-        pickBin(p).arr += p
+        pickBin(p, prev, balanceSlack).partitions += p
       }
     }
   }
 
-  def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray
+  def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.numPartitions > 0).toArray
 
   /**
    * Runs the packing algorithm and returns an array of PartitionGroups that if possible are
    * load balanced and grouped by locality
-   * @return array of partition groups
+    *
+    * @return array of partition groups
    */
-  def run(): Array[PartitionGroup] = {
-    setupGroups(math.min(prev.partitions.length, maxPartitions))   // setup the groups (bins)
-    throwBalls() // assign partitions (balls) to each group (bins)
+  def coalesce(maxPartitions: Int, prev: RDD[_]): Array[PartitionGroup] = {
+    setupGroups(math.min(prev.partitions.length, maxPartitions), prev)   // setup the groups (bins)
+    throwBalls(maxPartitions, prev, balanceSlack) // assign partitions (balls) to each group (bins)
     getPartitions
   }
 }
-
-private case class PartitionGroup(prefLoc: Option[String] = None) {
-  var arr = mutable.ArrayBuffer[Partition]()
-  def size: Int = arr.size
-}
-
-private object PartitionGroup {
-  def apply(prefLoc: String): PartitionGroup = {
-    require(prefLoc != "", "Preferred location must not be empty")
-    PartitionGroup(Some(prefLoc))
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index f6e0148f78..499a8b9aa1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -433,7 +433,9 @@ abstract class RDD[T: ClassTag](
    * coalesce(1000, shuffle = true) will result in 1000 partitions with the
    * data distributed using a hash partitioner.
    */
-  def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null)
+  def coalesce(numPartitions: Int, shuffle: Boolean = false,
+               partitionCoalescer: Option[PartitionCoalescer] = Option.empty)
+              (implicit ord: Ordering[T] = null)
       : RDD[T] = withScope {
     if (shuffle) {
       /** Distributes elements evenly across output partitions, starting from a random partition. */
@@ -451,9 +453,10 @@ abstract class RDD[T: ClassTag](
       new CoalescedRDD(
         new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
         new HashPartitioner(numPartitions)),
-        numPartitions).values
+        numPartitions,
+        partitionCoalescer).values
     } else {
-      new CoalescedRDD(this, numPartitions)
+      new CoalescedRDD(this, numPartitions, partitionCoalescer)
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala
new file mode 100644
index 0000000000..d8a80aa5ae
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Partition
+
+/**
+ * ::DeveloperApi::
+ * A PartitionCoalescer defines how to coalesce the partitions of a given RDD.
+ */
+@DeveloperApi
+trait PartitionCoalescer {
+
+  /**
+   * Coalesce the partitions of the given RDD.
+   *
+   * @param maxPartitions the maximum number of partitions to have after coalescing
+   * @param parent the parent RDD whose partitions to coalesce
+   * @return an array of [[PartitionGroup]]s, where each element is itself an array of
+   * [[Partition]]s and represents a partition after coalescing is performed.
+   */
+  def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup]
+}
+
+/**
+ * ::DeveloperApi::
+ * A group of [[Partition]]s
+ * @param prefLoc preferred location for the partition group
+ */
+@DeveloperApi
+class PartitionGroup(val prefLoc: Option[String] = None) {
+  val partitions = mutable.ArrayBuffer[Partition]()
+  def numPartitions: Int = partitions.size
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 24daedab20..8dc463d56d 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.rdd
 
-import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.reflect.ClassTag
 
 import com.esotericsoftware.kryo.KryoException
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
@@ -31,6 +33,20 @@ import org.apache.spark.rdd.RDDSuiteUtils._
 import org.apache.spark.util.Utils
 
 class RDDSuite extends SparkFunSuite with SharedSparkContext {
+  var tempDir: File = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    tempDir = Utils.createTempDir()
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      Utils.deleteRecursively(tempDir)
+    } finally {
+      super.afterAll()
+    }
+  }
 
   test("basic operations") {
     val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
@@ -951,6 +967,32 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
     assert(thrown.getMessage.contains("SPARK-5063"))
   }
 
+  test("custom RDD coalescer") {
+    val maxSplitSize = 512
+    val outDir = new File(tempDir, "output").getAbsolutePath
+    sc.makeRDD(1 to 1000, 10).saveAsTextFile(outDir)
+    val hadoopRDD =
+      sc.hadoopFile(outDir, classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
+    val coalescedHadoopRDD =
+      hadoopRDD.coalesce(2, partitionCoalescer = Option(new SizeBasedCoalescer(maxSplitSize)))
+    assert(coalescedHadoopRDD.partitions.size <= 10)
+    var totalPartitionCount = 0L
+    coalescedHadoopRDD.partitions.foreach(partition => {
+      var splitSizeSum = 0L
+      partition.asInstanceOf[CoalescedRDDPartition].parents.foreach(partition => {
+        val split = partition.asInstanceOf[HadoopPartition].inputSplit.value.asInstanceOf[FileSplit]
+        splitSizeSum += split.getLength
+        totalPartitionCount += 1
+      })
+      assert(splitSizeSum <= maxSplitSize)
+    })
+    assert(totalPartitionCount == 10)
+  }
+
+  // NOTE
+  // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
+  // running after them and if they access sc those tests will fail as sc is already closed, because
+  // sc is shared (this suite mixins SharedSparkContext)
   test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
     val existingRDD = sc.parallelize(1 to 100)
     sc.stop()
@@ -971,5 +1013,60 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
     assertFails { sc.parallelize(1 to 100) }
     assertFails { sc.textFile("/nonexistent-path") }
   }
+}
 
+/**
+ * Coalesces partitions based on their size assuming that the parent RDD is a [[HadoopRDD]].
+ * Took this class out of the test suite to prevent "Task not serializable" exceptions.
+ */
+class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Serializable {
+  override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = {
+    val partitions: Array[Partition] = parent.asInstanceOf[HadoopRDD[Any, Any]].getPartitions
+    val groups = ArrayBuffer[PartitionGroup]()
+    var currentGroup = new PartitionGroup()
+    var currentSum = 0L
+    var totalSum = 0L
+    var index = 0
+
+    // sort partitions based on the size of the corresponding input splits
+    partitions.sortWith((partition1, partition2) => {
+      val partition1Size = partition1.asInstanceOf[HadoopPartition].inputSplit.value.getLength
+      val partition2Size = partition2.asInstanceOf[HadoopPartition].inputSplit.value.getLength
+      partition1Size < partition2Size
+    })
+
+    def updateGroups(): Unit = {
+      groups += currentGroup
+      currentGroup = new PartitionGroup()
+      currentSum = 0
+    }
+
+    def addPartition(partition: Partition, splitSize: Long): Unit = {
+      currentGroup.partitions += partition
+      currentSum += splitSize
+      totalSum += splitSize
+    }
+
+    while (index < partitions.size) {
+      val partition = partitions(index)
+      val fileSplit =
+        partition.asInstanceOf[HadoopPartition].inputSplit.value.asInstanceOf[FileSplit]
+      val splitSize = fileSplit.getLength
+      if (currentSum + splitSize < maxSize) {
+        addPartition(partition, splitSize)
+        index += 1
+        if (index == partitions.size) {
+          updateGroups
+        }
+      } else {
+        if (currentGroup.partitions.size == 0) {
+          addPartition(partition, splitSize)
+          index += 1
+        } else {
+          updateGroups
+        }
+      }
+    }
+    groups.toArray
+  }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index ff35dc010d..b2c80afb53 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -49,6 +49,10 @@ object MimaExcludes {
           "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"),
         ProblemFilters.exclude[MissingMethodProblem](
           "org.apache.spark.status.api.v1.ApplicationAttemptInfo.<init>$default$5"),
+        // SPARK-14042 Add custom coalescer support
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.coalesce"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rdd.PartitionCoalescer$LocationIterator"),
+        ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.rdd.PartitionCoalescer"),
         // SPARK-12600 Remove SQL deprecated methods
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"),
-- 
GitLab