From 0d17593b32c12c3e39575430aa85cf20e56fae6a Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Wed, 13 Apr 2016 13:20:29 -0700
Subject: [PATCH] [SPARK-14461][ML] GLM training summaries should provide
 solver

## What changes were proposed in this pull request?
GLM training summaries should provide solver.

## How was this patch tested?
Unit tests.

cc jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12253 from yanboliang/spark-14461.
---
 .../ml/regression/GeneralizedLinearRegression.scala    | 10 +++++++---
 .../regression/GeneralizedLinearRegressionSuite.scala  |  4 ++++
 2 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 00cf25dc54..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
         predictionColName,
         model,
         wlsModel.diagInvAtWA.toArray,
-        1)
+        1,
+        getSolver)
       return model.setSummary(trainingSummary)
     }
 
@@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
       predictionColName,
       model,
       irlsModel.diagInvAtWA.toArray,
-      irlsModel.numIterations)
+      irlsModel.numIterations,
+      getSolver)
 
     model.setSummary(trainingSummary)
   }
@@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
  * @param model the model that should be summarized
  * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
  * @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
  */
 @Since("2.0.0")
 @Experimental
@@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
     @Since("2.0.0") val predictionCol: String,
     @Since("2.0.0") val model: GeneralizedLinearRegressionModel,
     private val diagInvAtWA: Array[Double],
-    @Since("2.0.0") val numIterations: Int) extends Serializable {
+    @Since("2.0.0") val numIterations: Int,
+    @Since("2.0.0") val solver: String) extends Serializable {
 
   import GeneralizedLinearRegression._
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 4905f3e068..3ecc210abd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite
     assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
     assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
     assert(summary.aic ~== aicR absTol 1E-3)
+    assert(summary.solver === "irls")
   }
 
   test("glm summary: binomial family with weight") {
@@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite
     assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
     assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
     assert(summary.aic ~== aicR absTol 1E-3)
+    assert(summary.solver === "irls")
   }
 
   test("glm summary: poisson family with weight") {
@@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite
     assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
     assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
     assert(summary.aic ~== aicR absTol 1E-3)
+    assert(summary.solver === "irls")
   }
 
   test("glm summary: gamma family with weight") {
@@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite
     assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
     assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
     assert(summary.aic ~== aicR absTol 1E-3)
+    assert(summary.solver === "irls")
   }
 
   test("read/write") {
-- 
GitLab