Skip to content
Snippets Groups Projects
Commit c5c6aded authored by MechCoder's avatar MechCoder Committed by Joseph K. Bradley
Browse files

[SPARK-9112] [ML] Implement Stats for LogisticRegression

I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary

I have some queries and asked them inline.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7538 from MechCoder/log_reg_stats and squashes the following commits:

2e9f7c7 [MechCoder] Change defs into lazy vals
d775371 [MechCoder] Clean up class inheritance
9586125 [MechCoder] Add abstraction to handle Multiclass Metrics
40ad8ef [MechCoder] minor
640376a [MechCoder] remove unnecessary dataframe stuff and add docs
80d9954 [MechCoder] Added tests
fbed861 [MechCoder] DataFrame support for metrics
70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression
parent 9f94c85f
No related branches found
No related tags found
No related merge requests found
...@@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable ...@@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
/** /**
...@@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String) ...@@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String)
if (handlePersistence) instances.unpersist() if (handlePersistence) instances.unpersist()
copyValues(new LogisticRegressionModel(uid, weights, intercept)) val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
model.transform(dataset),
$(probabilityCol),
$(labelCol),
objectiveHistory)
model.setSummary(logRegSummary)
} }
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
...@@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] ( ...@@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] (
override val numClasses: Int = 2 override val numClasses: Int = 2
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
def summary: LogisticRegressionTrainingSummary = trainingSummary match {
case Some(summ) => summ
case None =>
throw new SparkException(
"No training summary available for this LogisticRegressionModel",
new NullPointerException())
}
private[classification] def setSummary(
summary: LogisticRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
this
}
/** Indicates whether a training summary exists for this model instance. */
def hasSummary: Boolean = trainingSummary.isDefined
/**
* Evaluates the model on a testset.
* @param dataset Test dataset to evaluate model on.
*/
// TODO: decide on a good name before exposing to public API
private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
}
/** /**
* Predict label for the given feature vector. * Predict label for the given feature vector.
* The behavior of this can be adjusted using [[thresholds]]. * The behavior of this can be adjusted using [[thresholds]].
...@@ -440,6 +480,128 @@ private[classification] class MultiClassSummarizer extends Serializable { ...@@ -440,6 +480,128 @@ private[classification] class MultiClassSummarizer extends Serializable {
} }
} }
/**
* Abstraction for multinomial Logistic Regression Training results.
*/
sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
/** objective function (scaled loss + regularization) at each iteration. */
def objectiveHistory: Array[Double]
/** Number of training iterations until termination */
def totalIterations: Int = objectiveHistory.length
}
/**
* Abstraction for Logistic Regression Results for a given model.
*/
sealed trait LogisticRegressionSummary extends Serializable {
/** Dataframe outputted by the model's `transform` method. */
def predictions: DataFrame
/** Field in "predictions" which gives the calibrated probability of each sample as a vector. */
def probabilityCol: String
/** Field in "predictions" which gives the the true label of each sample. */
def labelCol: String
}
/**
* :: Experimental ::
* Logistic regression training results.
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each sample as a vector.
* @param labelCol field in "predictions" which gives the true label of each sample.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame,
probabilityCol: String,
labelCol: String,
val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
with LogisticRegressionTrainingSummary {
}
/**
* :: Experimental ::
* Binary Logistic regression results for a given model.
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each sample.
* @param labelCol field in "predictions" which gives the true label of each sample.
*/
@Experimental
class BinaryLogisticRegressionSummary private[classification] (
@transient override val predictions: DataFrame,
override val probabilityCol: String,
override val labelCol: String) extends LogisticRegressionSummary {
private val sqlContext = predictions.sqlContext
import sqlContext.implicits._
/**
* Returns a BinaryClassificationMetrics object.
*/
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
predictions.select(probabilityCol, labelCol).map {
case Row(score: Vector, label: Double) => (score(1), label)
}, 100
)
/**
* Returns the receiver operating characteristic (ROC) curve,
* which is an Dataframe having two fields (FPR, TPR)
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/
@transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
/**
* Computes the area under the receiver operating characteristic (ROC) curve.
*/
lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
/**
* Returns the precision-recall curve, which is an Dataframe containing
* two fields recall, precision with (0.0, 1.0) prepended to it.
*/
@transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
/**
* Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
*/
@transient lazy val fMeasureByThreshold: DataFrame = {
binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
}
/**
* Returns a dataframe with two fields (threshold, precision) curve.
* Every possible probability obtained in transforming the dataset are used
* as thresholds used in calculating the precision.
*/
@transient lazy val precisionByThreshold: DataFrame = {
binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
}
/**
* Returns a dataframe with two fields (threshold, recall) curve.
* Every possible probability obtained in transforming the dataset are used
* as thresholds used in calculating the recall.
*/
@transient lazy val recallByThreshold: DataFrame = {
binaryMetrics.recallByThreshold().toDF("threshold", "recall")
}
}
/** /**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for samples in sparse or dense vector in a online fashion. * in binary classification for samples in sparse or dense vector in a online fashion.
......
...@@ -152,4 +152,13 @@ public class JavaLogisticRegressionSuite implements Serializable { ...@@ -152,4 +152,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
} }
} }
} }
@Test
public void logisticRegressionTrainingSummary() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
LogisticRegressionTrainingSummary summary = model.summary();
assert(summary.totalIterations() == summary.objectiveHistory().length);
}
} }
...@@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptR relTol 1E-5) assert(model1.intercept ~== interceptR relTol 1E-5)
assert(model1.weights ~= weightsR absTol 1E-6) assert(model1.weights ~== weightsR absTol 1E-6)
}
test("evaluate on test set") {
// Evaluate on test set should be same as that of the transformed training data.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
val model = lr.fit(dataset)
val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
assert(summary.areaUnderROC === sameSummary.areaUnderROC)
assert(summary.roc.collect() === sameSummary.roc.collect())
assert(summary.pr.collect === sameSummary.pr.collect())
assert(
summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
assert(
summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
}
test("statistics on training data") {
// Test that loss is monotonically decreasing.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
val model = lr.fit(dataset)
assert(
model.summary
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment