diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 36d262fed425aa351f5dc70b2dd7abf7b8c38a2b..8ebc7e27ed4dd75dbfa4c7dc4e68b44fa5fc0951 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.mllib.recommendation
 
-import scala.collection.mutable.{ArrayBuffer, BitSet}
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.math.{abs, sqrt}
 import scala.util.Random
 import scala.util.Sorting
@@ -25,7 +26,7 @@ import scala.util.hashing.byteswap32
 
 import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
 
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.{Logging, HashPartitioner, Partitioner}
 import org.apache.spark.storage.StorageLevel
@@ -39,7 +40,8 @@ import org.apache.spark.mllib.optimization.NNLS
  * of the elements within this block, and the list of destination blocks that each user or
  * product will need to send its feature vector to.
  */
-private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet])
+private[recommendation]
+case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[mutable.BitSet])
 
 
 /**
@@ -382,7 +384,7 @@ class ALS private (
     val userIds = ratings.map(_.user).distinct.sorted
     val numUsers = userIds.length
     val userIdToPos = userIds.zipWithIndex.toMap
-    val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks))
+    val shouldSend = Array.fill(numUsers)(new mutable.BitSet(numProductBlocks))
     for (r <- ratings) {
       shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true
     }
@@ -797,4 +799,120 @@ object ALS {
     : MatrixFactorizationModel = {
     trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
   }
+
+  /**
+   * :: DeveloperApi ::
+   * Statistics of a block in ALS computation.
+   *
+   * @param category type of this block, "user" or "product"
+   * @param index index of this block
+   * @param count number of users or products inside this block, the same as the number of
+   *              least-squares problems to solve on this block in each iteration
+   * @param numRatings total number of ratings inside this block, the same as the number of outer
+   *                   products we need to make on this block in each iteration
+   * @param numInLinks total number of incoming links, the same as the number of vectors to retrieve
+   *                   before each iteration
+   * @param numOutLinks total number of outgoing links, the same as the number of vectors to send
+   *                    for the next iteration
+   */
+  @DeveloperApi
+  case class BlockStats(
+      category: String,
+      index: Int,
+      count: Long,
+      numRatings: Long,
+      numInLinks: Long,
+      numOutLinks: Long)
+
+  /**
+   * :: DeveloperApi ::
+   * Given an RDD of ratings, number of user blocks, and number of product blocks, computes the
+   * statistics of each block in ALS computation. This is useful for estimating cost and diagnosing
+   * load balance.
+   *
+   * @param ratings an RDD of ratings
+   * @param numUserBlocks number of user blocks
+   * @param numProductBlocks number of product blocks
+   * @return statistics of user blocks and product blocks
+   */
+  @DeveloperApi
+  def analyzeBlocks(
+      ratings: RDD[Rating],
+      numUserBlocks: Int,
+      numProductBlocks: Int): Array[BlockStats] = {
+
+    val userPartitioner = new ALSPartitioner(numUserBlocks)
+    val productPartitioner = new ALSPartitioner(numProductBlocks)
+
+    val ratingsByUserBlock = ratings.map { rating =>
+      (userPartitioner.getPartition(rating.user), rating)
+    }
+    val ratingsByProductBlock = ratings.map { rating =>
+      (productPartitioner.getPartition(rating.product),
+        Rating(rating.product, rating.user, rating.rating))
+    }
+
+    val als = new ALS()
+    val (userIn, userOut) =
+      als.makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, userPartitioner)
+    val (prodIn, prodOut) =
+      als.makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, productPartitioner)
+
+    def sendGrid(outLinks: RDD[(Int, OutLinkBlock)]): Map[(Int, Int), Long] = {
+      outLinks.map { x =>
+        val grid = new mutable.HashMap[(Int, Int), Long]()
+        val uPartition = x._1
+        x._2.shouldSend.foreach { ss =>
+          ss.foreach { pPartition =>
+            val pair = (uPartition, pPartition)
+            grid.put(pair, grid.getOrElse(pair, 0L) + 1L)
+          }
+        }
+        grid
+      }.reduce { (grid1, grid2) =>
+        grid2.foreach { x =>
+          grid1.put(x._1, grid1.getOrElse(x._1, 0L) + x._2)
+        }
+        grid1
+      }.toMap
+    }
+
+    val userSendGrid = sendGrid(userOut)
+    val prodSendGrid = sendGrid(prodOut)
+
+    val userInbound = new Array[Long](numUserBlocks)
+    val prodInbound = new Array[Long](numProductBlocks)
+    val userOutbound = new Array[Long](numUserBlocks)
+    val prodOutbound = new Array[Long](numProductBlocks)
+
+    for (u <- 0 until numUserBlocks; p <- 0 until numProductBlocks) {
+      userOutbound(u) += userSendGrid.getOrElse((u, p), 0L)
+      prodInbound(p) += userSendGrid.getOrElse((u, p), 0L)
+      userInbound(u) += prodSendGrid.getOrElse((p, u), 0L)
+      prodOutbound(p) += prodSendGrid.getOrElse((p, u), 0L)
+    }
+
+    val userCounts = userOut.mapValues(x => x.elementIds.length).collectAsMap()
+    val prodCounts = prodOut.mapValues(x => x.elementIds.length).collectAsMap()
+
+    val userRatings = countRatings(userIn)
+    val prodRatings = countRatings(prodIn)
+
+    val userStats = Array.tabulate(numUserBlocks)(
+      u => BlockStats("user", u, userCounts(u), userRatings(u), userInbound(u), userOutbound(u)))
+    val productStatus = Array.tabulate(numProductBlocks)(
+      p => BlockStats("product", p, prodCounts(p), prodRatings(p), prodInbound(p), prodOutbound(p)))
+
+    (userStats ++ productStatus).toArray
+  }
+
+  private def countRatings(inLinks: RDD[(Int, InLinkBlock)]): Map[Int, Long] = {
+    inLinks.mapValues { ilb =>
+      var numRatings = 0L
+      ilb.ratingsForBlock.foreach { ar =>
+        ar.foreach { p => numRatings += p._1.length }
+      }
+      numRatings
+    }.collectAsMap().toMap
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 81bebec8c7a3954f202cf616d2a5b7b43dd94991..017c39edb185f7cb9d2adfb9fb3006c090a0fdb1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -22,11 +22,11 @@ import scala.math.abs
 import scala.util.Random
 
 import org.scalatest.FunSuite
-
 import org.jblas.DoubleMatrix
 
-import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.recommendation.ALS.BlockStats
 
 object ALSSuite {
 
@@ -67,8 +67,10 @@ object ALSSuite {
       case true =>
         // Generate raw values from [0,9], or if negativeWeights, from [-2,7]
         val raw = new DoubleMatrix(users, products,
-          Array.fill(users * products)((if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*)
-        val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
+          Array.fill(users * products)(
+            (if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*)
+        val prefs =
+          new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
         (raw, prefs)
       case false => (userMatrix.mmul(productMatrix), null)
     }
@@ -160,6 +162,22 @@ class ALSSuite extends FunSuite with LocalSparkContext {
     testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false)
   }
 
+  test("analyze one user block and one product block") {
+    val localRatings = Seq(
+      Rating(0, 100, 1.0),
+      Rating(0, 101, 2.0),
+      Rating(0, 102, 3.0),
+      Rating(1, 102, 4.0),
+      Rating(2, 103, 5.0))
+    val ratings = sc.makeRDD(localRatings, 2)
+    val stats = ALS.analyzeBlocks(ratings, 1, 1)
+    assert(stats.size === 2)
+    assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3))
+    assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4))
+  }
+
+  // TODO: add tests for analyzing multiple user/product blocks
+
   /**
    * Test if we can correctly factorize R = U * P where U and P are of known rank.
    *