diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 514b4ef98dc5bbcc55ed7df1b60fd7b0d36dffd2..52c9e95d6012f9f3cef23163854e81230a190034 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -320,7 +320,7 @@ object ALS extends Logging {
 
   /** Trait for least squares solvers applied to the normal equation. */
   private[recommendation] trait LeastSquaresNESolver extends Serializable {
-    /** Solves a least squares problem (possibly with other constraints). */
+    /** Solves a least squares problem with regularization (possibly with other constraints). */
     def solve(ne: NormalEquation, lambda: Double): Array[Float]
   }
 
@@ -332,20 +332,19 @@ object ALS extends Logging {
     /**
      * Solves a least squares problem with L2 regularization:
      *
-     *   min norm(A x - b)^2^ + lambda * n * norm(x)^2^
+     *   min norm(A x - b)^2^ + lambda * norm(x)^2^
      *
      * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
-     * @param lambda regularization constant, which will be scaled by n
+     * @param lambda regularization constant
      * @return the solution x
      */
     override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
       val k = ne.k
       // Add scaled lambda to the diagonals of AtA.
-      val scaledlambda = lambda * ne.n
       var i = 0
       var j = 2
       while (i < ne.triK) {
-        ne.ata(i) += scaledlambda
+        ne.ata(i) += lambda
         i += j
         j += 1
       }
@@ -391,7 +390,7 @@ object ALS extends Logging {
     override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
       val rank = ne.k
       initialize(rank)
-      fillAtA(ne.ata, lambda * ne.n)
+      fillAtA(ne.ata, lambda)
       val x = NNLS.solve(ata, ne.atb, workspace)
       ne.reset()
       x.map(x => x.toFloat)
@@ -420,7 +419,15 @@ object ALS extends Logging {
     }
   }
 
-  /** Representing a normal equation (ALS' subproblem). */
+  /**
+   * Representing a normal equation to solve the following weighted least squares problem:
+   *
+   * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x.
+   *
+   * Its normal equation is given by
+   *
+   * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0.
+   */
   private[recommendation] class NormalEquation(val k: Int) extends Serializable {
 
     /** Number of entries in the upper triangular part of a k-by-k matrix. */
@@ -429,8 +436,6 @@ object ALS extends Logging {
     val ata = new Array[Double](triK)
     /** A^T^ * b */
     val atb = new Array[Double](k)
-    /** Number of observations. */
-    var n = 0
 
     private val da = new Array[Double](k)
     private val upper = "U"
@@ -444,28 +449,13 @@ object ALS extends Logging {
     }
 
     /** Adds an observation. */
-    def add(a: Array[Float], b: Float): this.type = {
-      require(a.length == k)
-      copyToDouble(a)
-      blas.dspr(upper, k, 1.0, da, 1, ata)
-      blas.daxpy(k, b.toDouble, da, 1, atb, 1)
-      n += 1
-      this
-    }
-
-    /**
-     * Adds an observation with implicit feedback. Note that this does not increment the counter.
-     */
-    def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
+    def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
+      require(c >= 0.0)
       require(a.length == k)
-      // Extension to the original paper to handle b < 0. confidence is a function of |b| instead
-      // so that it is never negative.
-      val confidence = 1.0 + alpha * math.abs(b)
       copyToDouble(a)
-      blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
-      // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
-      if (b > 0) {
-        blas.daxpy(k, confidence, da, 1, atb, 1)
+      blas.dspr(upper, k, c, da, 1, ata)
+      if (b != 0.0) {
+        blas.daxpy(k, c * b, da, 1, atb, 1)
       }
       this
     }
@@ -475,7 +465,6 @@ object ALS extends Logging {
       require(other.k == k)
       blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
       blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
-      n += other.n
       this
     }
 
@@ -483,7 +472,6 @@ object ALS extends Logging {
     def reset(): Unit = {
       ju.Arrays.fill(ata, 0.0)
       ju.Arrays.fill(atb, 0.0)
-      n = 0
     }
   }
 
@@ -1114,6 +1102,7 @@ object ALS extends Logging {
             ls.merge(YtY.get)
           }
           var i = srcPtrs(j)
+          var numExplicits = 0
           while (i < srcPtrs(j + 1)) {
             val encoded = srcEncodedIndices(i)
             val blockId = srcEncoder.blockId(encoded)
@@ -1121,13 +1110,23 @@ object ALS extends Logging {
             val srcFactor = sortedSrcFactors(blockId)(localIndex)
             val rating = ratings(i)
             if (implicitPrefs) {
-              ls.addImplicit(srcFactor, rating, alpha)
+              // Extension to the original paper to handle b < 0. confidence is a function of |b|
+              // instead so that it is never negative. c1 is confidence - 1.0.
+              val c1 = alpha * math.abs(rating)
+              // For rating <= 0, the corresponding preference is 0. So the term below is only added
+              // for rating > 0. Because YtY is already added, we need to adjust the scaling here.
+              if (rating > 0) {
+                numExplicits += 1
+                ls.add(srcFactor, (c1 + 1.0) / c1, c1)
+              }
             } else {
               ls.add(srcFactor, rating)
+              numExplicits += 1
             }
             i += 1
           }
-          dstFactors(j) = solver.solve(ls, regParam)
+          // Weight lambda by the number of explicit ratings based on the ALS-WR paper.
+          dstFactors(j) = solver.solve(ls, numExplicits * regParam)
           j += 1
         }
         dstFactors
@@ -1141,7 +1140,7 @@ object ALS extends Logging {
   private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
     factorBlocks.values.aggregate(new NormalEquation(rank))(
       seqOp = (ne, factors) => {
-        factors.foreach(ne.add(_, 0.0f))
+        factors.foreach(ne.add(_, 0.0))
         ne
       },
       combOp = (ne1, ne2) => ne1.merge(ne2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 0bb06e9e8ac9c6254768794a93329e4990909c20..29d4ec5f85c1e14ee611cab0874f040dc0c54a6f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -68,39 +68,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
     }
   }
 
-  test("normal equation construction with explict feedback") {
+  test("normal equation construction") {
     val k = 2
     val ne0 = new NormalEquation(k)
-      .add(Array(1.0f, 2.0f), 3.0f)
-      .add(Array(4.0f, 5.0f), 6.0f)
+      .add(Array(1.0f, 2.0f), 3.0)
+      .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted
     assert(ne0.k === k)
     assert(ne0.triK === k * (k + 1) / 2)
-    assert(ne0.n === 2)
     // NumPy code that computes the expected values:
     // A = np.matrix("1 2; 4 5")
     // b = np.matrix("3; 6")
-    // ata = A.transpose() * A
-    // atb = A.transpose() * b
-    assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8)
-    assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8)
+    // C = np.matrix(np.diag([1, 2]))
+    // ata = A.transpose() * C * A
+    // atb = A.transpose() * C * b
+    assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8)
+    assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8)
 
     val ne1 = new NormalEquation(2)
-      .add(Array(7.0f, 8.0f), 9.0f)
+      .add(Array(7.0f, 8.0f), 9.0)
     ne0.merge(ne1)
-    assert(ne0.n === 3)
     // NumPy code that computes the expected values:
     // A = np.matrix("1 2; 4 5; 7 8")
     // b = np.matrix("3; 6; 9")
-    // ata = A.transpose() * A
-    // atb = A.transpose() * b
-    assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8)
-    assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8)
+    // C = np.matrix(np.diag([1, 2, 1]))
+    // ata = A.transpose() * C * A
+    // atb = A.transpose() * C * b
+    assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8)
+    assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8)
 
     intercept[IllegalArgumentException] {
-      ne0.add(Array(1.0f), 2.0f)
+      ne0.add(Array(1.0f), 2.0)
     }
     intercept[IllegalArgumentException] {
-      ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f)
+      ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0)
+    }
+    intercept[IllegalArgumentException] {
+      ne0.add(Array(1.0f, 2.0f), 0.0, -1.0)
     }
     intercept[IllegalArgumentException] {
       val ne2 = new NormalEquation(3)
@@ -108,41 +111,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
     }
 
     ne0.reset()
-    assert(ne0.n === 0)
     assert(ne0.ata.forall(_ == 0.0))
     assert(ne0.atb.forall(_ == 0.0))
   }
 
-  test("normal equation construction with implicit feedback") {
-    val k = 2
-    val alpha = 0.5
-    val ne0 = new NormalEquation(k)
-      .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
-      .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
-      .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
-    assert(ne0.k === k)
-    assert(ne0.triK === k * (k + 1) / 2)
-    assert(ne0.n === 0) // addImplicit doesn't increase the count.
-    // NumPy code that computes the expected values:
-    // alpha = 0.5
-    // A = np.matrix("-5 -4; -2 -1; 1 2")
-    // b = np.matrix("-3; 0; 3")
-    // b1 = b > 0
-    // c = 1.0 + alpha * np.abs(b)
-    // C = np.diag(c.A1)
-    // I = np.eye(3)
-    // ata = A.transpose() * (C - I) * A
-    // atb = A.transpose() * C * b1
-    assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
-    assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
-  }
-
   test("CholeskySolver") {
     val k = 2
     val ne0 = new NormalEquation(k)
-      .add(Array(1.0f, 2.0f), 4.0f)
-      .add(Array(1.0f, 3.0f), 9.0f)
-      .add(Array(1.0f, 4.0f), 16.0f)
+      .add(Array(1.0f, 2.0f), 4.0)
+      .add(Array(1.0f, 3.0f), 9.0)
+      .add(Array(1.0f, 4.0f), 16.0)
     val ne1 = new NormalEquation(k)
       .merge(ne0)
 
@@ -154,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
     // x0 = np.linalg.lstsq(A, b)[0]
     assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
 
-    assert(ne0.n === 0)
     assert(ne0.ata.forall(_ == 0.0))
     assert(ne0.atb.forall(_ == 0.0))
 
-    val x1 = chol.solve(ne1, 0.5).map(_.toDouble)
+    val x1 = chol.solve(ne1, 1.5).map(_.toDouble)
     // NumPy code that computes the expected solution, where lambda is scaled by n:
-    // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b)
+    // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b)
     assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
   }
 
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index b094e50856f707ed897513c021e4a3cef6805eba..c5c4c13dae1057bc06748e1d8ccadbbb0b5fb17f 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     >>> ratings = sc.parallelize([r1, r2, r3])
     >>> model = ALS.trainImplicit(ratings, 1, seed=10)
     >>> model.predict(2, 2)
-    0.43...
+    0.4...
 
     >>> testset = sc.parallelize([(1, 2), (1, 1)])
     >>> model = ALS.train(ratings, 2, seed=0)
@@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
 
     >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
     >>> model.predict(2,2)
-    0.43...
+    0.4...
 
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
     >>> model.save(sc, path)
     >>> sameModel = MatrixFactorizationModel.load(sc, path)
     >>> sameModel.predict(2,2)
-    0.43...
+    0.4...
     >>> sameModel.predictAll(testset).collect()
     [Rating(...
     >>> try: