Skip to content
Snippets Groups Projects
Commit cc12a86f authored by Ram Sriharsha's avatar Ram Sriharsha Committed by Joseph K. Bradley
Browse files

[SPARK-7575] [ML] [DOC] Example code for OneVsRest

Java and Scala examples for OneVsRest. Fixes the base classifier to be Logistic Regression and accepts the configuration parameters of the base classifier.

Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes #6115 from harsha2010/SPARK-7575 and squashes the following commits:

87ad3c7 [Ram Sriharsha] extra line
f5d9891 [Ram Sriharsha] Merge branch 'master' into SPARK-7575
7076084 [Ram Sriharsha] cleanup
dfd660c [Ram Sriharsha] cleanup
8703e4f [Ram Sriharsha] update doc
cb23995 [Ram Sriharsha] fix commandline options for JavaOneVsRestExample
69e91f8 [Ram Sriharsha] cleanup
7f4e127 [Ram Sriharsha] cleanup
d4c40d0 [Ram Sriharsha] Code Review fixes
461eb38 [Ram Sriharsha] cleanup
e0106d9 [Ram Sriharsha] Fix typo
935cf56 [Ram Sriharsha] Try to match Java and Scala Example Commandline options
5323ff9 [Ram Sriharsha] cleanup
196a59a [Ram Sriharsha] cleanup
6adfa0c [Ram Sriharsha] Style Fix
8cfc5d5 [Ram Sriharsha] [SPARK-7575] Example code for OneVsRest
parent 2c04c8a1
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.ml;
import org.apache.commons.cli.*;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.util.MetadataUtils;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructField;
/**
* An example runner for Multiclass to Binary Reduction with One Vs Rest.
* The example uses Logistic Regression as the base classifier. All parameters that
* can be specified on the base classifier can be passed in to the runner options.
* Run with
* <pre>
* bin/run-example ml.JavaOneVsRestExample [options]
* </pre>
*/
public class JavaOneVsRestExample {
private static class Params {
String input;
String testInput = null;
Integer maxIter = 100;
double tol = 1E-6;
boolean fitIntercept = true;
Double regParam = null;
Double elasticNetParam = null;
double fracTest = 0.2;
}
public static void main(String[] args) {
// parse the arguments
Params params = parse(args);
SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// configure the base classifier
LogisticRegression classifier = new LogisticRegression()
.setMaxIter(params.maxIter)
.setTol(params.tol)
.setFitIntercept(params.fitIntercept);
if (params.regParam != null) {
classifier.setRegParam(params.regParam);
}
if (params.elasticNetParam != null) {
classifier.setElasticNetParam(params.elasticNetParam);
}
// instantiate the One Vs Rest Classifier
OneVsRest ovr = new OneVsRest().setClassifier(classifier);
String input = params.input;
RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
RDD<LabeledPoint> train;
RDD<LabeledPoint> test;
// compute the train/ test split: if testInput is not provided use part of input
String testInput = params.testInput;
if (testInput != null) {
train = inputData;
// compute the number of features in the training set.
int numFeatures = inputData.first().features().size();
test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
} else {
double f = params.fracTest;
RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
train = tmp[0];
test = tmp[1];
}
// train the multiclass model
DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
// score the model on test data
DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
DataFrame predictions = ovrModel.transform(testDataFrame.cache())
.select("prediction", "label");
// obtain metrics
MulticlassMetrics metrics = new MulticlassMetrics(predictions);
StructField predictionColSchema = predictions.schema().apply("prediction");
Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();
// compute the false positive rate per label
StringBuilder results = new StringBuilder();
results.append("label\tfpr\n");
for (int label = 0; label < numClasses; label++) {
results.append(label);
results.append("\t");
results.append(metrics.falsePositiveRate((double) label));
results.append("\n");
}
Matrix confusionMatrix = metrics.confusionMatrix();
// output the Confusion Matrix
System.out.println("Confusion Matrix");
System.out.println(confusionMatrix);
System.out.println();
System.out.println(results);
jsc.stop();
}
private static Params parse(String[] args) {
Options options = generateCommandlineOptions();
CommandLineParser parser = new PosixParser();
Params params = new Params();
try {
CommandLine cmd = parser.parse(options, args);
String value;
if (cmd.hasOption("input")) {
params.input = cmd.getOptionValue("input");
}
if (cmd.hasOption("maxIter")) {
value = cmd.getOptionValue("maxIter");
params.maxIter = Integer.parseInt(value);
}
if (cmd.hasOption("tol")) {
value = cmd.getOptionValue("tol");
params.tol = Double.parseDouble(value);
}
if (cmd.hasOption("fitIntercept")) {
value = cmd.getOptionValue("fitIntercept");
params.fitIntercept = Boolean.parseBoolean(value);
}
if (cmd.hasOption("regParam")) {
value = cmd.getOptionValue("regParam");
params.regParam = Double.parseDouble(value);
}
if (cmd.hasOption("elasticNetParam")) {
value = cmd.getOptionValue("elasticNetParam");
params.elasticNetParam = Double.parseDouble(value);
}
if (cmd.hasOption("testInput")) {
value = cmd.getOptionValue("testInput");
params.testInput = value;
}
if (cmd.hasOption("fracTest")) {
value = cmd.getOptionValue("fracTest");
params.fracTest = Double.parseDouble(value);
}
} catch (ParseException e) {
printHelpAndQuit(options);
}
return params;
}
private static Options generateCommandlineOptions() {
Option input = OptionBuilder.withArgName("input")
.hasArg()
.isRequired()
.withDescription("input path to labeled examples. This path must be specified")
.create("input");
Option testInput = OptionBuilder.withArgName("testInput")
.hasArg()
.withDescription("input path to test examples")
.create("testInput");
Option fracTest = OptionBuilder.withArgName("testInput")
.hasArg()
.withDescription("fraction of data to hold out for testing." +
" If given option testInput, this option is ignored. default: 0.2")
.create("fracTest");
Option maxIter = OptionBuilder.withArgName("maxIter")
.hasArg()
.withDescription("maximum number of iterations for Logistic Regression. default:100")
.create("maxIter");
Option tol = OptionBuilder.withArgName("tol")
.hasArg()
.withDescription("the convergence tolerance of iterations " +
"for Logistic Regression. default: 1E-6")
.create("tol");
Option fitIntercept = OptionBuilder.withArgName("fitIntercept")
.hasArg()
.withDescription("fit intercept for logistic regression. default true")
.create("fitIntercept");
Option regParam = OptionBuilder.withArgName( "regParam" )
.hasArg()
.withDescription("the regularization parameter for Logistic Regression.")
.create("regParam");
Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" )
.hasArg()
.withDescription("the ElasticNet mixing parameter for Logistic Regression.")
.create("elasticNetParam");
Options options = new Options()
.addOption(input)
.addOption(testInput)
.addOption(fracTest)
.addOption(maxIter)
.addOption(tol)
.addOption(fitIntercept)
.addOption(regParam)
.addOption(elasticNetParam);
return options;
}
private static void printHelpAndQuit(Options options) {
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp("JavaOneVsRestExample", options);
System.exit(-1);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.ml
import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO}
import scopt.OptionParser
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
/**
* An example runner for Multiclass to Binary Reduction with One Vs Rest.
* The example uses Logistic Regression as the base classifier. All parameters that
* can be specified on the base classifier can be passed in to the runner options.
* Run with
* {{{
* ./bin/run-example ml.OneVsRestExample [options]
* }}}
* For local mode, run
* {{{
* ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g
* [examples JAR path] [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object OneVsRestExample {
case class Params private[ml] (
input: String = null,
testInput: Option[String] = None,
maxIter: Int = 100,
tol: Double = 1E-6,
fitIntercept: Boolean = true,
regParam: Option[Double] = None,
elasticNetParam: Option[Double] = None,
fracTest: Double = 0.2) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
val parser = new OptionParser[Params]("OneVsRest Example") {
head("OneVsRest Example: multiclass to binary reduction using OneVsRest")
opt[String]("input")
.text("input path to labeled examples. This path must be specified")
.required()
.action((x, c) => c.copy(input = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
opt[String]("testInput")
.text("input path to test dataset. If given, option fracTest is ignored")
.action((x,c) => c.copy(testInput = Some(x)))
opt[Int]("maxIter")
.text(s"maximum number of iterations for Logistic Regression." +
s" default: ${defaultParams.maxIter}")
.action((x, c) => c.copy(maxIter = x))
opt[Double]("tol")
.text(s"the convergence tolerance of iterations for Logistic Regression." +
s" default: ${defaultParams.tol}")
.action((x, c) => c.copy(tol = x))
opt[Boolean]("fitIntercept")
.text(s"fit intercept for Logistic Regression." +
s" default: ${defaultParams.fitIntercept}")
.action((x, c) => c.copy(fitIntercept = x))
opt[Double]("regParam")
.text(s"the regularization parameter for Logistic Regression.")
.action((x,c) => c.copy(regParam = Some(x)))
opt[Double]("elasticNetParam")
.text(s"the ElasticNet mixing parameter for Logistic Regression.")
.action((x,c) => c.copy(elasticNetParam = Some(x)))
checkConfig { params =>
if (params.fracTest < 0 || params.fracTest >= 1) {
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
}
}
parser.parse(args, defaultParams).map { params =>
run(params)
}.getOrElse {
sys.exit(1)
}
}
private def run(params: Params) {
val conf = new SparkConf().setAppName(s"OneVsRestExample with $params")
val sc = new SparkContext(conf)
val inputData = MLUtils.loadLibSVMFile(sc, params.input)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// compute the train/test split: if testInput is not provided use part of input.
val data = params.testInput match {
case Some(t) => {
// compute the number of features in the training set.
val numFeatures = inputData.first().features.size
val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures)
Array[RDD[LabeledPoint]](inputData, testData)
}
case None => {
val f = params.fracTest
inputData.randomSplit(Array(1 - f, f), seed = 12345)
}
}
val Array(train, test) = data.map(_.toDF().cache())
// instantiate the base classifier
val classifier = new LogisticRegression()
.setMaxIter(params.maxIter)
.setTol(params.tol)
.setFitIntercept(params.fitIntercept)
// Set regParam, elasticNetParam if specified in params
params.regParam.foreach(classifier.setRegParam)
params.elasticNetParam.foreach(classifier.setElasticNetParam)
// instantiate the One Vs Rest Classifier.
val ovr = new OneVsRest()
ovr.setClassifier(classifier)
// train the multiclass model.
val (trainingDuration, ovrModel) = time(ovr.fit(train))
// score the model on test data.
val (predictionDuration, predictions) = time(ovrModel.transform(test))
// evaluate the model
val predictionsAndLabels = predictions.select("prediction", "label")
.map(row => (row.getDouble(0), row.getDouble(1)))
val metrics = new MulticlassMetrics(predictionsAndLabels)
val confusionMatrix = metrics.confusionMatrix
// compute the false positive rate per label
val predictionColSchema = predictions.schema("prediction")
val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get
val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble)))
println(s" Training Time ${trainingDuration} sec\n")
println(s" Prediction Time ${predictionDuration} sec\n")
println(s" Confusion Matrix\n ${confusionMatrix.toString}\n")
println("label\tfpr")
println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))
sc.stop()
}
private def time[R](block: => R): (Long, R) = {
val t0 = System.nanoTime()
val result = block // call-by-name
val t1 = System.nanoTime()
(NANO.toSeconds(t1 - t0), result)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment