Skip to content
Snippets Groups Projects
Commit 72e66bb2 authored by Yong Tang's avatar Yong Tang Committed by Sean Owen
Browse files

[SPARK-14301][EXAMPLES] Java examples code merge and clean up.

## What changes were proposed in this pull request?

This fix tries to remove duplicate Java code in examples/mllib and examples/ml. The following changes have been made:

```
deleted: ml/JavaCrossValidatorExample.java (duplicate of JavaModelSelectionViaCrossValidationExample.java)
deleted: ml/JavaTrainValidationSplitExample.java (duplicated of JavaModelSelectionViaTrainValidationSplitExample.java)
deleted: mllib/JavaFPGrowthExample.java (duplicate of JavaSimpleFPGrowth.java)
deleted: mllib/JavaLDAExample.java (duplicate of JavaLatentDirichletAllocationExample.java)
deleted: mllib/JavaKMeans.java (merged with JavaKMeansExample.java)
deleted: mllib/JavaLR.java (duplicate of JavaLinearRegressionWithSGDExample.java)
updated: mllib/JavaKMeansExample.java (merged with mllib/JavaKMeans.java)
```

## How was this patch tested?
Existing tests passed.

Author: Yong Tang <yong.tang.github@outlook.com>

Closes #12143 from yongtang/SPARK-14301.
parent 00288ea2
No related branches found
No related tags found
No related merge requests found
Showing with 16 additions and 534 deletions
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.ml;
import java.util.List;
import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
/**
* A simple example demonstrating model selection using CrossValidator.
* This example also demonstrates how Pipelines are Estimators.
*
* This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and
* {@link org.apache.spark.examples.ml.Document} defined in the Scala example
* {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}.
*
* Run with
* <pre>
* bin/run-example ml.JavaCrossValidatorExample
* </pre>
*/
public class JavaCrossValidatorExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// Prepare training documents, which are labeled.
List<LabeledDocument> localTraining = Lists.newArrayList(
new LabeledDocument(0L, "a b c d e spark", 1.0),
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0),
new LabeledDocument(4L, "b spark who", 1.0),
new LabeledDocument(5L, "g d a y", 0.0),
new LabeledDocument(6L, "spark fly", 1.0),
new LabeledDocument(7L, "was mapreduce", 0.0),
new LabeledDocument(8L, "e spark program", 1.0),
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
Dataset<Row> training = jsql.createDataFrame(
jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words");
HashingTF hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol(tokenizer.getOutputCol())
.setOutputCol("features");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01);
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
CrossValidator crossval = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator());
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000})
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
.build();
crossval.setEstimatorParamMaps(paramGrid);
crossval.setNumFolds(2); // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel cvModel = crossval.fit(training);
// Prepare test documents, which are unlabeled.
List<Document> localTest = Lists.newArrayList(
new Document(4L, "spark i j k"),
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
Dataset<Row> predictions = cvModel.transform(test);
for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
jsc.stop();
}
}
......@@ -32,7 +32,15 @@ import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
/**
* Java example for Model Selection via Train Validation Split.
* Java example demonstrating model selection using TrainValidationSplit.
*
* The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
* using linear regression.
*
* Run with
* {{{
* bin/run-example ml.JavaModelSelectionViaTrainValidationSplitExample
* }}}
*/
public class JavaModelSelectionViaTrainValidationSplitExample {
public static void main(String[] args) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
/**
* A simple example demonstrating model selection using TrainValidationSplit.
*
* The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
* using linear regression.
*
* Run with
* {{{
* bin/run-example ml.JavaTrainValidationSplitExample
* }}}
*/
public class JavaTrainValidationSplitExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Prepare training and test data.
Dataset<Row>[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];
LinearRegression lr = new LinearRegression();
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[] {0.1, 0.01})
.addGrid(lr.fitIntercept())
.addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
.build();
// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator())
.setEstimatorParamMaps(paramGrid);
// 80% of the data will be used for training and the remaining 20% for validation.
trainValidationSplit.setTrainRatio(0.8);
// Run train validation split, and choose the best set of parameters.
TrainValidationSplitModel model = trainValidationSplit.fit(training);
// Make predictions on test data. model is the model with combination of parameters
// that performed best.
model.transform(test)
.select("features", "label", "prediction")
.show();
jsc.stop();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib;
import java.util.ArrayList;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.fpm.FPGrowth;
import org.apache.spark.mllib.fpm.FPGrowthModel;
/**
* Java example for mining frequent itemsets using FP-growth.
* Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt
*/
public class JavaFPGrowthExample {
public static void main(String[] args) {
String inputFile;
double minSupport = 0.3;
int numPartition = -1;
if (args.length < 1) {
System.err.println(
"Usage: JavaFPGrowth <input_file> [minSupport] [numPartition]");
System.exit(1);
}
inputFile = args[0];
if (args.length >= 2) {
minSupport = Double.parseDouble(args[1]);
}
if (args.length >= 3) {
numPartition = Integer.parseInt(args[2]);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<ArrayList<String>> transactions = sc.textFile(inputFile).map(
new Function<String, ArrayList<String>>() {
@Override
public ArrayList<String> call(String s) {
return Lists.newArrayList(s.split(" "));
}
}
);
FPGrowthModel<String> model = new FPGrowth()
.setMinSupport(minSupport)
.setNumPartitions(numPartition)
.run(transactions);
for (FPGrowth.FreqItemset<String> s: model.freqItemsets().toJavaRDD().collect()) {
System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq());
}
sc.stop();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib;
import java.util.regex.Pattern;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
/**
* Example using MLlib KMeans from Java.
*/
public final class JavaKMeans {
private static class ParsePoint implements Function<String, Vector> {
private static final Pattern SPACE = Pattern.compile(" ");
@Override
public Vector call(String line) {
String[] tok = SPACE.split(line);
double[] point = new double[tok.length];
for (int i = 0; i < tok.length; ++i) {
point[i] = Double.parseDouble(tok[i]);
}
return Vectors.dense(point);
}
}
public static void main(String[] args) {
if (args.length < 3) {
System.err.println(
"Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]");
System.exit(1);
}
String inputFile = args[0];
int k = Integer.parseInt(args[1]);
int iterations = Integer.parseInt(args[2]);
int runs = 1;
if (args.length >= 4) {
runs = Integer.parseInt(args[3]);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = sc.textFile(inputFile);
JavaRDD<Vector> points = lines.map(new ParsePoint());
KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());
System.out.println("Cluster centers:");
for (Vector center : model.clusterCenters()) {
System.out.println(" " + center);
}
double cost = model.computeCost(points.rdd());
System.out.println("Cost: " + cost);
sc.stop();
}
}
......@@ -58,6 +58,13 @@ public class JavaKMeansExample {
int numIterations = 20;
KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);
System.out.println("Cluster centers:");
for (Vector center: clusters.clusterCenters()) {
System.out.println(" " + center);
}
double cost = clusters.computeCost(parsedData.rdd());
System.out.println("Cost: " + cost);
// Evaluate clustering by computing Within Set Sum of Squared Errors
double WSSSE = clusters.computeCost(parsedData.rdd());
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib;
import scala.Tuple2;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.DistributedLDAModel;
import org.apache.spark.mllib.clustering.LDA;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.SparkConf;
public class JavaLDAExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("LDA Example");
JavaSparkContext sc = new JavaSparkContext(conf);
// Load and parse the data
String path = "data/mllib/sample_lda_data.txt";
JavaRDD<String> data = sc.textFile(path);
JavaRDD<Vector> parsedData = data.map(
new Function<String, Vector>() {
public Vector call(String s) {
String[] sarray = s.trim().split(" ");
double[] values = new double[sarray.length];
for (int i = 0; i < sarray.length; i++) {
values[i] = Double.parseDouble(sarray[i]);
}
return Vectors.dense(values);
}
}
);
// Index documents with unique IDs
JavaPairRDD<Long, Vector> corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
return doc_id.swap();
}
}
));
corpus.cache();
// Cluster the documents into three topics using LDA
DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);
// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
+ " words):");
Matrix topics = ldaModel.topicsMatrix();
for (int topic = 0; topic < 3; topic++) {
System.out.print("Topic " + topic + ":");
for (int word = 0; word < ldaModel.vocabSize(); word++) {
System.out.print(" " + topics.apply(word, topic));
}
System.out.println();
}
sc.stop();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib;
import java.util.regex.Pattern;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
/**
* Logistic regression based classification using ML Lib.
*/
public final class JavaLR {
static class ParsePoint implements Function<String, LabeledPoint> {
private static final Pattern COMMA = Pattern.compile(",");
private static final Pattern SPACE = Pattern.compile(" ");
@Override
public LabeledPoint call(String line) {
String[] parts = COMMA.split(line);
double y = Double.parseDouble(parts[0]);
String[] tok = SPACE.split(parts[1]);
double[] x = new double[tok.length];
for (int i = 0; i < tok.length; ++i) {
x[i] = Double.parseDouble(tok[i]);
}
return new LabeledPoint(y, Vectors.dense(x));
}
}
public static void main(String[] args) {
if (args.length != 3) {
System.err.println("Usage: JavaLR <input_dir> <step_size> <niters>");
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = sc.textFile(args[0]);
JavaRDD<LabeledPoint> points = lines.map(new ParsePoint()).cache();
double stepSize = Double.parseDouble(args[1]);
int iterations = Integer.parseInt(args[2]);
// Another way to configure LogisticRegression
//
// LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD();
// lr.optimizer().setNumIterations(iterations)
// .setStepSize(stepSize)
// .setMiniBatchFraction(1.0);
// lr.setIntercept(true);
// LogisticRegressionModel model = lr.train(points.rdd());
LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(),
iterations, stepSize);
System.out.print("Final w: " + model.weights());
sc.stop();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment