Skip to content
Snippets Groups Projects
Commit b92bd5a2 authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-3841] [mllib] Pretty-print params for ML examples

Provide a parent class for the Params case classes used in many MLlib examples, where the parent class pretty-prints the case class fields:
Param1Name	Param1Value
Param2Name	Param2Value
...
Using this class will make it easier to print test settings to logs.

Also, updated DecisionTreeRunner to print a little more info.

CC: mengxr

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #2700 from jkbradley/dtrunner-update and squashes the following commits:

cff873f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
7a08ae4 [Joseph K. Bradley] code review comment updates
b4d2043 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
d8228a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
0fc9c64 [Joseph K. Bradley] Added abstract TestParams class for mllib example parameters
12b7798 [Joseph K. Bradley] Added abstract class TestParams for pretty-printing Params values
5f84f03 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
f7441b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
19eb6fc [Joseph K. Bradley] Updated DecisionTreeRunner to print training time.
parent bc441872
No related branches found
No related tags found
No related merge requests found
Showing
with 75 additions and 7 deletions
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib
import scala.reflect.runtime.universe._
/**
* Abstract class for parameter case classes.
* This overrides the [[toString]] method to print all case class fields by name and value.
* @tparam T Concrete parameter class.
*/
abstract class AbstractParams[T: TypeTag] {
private def tag: TypeTag[T] = typeTag[T]
/**
* Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
* {
* [field name]:\t[field value]\n
* [field name]:\t[field value]\n
* ...
* }
*/
override def toString: String = {
val tpe = tag.tpe
val allAccessors = tpe.declarations.collect {
case m: MethodSymbol if m.isCaseAccessor => m
}
val mirror = runtimeMirror(getClass.getClassLoader)
val instanceMirror = mirror.reflect(this)
allAccessors.map { f =>
val paramName = f.name.toString
val fieldMirror = instanceMirror.reflectField(f)
val paramValue = fieldMirror.get
s" $paramName:\t$paramValue"
}.mkString("{\n", ",\n", "\n}")
}
}
......@@ -55,7 +55,7 @@ object BinaryClassification {
stepSize: Double = 1.0,
algorithm: Algorithm = LR,
regType: RegType = L2,
regParam: Double = 0.1)
regParam: Double = 0.1) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......
......@@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object Correlations {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
extends AbstractParams[Params]
def main(args: Array[String]) {
......
......@@ -43,6 +43,7 @@ import org.apache.spark.{SparkConf, SparkContext}
*/
object CosineSimilarity {
case class Params(inputFile: String = null, threshold: Double = 0.1)
extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......
......@@ -62,7 +62,7 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2)
fracTest: Double = 0.2) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......@@ -138,9 +138,11 @@ object DecisionTreeRunner {
def run(params: Params) {
val conf = new SparkConf().setAppName("DecisionTreeRunner")
val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
val sc = new SparkContext(conf)
println(s"DecisionTreeRunner with parameters:\n$params")
// Load training data and cache it.
val origExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
......@@ -235,7 +237,10 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.numNodes < 20) {
println(model.toDebugString) // Print full model.
} else {
......@@ -259,8 +264,11 @@ object DecisionTreeRunner {
} else {
val randomSeed = Utils.random.nextInt()
if (params.algo == Classification) {
val startTime = System.nanoTime()
val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
......@@ -275,8 +283,11 @@ object DecisionTreeRunner {
println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
val startTime = System.nanoTime()
val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
......
......@@ -44,7 +44,7 @@ object DenseKMeans {
input: String = null,
k: Int = -1,
numIterations: Int = 10,
initializationMode: InitializationMode = Parallel)
initializationMode: InitializationMode = Parallel) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......
......@@ -47,7 +47,7 @@ object LinearRegression extends App {
numIterations: Int = 100,
stepSize: Double = 1.0,
regType: RegType = L2,
regParam: Double = 0.1)
regParam: Double = 0.1) extends AbstractParams[Params]
val defaultParams = Params()
......
......@@ -55,7 +55,7 @@ object MovieLensALS {
rank: Int = 10,
numUserBlocks: Int = -1,
numProductBlocks: Int = -1,
implicitPrefs: Boolean = false)
implicitPrefs: Boolean = false) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......
......@@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object MultivariateSummarizer {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
extends AbstractParams[Params]
def main(args: Array[String]) {
......
......@@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._
object SampledRDDs {
case class Params(input: String = "data/mllib/sample_binary_classification_data.txt")
extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......
......@@ -37,7 +37,7 @@ object SparseNaiveBayes {
input: String = null,
minPartitions: Int = 0,
numFeatures: Int = -1,
lambda: Double = 1.0)
lambda: Double = 1.0) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment