Skip to content
Snippets Groups Projects
Commit 20424dad authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Reynold Xin
Browse files

[SPARK-2174][MLLIB] treeReduce and treeAggregate

In `reduce` and `aggregate`, the driver node spends linear time on the number of partitions. It becomes a bottleneck when there are many partitions and the data from each partition is big.

SPARK-1485 (#506) tracks the progress of implementing AllReduce on Spark. I did several implementations including butterfly, reduce + broadcast, and treeReduce + broadcast. treeReduce + BT broadcast seems to be right way to go for Spark. Using binary tree may introduce some overhead in communication, because the driver still need to coordinate on data shuffling. In my experiments, n -> sqrt(n) -> 1 gives the best performance in general, which is why I set "depth = 2" in MLlib algorithms. But it certainly needs more testing.

I left `treeReduce` and `treeAggregate` public for easy testing. Some numbers from a test on 32-node m3.2xlarge cluster.

code:

~~~
import breeze.linalg._
import org.apache.log4j._

Logger.getRootLogger.setLevel(Level.OFF)

for (n <- Seq(1, 10, 100, 1000, 10000, 100000, 1000000)) {
  val vv = sc.parallelize(0 until 1024, 1024).map(i => DenseVector.zeros[Double](n))
  var start = System.nanoTime(); vv.treeReduce(_ + _, 2); println((System.nanoTime() - start) / 1e9)
  start = System.nanoTime(); vv.reduce(_ + _); println((System.nanoTime() - start) / 1e9)
}
~~~

out:

| n | treeReduce(,2) | reduce |
|---|---------------------|-----------|
| 10 | 0.215538731 | 0.204206899 |
| 100 | 0.278405907 | 0.205732582 |
| 1000 | 0.208972182 | 0.214298272 |
| 10000 | 0.194792071 | 0.349353687 |
| 100000 | 0.347683285 | 6.086671892 |
| 1000000 | 2.589350682 | 66.572906702 |

CC: @pwendell

This is clearly more scalable than the default implementation. My question is whether we should use this implementation in `reduce` and `aggregate` or put them as separate methods. The concern is that users may use `reduce` and `aggregate` as collect, where having multiple stages doesn't reduce the data size. However, in this case, `collect` is more appropriate.

Author: Xiangrui Meng <meng@databricks.com>

Closes #1110 from mengxr/tree and squashes the following commits:

c6cd267 [Xiangrui Meng] make depth default to 2
b04b96a [Xiangrui Meng] address comments
9bcc5d3 [Xiangrui Meng] add depth for readability
7495681 [Xiangrui Meng] fix compile error
142a857 [Xiangrui Meng] merge master
d58a087 [Xiangrui Meng] move treeReduce and treeAggregate to mllib
8a2a59c [Xiangrui Meng] Merge branch 'master' into tree
be6a88a [Xiangrui Meng] use treeAggregate in mllib
0f94490 [Xiangrui Meng] add docs
eb71c33 [Xiangrui Meng] add treeReduce
fe42a5e [Xiangrui Meng] add treeAggregate
parent 96ba04bb
No related branches found
No related tags found
No related merge requests found
......@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
/**
......@@ -79,7 +80,7 @@ class RowMatrix(
private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
val n = numCols().toInt
val vbr = rows.context.broadcast(v)
rows.aggregate(BDV.zeros[Double](n))(
rows.treeAggregate(BDV.zeros[Double](n))(
seqOp = (U, r) => {
val rBrz = r.toBreeze
val a = rBrz.dot(vbr.value)
......@@ -91,9 +92,7 @@ class RowMatrix(
s"Do not support vector operation from type ${rBrz.getClass.getName}.")
}
U
},
combOp = (U1, U2) => U1 += U2
)
}, combOp = (U1, U2) => U1 += U2)
}
/**
......@@ -104,13 +103,11 @@ class RowMatrix(
val nt: Int = n * (n + 1) / 2
// Compute the upper triangular part of the gram matrix.
val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))(
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
seqOp = (U, v) => {
RowMatrix.dspr(1.0, v, U.data)
U
},
combOp = (U1, U2) => U1 += U2
)
}, combOp = (U1, U2) => U1 += U2)
RowMatrix.triuToFull(n, GU.data)
}
......@@ -290,9 +287,10 @@ class RowMatrix(
s"We need at least $mem bytes of memory.")
}
val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) =>
(s1._1 + s2._1, s1._2 += s2._2)
)
updateNumRows(m)
......@@ -353,10 +351,9 @@ class RowMatrix(
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
updateNumRows(summary.count)
summary
}
......
......@@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.rdd.RDDFunctions._
/**
* Class used to solve an optimization problem using Gradient Descent.
......@@ -177,7 +178,7 @@ object GradientDescent extends Logging {
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
.aggregate((BDV.zeros[Double](n), 0.0))(
.treeAggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
(grad, loss + l)
......
......@@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.rdd.RDDFunctions._
/**
* :: DeveloperApi ::
......@@ -199,7 +200,7 @@ object LBFGS extends Logging {
val n = weights.length
val bcWeights = data.context.broadcast(weights)
val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
......
......@@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd
import scala.language.implicitConversions
import scala.reflect.ClassTag
import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
/**
* Machine learning specific RDD functions.
......@@ -44,6 +47,69 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
new SlidingRDD[T](self, windowSize)
}
}
/**
* Reduces the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#reduce]]
*/
def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
val cleanF = self.context.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
} else {
None
}
}
val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
if (c.isDefined && x.isDefined) {
Some(cleanF(c.get, x.get))
} else if (c.isDefined) {
c
} else if (x.isDefined) {
x
} else {
None
}
}
RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#aggregate]]
*/
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (self.partitions.size == 0) {
return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
}
val cleanSeqOp = self.context.clean(seqOp)
val cleanCombOp = self.context.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.size
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
while (numPartitions > scale + numPartitions / scale) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)
}
}
private[mllib]
......
......@@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
val expected = data.flatMap(x => x).sliding(3).toList
assert(sliding.collect().toList === expected)
}
test("treeAggregate") {
val rdd = sc.makeRDD(-1000 until 1000, 10)
def seqOp = (c: Long, x: Int) => c + x
def combOp = (c1: Long, c2: Long) => c1 + c2
for (depth <- 1 until 10) {
val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
assert(sum === -1000L)
}
}
test("treeReduce") {
val rdd = sc.makeRDD(-1000 until 1000, 10)
for (depth <- 1 until 10) {
val sum = rdd.treeReduce(_ + _, depth)
assert(sum === -1000)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment