Skip to content
Snippets Groups Projects
Commit d60f6f62 authored by sueann's avatar sueann Committed by Joseph K. Bradley
Browse files

[SPARK-18194][ML] Log instrumentation in OneVsRest, CrossValidator, TrainValidationSplit

## What changes were proposed in this pull request?

Added instrumentation logging for OneVsRest classifier, CrossValidator, TrainValidationSplit fit() functions.

## How was this patch tested?

Ran unit tests and checked the log file (see output in comments).

Author: sueann <sueann@databricks.com>

Closes #16480 from sueann/SPARK-18194.
parent b59cddab
No related branches found
No related tags found
No related merge requests found
......@@ -308,6 +308,10 @@ final class OneVsRest @Since("1.4.0") (
override def fit(dataset: Dataset[_]): OneVsRestModel = {
transformSchema(dataset.schema)
val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, predictionCol)
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
......@@ -316,6 +320,7 @@ final class OneVsRest @Since("1.4.0") (
maxLabelIndex.toInt + 1
}
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
instr.logNumClasses(numClasses)
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
......@@ -339,6 +344,7 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
}.toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)
if (handlePersistence) {
multiclassLabeled.unpersist()
......@@ -352,6 +358,7 @@ final class OneVsRest @Since("1.4.0") (
case attr: Attribute => attr
}
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
instr.logSuccess(model)
copyValues(model)
}
......
......@@ -457,8 +457,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
val instrLog = Instrumentation.create(this, ratings)
instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
val instr = Instrumentation.create(this, ratings)
instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
userCol, itemCol, ratingCol, predictionCol, maxIter,
regParam, nonnegative, checkpointInterval, seed)
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
......@@ -471,7 +471,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
instrLog.logSuccess(model)
instr.logSuccess(model)
copyValues(model)
}
......
......@@ -101,6 +101,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val instr = Instrumentation.create(this, dataset)
instr.logParams(numFolds, seed)
logTuningParams(instr)
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
......@@ -127,6 +132,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
......
......@@ -97,6 +97,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val instr = Instrumentation.create(this, dataset)
instr.logParams(trainRatio, seed)
logTuningParams(instr)
val Array(trainingDataset, validationDataset) =
dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
trainingDataset.cache()
......@@ -123,6 +127,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
}
......
......@@ -26,7 +26,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
......@@ -76,6 +76,15 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
}
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
/**
* Instrumentation logging for tuning params including the inner estimator and evaluator info.
*/
protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = {
instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName)
instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
}
}
private[ml] object ValidatorParams {
......
......@@ -87,8 +87,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
/**
* Logs the value with customized name field.
*/
def logNamedValue(name: String, num: Long): Unit = {
log(compact(render(name -> num)))
def logNamedValue(name: String, value: String): Unit = {
log(compact(render(name -> value)))
}
def logNamedValue(name: String, value: Long): Unit = {
log(compact(render(name -> value)))
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment