diff --git a/src/scala/spark/Shuffle.scala b/src/scala/spark/Shuffle.scala
index 4c5649b5378164bf162c33b90e7445b7ac412877..a80cfdb585d0cf981ce4a76f9bdcaa15239349cd 100644
--- a/src/scala/spark/Shuffle.scala
+++ b/src/scala/spark/Shuffle.scala
@@ -1,5 +1,9 @@
 package spark
 
+import java.net._
+import java.util.{BitSet}
+import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+
 /**
  * A trait for shuffle system. Given an input RDD and combiner functions
  * for PairRDDExtras.combineByKey(), returns an output RDD.
@@ -13,3 +17,83 @@ trait Shuffle[K, V, C] {
               mergeCombiners: (C, C) => C)
   : RDD[(K, C)]
 }
+
+/**
+ * An object containing common shuffle config parameters
+ */
+private object Shuffle 
+extends Logging {
+  // ShuffleTracker info
+  private var MasterHostAddress_ = System.getProperty(
+    "spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress)
+  private var MasterTrackerPort_ = System.getProperty(
+    "spark.shuffle.masterTrackerPort", "22222").toInt
+
+  // Used thoughout the code for small and large waits/timeouts
+  private var MinKnockInterval_ = System.getProperty(
+    "spark.shuffle.minKnockInterval", "1000").toInt
+  private var MaxKnockInterval_ =  System.getProperty(
+    "spark.shuffle.maxKnockInterval", "5000").toInt
+
+  // Maximum number of connections
+  private var MaxRxConnections_ = System.getProperty(
+    "spark.shuffle.maxRxConnections", "4").toInt
+  private var MaxTxConnections_ = System.getProperty(
+    "spark.shuffle.maxTxConnections", "8").toInt
+
+  def MasterHostAddress = MasterHostAddress_
+  def MasterTrackerPort = MasterTrackerPort_
+
+  def MinKnockInterval = MinKnockInterval_
+  def MaxKnockInterval = MaxKnockInterval_
+  
+  def MaxRxConnections = MaxRxConnections_
+  def MaxTxConnections = MaxTxConnections_
+
+  // Returns a standard ThreadFactory except all threads are daemons
+  private def newDaemonThreadFactory: ThreadFactory = {
+    new ThreadFactory {
+      def newThread(r: Runnable): Thread = {
+        var t = Executors.defaultThreadFactory.newThread(r)
+        t.setDaemon(true)
+        return t
+      }
+    }
+  }
+
+  // Wrapper over newFixedThreadPool
+  def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+    var threadPool =
+      Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+    threadPool.setThreadFactory(newDaemonThreadFactory)
+    
+    return threadPool
+  }
+  
+  // Wrapper over newCachedThreadPool
+  def newDaemonCachedThreadPool: ThreadPoolExecutor = {
+    var threadPool =
+      Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
+  
+    threadPool.setThreadFactory(newDaemonThreadFactory)
+    
+    return threadPool
+  }
+}
+
+@serializable
+case class SplitInfo(val hostAddress: String, val listenPort: Int,
+  val inputId: Int) { 
+
+  var hasSplits = 0
+  var hasSplitsBitVector: BitSet = null
+}
+
+object SplitInfo {
+  // Constants for special values of listenPort
+  val MappersBusy = -1
+
+  // Other constants
+  val UnusedParam = 0
+}
diff --git a/src/scala/spark/ShuffleTrackerStrategy.scala b/src/scala/spark/ShuffleTrackerStrategy.scala
new file mode 100644
index 0000000000000000000000000000000000000000..d982f48bb520a4fe6316094448ae1361cab47bb0
--- /dev/null
+++ b/src/scala/spark/ShuffleTrackerStrategy.scala
@@ -0,0 +1,81 @@
+package spark
+
+/**
+ * A trait for implementing tracker strategies for the shuffle system.
+ */
+trait ShuffleTrackerStrategy {
+  // Initialize
+  def initialize(outputLocs_ : Array[SplitInfo]): Unit
+  
+  // Select a split, update internal stats, and send it back
+  def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int
+  
+  // A reducer is done. Update internal stats
+  def deleteReducerFrom(reducerSplitInfo: SplitInfo, 
+    serverSplitIndex: Int): Unit
+}
+
+/**
+ * A simple ShuffleTrackerStrategy that tries to balance the total number of 
+ * connections created for each mapper.
+ */
+class BalanceConnectionsShuffleTrackerStrategy
+extends ShuffleTrackerStrategy with Logging {
+  var outputLocs: Array[SplitInfo] = null
+  var curConnectionsPerLoc: Array[Int] = null
+  var totalConnectionsPerLoc: Array[Int] = null
+  
+  // The order of elements in the outputLocs (splitIndex) is used to pass 
+  // information back and forth between the tracker, mappers, and reducers
+  def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
+    outputLocs = outputLocs_
+    
+    // Now initialize other data structures
+    curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
+    totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
+  }
+  
+  def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized {
+    var minConnections = Int.MaxValue
+    var splitIndex = -1
+    
+    for (i <- 0 until curConnectionsPerLoc.size) {
+      // TODO: Use of MaxRxConnections instead of MaxTxConnections is 
+      // intentional here. MaxTxConnections is per machine whereas 
+      // MaxRxConnections is per mapper/reducer. Will have to find a better way.
+      if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
+        totalConnectionsPerLoc(i) < minConnections && 
+        !reducerSplitInfo.hasSplitsBitVector.get(i)) {
+        minConnections = totalConnectionsPerLoc(i)
+        splitIndex = i
+      }
+    }
+  
+    if (splitIndex != -1) {
+      curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
+      totalConnectionsPerLoc(splitIndex) = 
+        totalConnectionsPerLoc(splitIndex) + 1
+        
+      curConnectionsPerLoc.foreach { i =>
+        print ("" + i + " ")
+      }
+      println("")
+    }
+  
+    return splitIndex
+  }
+  
+  def deleteReducerFrom(reducerSplitInfo: SplitInfo, 
+    serverSplitIndex: Int): Unit = synchronized {
+    // Decrease number of active connections
+    curConnectionsPerLoc(serverSplitIndex) = 
+      curConnectionsPerLoc(serverSplitIndex) - 1
+
+    assert(curConnectionsPerLoc(serverSplitIndex) >= 0)
+
+    curConnectionsPerLoc.foreach {  i =>
+      print ("" + i + " ")
+    }
+    println("")
+  }
+}
diff --git a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
index 44822f5143cc24d4847c07dd5c240d10e484ade1..77be47cc3769bdb1c7236c9a0f852f4cdd6231ad 100644
--- a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
+++ b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala
@@ -97,8 +97,7 @@ extends Shuffle[K, V, C] with Logging {
       combiners = new HashMap[K, C]
       
       var threadPool = 
-        TrackedCustomParallelLocalFileShuffle.newDaemonFixedThreadPool(
-          TrackedCustomParallelLocalFileShuffle.MaxRxConnections)
+        Shuffle.newDaemonFixedThreadPool(Shuffle.MaxRxConnections)
         
       // Start consumer
       var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
@@ -107,8 +106,8 @@ extends Shuffle[K, V, C] with Logging {
       logInfo("ShuffleConsumer started...")
         
       while (hasSplits < totalSplits) {
-        var numThreadsToCreate = Math.min(totalSplits, 
-          TrackedCustomParallelLocalFileShuffle.MaxRxConnections) - 
+        var numThreadsToCreate = 
+          Math.min(totalSplits, Shuffle.MaxRxConnections) - 
           threadPool.getActiveCount
       
         while (hasSplits < totalSplits && numThreadsToCreate > 0) {
@@ -133,7 +132,7 @@ extends Shuffle[K, V, C] with Logging {
         }
         
         // Sleep for a while before creating new threads
-        Thread.sleep(TrackedCustomParallelLocalFileShuffle.MinKnockInterval)
+        Thread.sleep(Shuffle.MinKnockInterval)
       }
       
       threadPool.shutdown()
@@ -181,9 +180,8 @@ extends Shuffle[K, V, C] with Logging {
   
   // Talks to the tracker and receives instruction
   private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = {
-    val clientSocketToTracker = new Socket(
-      TrackedCustomParallelLocalFileShuffle.MasterHostAddress, 
-      TrackedCustomParallelLocalFileShuffle.MasterTrackerPort)
+    val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, 
+      Shuffle.MasterTrackerPort)
     val oosTracker =
       new ObjectOutputStream(clientSocketToTracker.getOutputStream)
     oosTracker.flush()
@@ -219,8 +217,7 @@ extends Shuffle[K, V, C] with Logging {
   
   class ShuffleTracker(outputLocs: Array[SplitInfo])
   extends Thread with Logging {
-    var threadPool = 
-      TrackedCustomParallelLocalFileShuffle.newDaemonCachedThreadPool
+    var threadPool = Shuffle.newDaemonCachedThreadPool
     var serverSocket: ServerSocket = null
 
     // Create trackerStrategy object
@@ -236,8 +233,7 @@ extends Shuffle[K, V, C] with Logging {
     trackerStrategy.initialize(outputLocs)
 
     override def run: Unit = {
-      serverSocket = new ServerSocket(
-        TrackedCustomParallelLocalFileShuffle.MasterTrackerPort)
+      serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
       logInfo("ShuffleTracker" + serverSocket)
       
       try {
@@ -392,8 +388,7 @@ extends Shuffle[K, V, C] with Logging {
       }
       
       var timeOutTimer = new Timer
-      timeOutTimer.schedule(timeOutTask, 
-        TrackedCustomParallelLocalFileShuffle.MaxKnockInterval)
+      timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
         
       // Create a temp variable to be used in different places
       val requestPath = "http://%s:%d/shuffle/%s".format(
@@ -480,9 +475,8 @@ extends Shuffle[K, V, C] with Logging {
     // Connect to the tracker and update its stats
     private def sendLeavingNotification(): Unit = synchronized {
       if (!alreadySentLeavingNotification) {
-        val clientSocketToTracker =
-          new Socket(TrackedCustomParallelLocalFileShuffle.MasterHostAddress, 
-            TrackedCustomParallelLocalFileShuffle.MasterTrackerPort)
+        val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, 
+          Shuffle.MasterTrackerPort)
         val oosTracker =
           new ObjectOutputStream(clientSocketToTracker.getOutputStream)
         oosTracker.flush()
@@ -499,7 +493,7 @@ extends Shuffle[K, V, C] with Logging {
           oosTracker.writeObject(getLocalSplitInfo)
           oosTracker.flush()
           
-          // Send serverSplitInfo so that tracker can update its stats       
+          // Send serverSplitInfo so that tracker can update its stats
           oosTracker.writeObject(splitIndex)
           oosTracker.flush()
           
@@ -538,89 +532,6 @@ extends Shuffle[K, V, C] with Logging {
   }  
 }
 
-trait ShuffleTrackerStrategy {
-  def initialize(outputLocs_ : Array[SplitInfo]): Unit
-  def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int
-  def deleteReducerFrom(reducerSplitInfo: SplitInfo, 
-    serverSplitIndex: Int): Unit
-}
-
-class BalanceConnectionsShuffleTrackerStrategy
-extends ShuffleTrackerStrategy with Logging {
-  var outputLocs: Array[SplitInfo] = null
-  var curConnectionsPerLoc: Array[Int] = null
-  var totalConnectionsPerLoc: Array[Int] = null
-  
-  def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
-    outputLocs = outputLocs_
-    
-    // Now initialize other data structures
-    curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
-    totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
-  }
-  
-  def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized {
-    var minConnections = Int.MaxValue
-    var splitIndex = -1
-    
-    for (i <- 0 until curConnectionsPerLoc.size) {
-      // TODO: Use of MaxRxConnections instead of MaxTxConnections is 
-      // intentional here. MaxTxConnections is per machine whereas 
-      // MaxRxConnections is per mapper/reducer. Will have to find a better way.
-      if (curConnectionsPerLoc(i) < 
-        TrackedCustomParallelLocalFileShuffle.MaxRxConnections &&
-        totalConnectionsPerLoc(i) < minConnections && 
-        !reducerSplitInfo.hasSplitsBitVector.get(i)) {
-        minConnections = totalConnectionsPerLoc(i)
-        splitIndex = i
-      }
-    }
-  
-    if (splitIndex != -1) {
-      curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
-      totalConnectionsPerLoc(splitIndex) = 
-        totalConnectionsPerLoc(splitIndex) + 1
-        
-      curConnectionsPerLoc.foreach { i =>
-        print ("" + i + " ")
-      }
-      println("")
-    }
-  
-    return splitIndex
-  }
-  
-  def deleteReducerFrom(reducerSplitInfo: SplitInfo, 
-    serverSplitIndex: Int): Unit = synchronized {
-    // Decrease number of active connections
-    curConnectionsPerLoc(serverSplitIndex) = 
-      curConnectionsPerLoc(serverSplitIndex) - 1
-
-    assert(curConnectionsPerLoc(serverSplitIndex) >= 0)
-
-    curConnectionsPerLoc.foreach {  i =>
-      print ("" + i + " ")
-    }
-    println("")
-  }
-}
-
-@serializable
-case class SplitInfo(val hostAddress: String, val listenPort: Int,
-  val inputId: Int) { 
-
-  var hasSplits = 0
-  var hasSplitsBitVector: BitSet = null
-}
-
-object SplitInfo {
-  // Constants for special values of listenPort
-  val MappersBusy = -1
-
-  // Other constants
-  val UnusedParam = 0
-}
-
 object TrackedCustomParallelLocalFileShuffle extends Logging {
   // Tracker communication constants
   val ReducerEntering = 0
@@ -639,27 +550,6 @@ object TrackedCustomParallelLocalFileShuffle extends Logging {
   // Random number generator
   var ranGen = new Random
   
-  // Load config parameters
-
-  // ShuffleTracker info
-  private var MasterHostAddress_ = System.getProperty(
-    "spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress)
-  private var MasterTrackerPort_ = System.getProperty(
-    "spark.shuffle.masterTrackerPort", "22222").toInt
-
-  // Used thoughout the code for small and large waits/timeouts
-  private var MinKnockInterval_ = System.getProperty(
-    "spark.shuffle.minKnockInterval", "1000").toInt
-  private var MaxKnockInterval_ =  System.getProperty(
-    "spark.shuffle.maxKnockInterval", "5000").toInt
-
-  // Maximum number of connections
-  private var MaxRxConnections_ = System.getProperty(
-    "spark.shuffle.maxRxConnections", "4").toInt
-  private var MaxTxConnections_ = System.getProperty(
-    "spark.shuffle.maxTxConnections", "8").toInt
-
-  
   private def initializeIfNeeded() = synchronized {
     if (!initialized) {
       // TODO: localDir should be created by some mechanism common to Spark
@@ -702,15 +592,6 @@ object TrackedCustomParallelLocalFileShuffle extends Logging {
     }
   }
   
-  def MasterHostAddress = MasterHostAddress_
-  def MasterTrackerPort = MasterTrackerPort_
-
-  def MinKnockInterval = MinKnockInterval_
-  def MaxKnockInterval = MaxKnockInterval_
-  
-  def MaxRxConnections = MaxRxConnections_
-  def MaxTxConnections = MaxTxConnections_
-  
   def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
     initializeIfNeeded()
     val dir = new File(shuffleDir, shuffleId + "/" + inputId)
@@ -723,41 +604,9 @@ object TrackedCustomParallelLocalFileShuffle extends Logging {
     nextShuffleId.getAndIncrement()
   }
   
-  // Returns a standard ThreadFactory except all threads are daemons
-  private def newDaemonThreadFactory: ThreadFactory = {
-    new ThreadFactory {
-      def newThread(r: Runnable): Thread = {
-        var t = Executors.defaultThreadFactory.newThread(r)
-        t.setDaemon(true)
-        return t
-      }
-    }
-  }
-
-  // Wrapper over newFixedThreadPool
-  def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
-    var threadPool =
-      Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
-    threadPool.setThreadFactory(newDaemonThreadFactory)
-    
-    return threadPool
-  }
-  
-  // Wrapper over newCachedThreadPool
-  def newDaemonCachedThreadPool: ThreadPoolExecutor = {
-    var threadPool =
-      Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-  
-    threadPool.setThreadFactory(newDaemonThreadFactory)
-    
-    return threadPool
-  }
-  
   class ShuffleServer
   extends Thread with Logging {
-    var threadPool = newDaemonFixedThreadPool(
-      CustomParallelLocalFileShuffle.MaxTxConnections)
+    var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
 
     var serverSocket: ServerSocket = null