From 217667174e267adba5469cf26b3e4418e3d1cc90 Mon Sep 17 00:00:00 2001
From: Shivaram Venkataraman <shivaram@eecs.berkeley.edu>
Date: Wed, 17 Jul 2013 16:08:34 -0700
Subject: [PATCH] Return Array[Double] from SGD instead of DoubleMatrix

---
 .../scala/spark/mllib/optimization/GradientDescent.scala    | 4 ++--
 .../scala/spark/mllib/regression/LogisticRegression.scala   | 6 ++----
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index 2c5038757b..4c996c0903 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -50,7 +50,7 @@ object GradientDescent {
     stepSize: Double,
     numIters: Int,
     initialWeights: Array[Double],
-    miniBatchFraction: Double=1.0) : (DoubleMatrix, Array[Double]) = {
+    miniBatchFraction: Double=1.0) : (Array[Double], Array[Double]) = {
 
     val stochasticLossHistory = new ArrayBuffer[Double](numIters)
 
@@ -75,6 +75,6 @@ object GradientDescent {
       reg_val = update._2
     }
 
-    (weights, stochasticLossHistory.toArray)
+    (weights.toArray, stochasticLossHistory.toArray)
   }
 }
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
index ab865af0c6..711e205c39 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
@@ -126,10 +126,8 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D
       initalWeightsWithIntercept,
       miniBatchFraction)
 
-    val weightsArray = weights.toArray()
-
-    val intercept = weightsArray(0)
-    val weightsScaled = weightsArray.tail
+    val intercept = weights(0)
+    val weightsScaled = weights.tail
 
     val model = new LogisticRegressionModel(weightsScaled, intercept, stochasticLosses)
 
-- 
GitLab