Skip to content
Snippets Groups Projects
Commit 1018a1c1 authored by Timothy Hunter's avatar Timothy Hunter Committed by Joseph K. Bradley
Browse files

[SPARK-14568][ML] Instrumentation framework for logistic regression

## What changes were proposed in this pull request?

This adds extra logging information about a `LogisticRegression` estimator when being fit on a dataset. With this PR, you see the following extra lines when running the example in the documentation:

```
16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training: numPartitions=1 storageLevel=StorageLevel(disk=true, memory=true, offheap=false, deserialized=true, replication=1)
16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): {"regParam":0.3,"elasticNetParam":0.8,"maxIter":10}
...
16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numClasses=2
16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numFeatures=692
...
16/04/13 07:19:01 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training finished
```

## How was this patch tested?

This PR was manually tested.

Author: Timothy Hunter <timhunter@databricks.com>

Closes #12331 from thunterdb/1604-instrumentation.
parent 323e7390
No related branches found
No related tags found
No related merge requests found
......@@ -273,6 +273,10 @@ class LogisticRegression @Since("1.2.0") (
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val instr = Instrumentation.create(this, instances)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
......@@ -291,6 +295,9 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
val (coefficients, intercept, objectiveHistory) = {
if (numInvalid != 0) {
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
......@@ -444,7 +451,9 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
model.setSummary(logRegSummary)
val m = model.setSummary(logRegSummary)
instr.logSuccess(m)
m
}
@Since("1.4.0")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import java.util.concurrent.atomic.AtomicLong
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
/**
* A small wrapper that defines a training session for an estimator, and some methods to log
* useful information during this session.
*
* A new instance is expected to be created within fit().
*
* @param estimator the estimator that is being fit
* @param dataset the training dataset
* @tparam E the type of the estimator
*/
private[ml] class Instrumentation[E <: Estimator[_]] private (
estimator: E, dataset: RDD[_]) extends Logging {
private val id = Instrumentation.counter.incrementAndGet()
private val prefix = {
val className = estimator.getClass.getSimpleName
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
}
init()
private def init(): Unit = {
log(s"training: numPartitions=${dataset.partitions.length}" +
s" storageLevel=${dataset.getStorageLevel}")
}
/**
* Logs a message with a prefix that uniquely identifies the training session.
*/
def log(msg: String): Unit = {
logInfo(prefix + msg)
}
/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
def logParams(params: Param[_]*): Unit = {
val pairs: Seq[(String, JValue)] = for {
p <- params
value <- estimator.get(p)
} yield {
val cast = p.asInstanceOf[Param[Any]]
p.name -> parse(cast.jsonEncode(value))
}
log(compact(render(map2jvalue(pairs.toMap))))
}
def logNumFeatures(num: Long): Unit = {
log(compact(render("numFeatures" -> num)))
}
def logNumClasses(num: Long): Unit = {
log(compact(render("numClasses" -> num)))
}
/**
* Logs the successful completion of the training session and the value of the learned model.
*/
def logSuccess(model: Model[_]): Unit = {
log(s"training finished")
}
}
/**
* Some common methods for logging information about a training session.
*/
private[ml] object Instrumentation {
private val counter = new AtomicLong(0)
/**
* Creates an instrumentation object for a training session.
*/
def create[E <: Estimator[_]](
estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
create[E](estimator, dataset.rdd)
}
/**
* Creates an instrumentation object for a training session.
*/
def create[E <: Estimator[_]](
estimator: E, dataset: RDD[_]): Instrumentation[E] = {
new Instrumentation[E](estimator, dataset)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment