diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 31d474a20fa857a262d40bf8a8c3966ba847d7ae..6790c86f651b4370eada5e53fb8c64a99104841e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -62,7 +62,7 @@ class LogisticRegressionModel (
   override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
       intercept: Double) = {
     val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
-    val score = 1.0/ (1.0 + math.exp(-margin))
+    val score = 1.0 / (1.0 + math.exp(-margin))
     threshold match {
       case Some(t) => if (score < t) 0.0 else 1.0
       case None => score
@@ -204,6 +204,8 @@ class LogisticRegressionWithLBFGS private (
    */
   def this() = this(1E-4, 100, 0.0)
 
+  this.setFeatureScaling(true)
+
   private val gradient = new LogisticGradient()
   private val updater = new SimpleUpdater()
   // Have to return new LBFGS object every time since users can reset the parameters anytime.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 54854252d7477027f960f774d54ab3b8ad7dded6..20c1fdd2269ceef7b6b44488bf8c3fed49b960bf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.mllib.regression
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.feature.StandardScaler
 import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.optimization._
@@ -94,6 +95,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
 
   protected var validateData: Boolean = true
 
+  /**
+   * Whether to perform feature scaling before model training to reduce the condition numbers
+   * which can significantly help the optimizer converging faster. The scaling correction will be
+   * translated back to resulting model weights, so it's transparent to users.
+   * Note: This technique is used in both libsvm and glmnet packages. Default false.
+   */
+  private var useFeatureScaling = false
+
+  /**
+   * Set if the algorithm should use feature scaling to improve the convergence during optimization.
+   */
+  private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = {
+    this.useFeatureScaling = useFeatureScaling
+    this
+  }
+
   /**
    * Create a model given the weights and intercept
    */
@@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
       throw new SparkException("Input validation failed.")
     }
 
+    /**
+     * Scaling columns to unit variance as a heuristic to reduce the condition number:
+     *
+     * During the optimization process, the convergence (rate) depends on the condition number of
+     * the training dataset. Scaling the variables often reduces this condition number
+     * heuristically, thus improving the convergence rate. Without reducing the condition number,
+     * some training datasets mixing the columns with different scales may not be able to converge.
+     *
+     * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return
+     * the weights in the original scale.
+     * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
+     *
+     * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing
+     * the variance of each column (without subtracting the mean), and train the model in the
+     * scaled space. Then we transform the coefficients from the scaled space to the original scale
+     * as GLMNET and LIBSVM do.
+     *
+     * Currently, it's only enabled in LogisticRegressionWithLBFGS
+     */
+    val scaler = if (useFeatureScaling) {
+      (new StandardScaler).fit(input.map(x => x.features))
+    } else {
+      null
+    }
+
     // Prepend an extra variable consisting of all 1.0's for the intercept.
     val data = if (addIntercept) {
-      input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
+      if(useFeatureScaling) {
+        input.map(labeledPoint =>
+          (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features))))
+      } else {
+        input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
+      }
     } else {
-      input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
+      if (useFeatureScaling) {
+        input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features)))
+      } else {
+        input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
+      }
     }
 
     val initialWeightsWithIntercept = if (addIntercept) {
@@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
     val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
 
     val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
-    val weights =
+    var weights =
       if (addIntercept) {
         Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
       } else {
         weightsWithIntercept
       }
 
+    /**
+     * The weights and intercept are trained in the scaled space; we're converting them back to
+     * the original scale.
+     *
+     * Math shows that if we only perform standardization without subtracting means, the intercept
+     * will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i
+     * is the coefficient in the original space, and v_i is the variance of the column i.
+     */
+    if (useFeatureScaling) {
+      weights = scaler.transform(weights)
+    }
+
     createModel(weights, intercept)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 2289c6cdc19ded8b9e748f51a2797407a9d7530d..bc05b2046878f6cbc93aaf89a653404962c5f673 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -185,6 +185,63 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
     // Test prediction on Array.
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
+
+  test("numerical stability of scaling features using logistic regression with LBFGS") {
+    /**
+     * If we rescale the features, the condition number will be changed so the convergence rate
+     * and the solution will not equal to the original solution multiple by the scaling factor
+     * which it should be.
+     *
+     * However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first,
+     * no matter how we multiple a scaling factor into the dataset, the convergence rate should be
+     * the same, and the solution should equal to the original solution multiple by the scaling
+     * factor.
+     */
+
+    val nPoints = 10000
+    val A = 2.0
+    val B = -1.5
+
+    val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+    val initialWeights = Vectors.dense(0.0)
+
+    val testRDD1 = sc.parallelize(testData, 2)
+
+    val testRDD2 = sc.parallelize(
+      testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2)
+
+    val testRDD3 = sc.parallelize(
+      testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2)
+
+    testRDD1.cache()
+    testRDD2.cache()
+    testRDD3.cache()
+
+    val lrA = new LogisticRegressionWithLBFGS().setIntercept(true)
+    val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false)
+
+    val modelA1 = lrA.run(testRDD1, initialWeights)
+    val modelA2 = lrA.run(testRDD2, initialWeights)
+    val modelA3 = lrA.run(testRDD3, initialWeights)
+
+    val modelB1 = lrB.run(testRDD1, initialWeights)
+    val modelB2 = lrB.run(testRDD2, initialWeights)
+    val modelB3 = lrB.run(testRDD3, initialWeights)
+
+    // For model trained with feature standardization, the weights should
+    // be the same in the scaled space. Note that the weights here are already
+    // in the original space, we transform back to scaled space to compare.
+    assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01)
+    assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01)
+
+    // Training data with different scales without feature standardization
+    // will not yield the same result in the scaled space due to poor
+    // convergence rate.
+    assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1)
+    assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1)
+  }
+
 }
 
 class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {