Skip to content
Snippets Groups Projects
Commit 460da878 authored by Josh Rosen's avatar Josh Rosen
Browse files

Improve Java API examples

- Replace JavaLR example with JavaHdfsLR example.
- Use anonymous classes in JavaWordCount; add options.
- Remove @Override annotations.
parent 01dce3f5
No related branches found
No related tags found
No related merge requests found
package spark.examples;
import scala.util.Random;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.api.java.function.Function;
import spark.api.java.function.Function2;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringTokenizer;
public class JavaLR {
public class JavaHdfsLR {
static int N = 10000; // Number of data points
static int D = 10; // Number of dimensions
static double R = 0.7; // Scaling factor
static int ITERATIONS = 5;
static Random rand = new Random(42);
static class DataPoint implements Serializable {
public DataPoint(double[] x, int y) {
public DataPoint(double[] x, double y) {
this.x = x;
this.y = y;
}
double[] x;
int y;
double y;
}
static DataPoint generatePoint(int i) {
int y = (i % 2 == 0) ? -1 : 1;
double[] x = new double[D];
for (int j = 0; j < D; j++) {
x[j] = rand.nextGaussian() + y * R;
}
return new DataPoint(x, y);
}
static class ParsePoint extends Function<String, DataPoint> {
static List<DataPoint> generateData() {
List<DataPoint> points = new ArrayList<DataPoint>(N);
for (int i = 0; i < N; i++) {
points.add(generatePoint(i));
public DataPoint apply(String line) {
StringTokenizer tok = new StringTokenizer(line, " ");
double y = Double.parseDouble(tok.nextToken());
double[] x = new double[D];
int i = 0;
while (i < D) {
x[i] = Double.parseDouble(tok.nextToken());
i += 1;
}
return new DataPoint(x, y);
}
return points;
}
static class VectorSum extends Function2<double[], double[], double[]> {
@Override
public double[] apply(double[] a, double[] b) {
double[] result = new double[D];
for (int j = 0; j < D; j++) {
......@@ -64,7 +59,6 @@ public class JavaLR {
this.weights = weights;
}
@Override
public double[] apply(DataPoint p) {
double[] gradient = new double[D];
for (int i = 0; i < D; i++) {
......@@ -89,39 +83,40 @@ public class JavaLR {
public static void main(String[] args) {
if (args.length == 0) {
System.err.println("Usage: JavaLR <host> [<slices>]");
System.exit(1);
}
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR");
Integer numSlices = (args.length > 1) ? Integer.parseInt(args[1]): 2;
List<DataPoint> data = generateData();
if (args.length < 3) {
System.err.println("Usage: JavaHdfsLR <master> <file> <iters>");
System.exit(1);
}
// Initialize w to a random value
double[] w = new double[D];
for (int i = 0; i < D; i++) {
w[i] = 2 * rand.nextDouble() - 1;
}
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR");
JavaRDD<String> lines = sc.textFile(args[1]);
JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
int ITERATIONS = Integer.parseInt(args[2]);
System.out.print("Initial w: ");
printWeights(w);
// Initialize w to a random value
double[] w = new double[D];
for (int i = 0; i < D; i++) {
w[i] = 2 * rand.nextDouble() - 1;
}
for (int i = 1; i <= ITERATIONS; i++) {
System.out.println("On iteration " + i);
System.out.print("Initial w: ");
printWeights(w);
double[] gradient = sc.parallelize(data, numSlices).map(
new ComputeGradient(w)
).reduce(new VectorSum());
for (int i = 1; i <= ITERATIONS; i++) {
System.out.println("On iteration " + i);
for (int j = 0; j < D; j++) {
w[j] -= gradient[j];
}
double[] gradient = points.map(
new ComputeGradient(w)
).reduce(new VectorSum());
for (int j = 0; j < D; j++) {
w[j] -= gradient[j];
}
System.out.print("Final w: ");
printWeights(w);
System.exit(0);
}
System.out.print("Final w: ");
printWeights(w);
System.exit(0);
}
}
......@@ -31,7 +31,7 @@ public class JavaTC {
static class ProjectFn extends PairFunction<Tuple2<Integer, Tuple2<Integer, Integer>>,
Integer, Integer> {
static ProjectFn INSTANCE = new ProjectFn();
@Override
public Tuple2<Integer, Integer> apply(Tuple2<Integer, Tuple2<Integer, Integer>> triple) {
return new Tuple2<Integer, Integer>(triple._2()._2(), triple._2()._1());
}
......
package spark.examples;
import spark.api.java.JavaDoubleRDD;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.api.java.function.DoubleFunction;
import java.util.List;
public class JavaTest {
public static class MapFunction extends DoubleFunction<String> {
@Override
public Double apply(String s) {
return java.lang.Double.parseDouble(s);
}
}
public static void main(String[] args) throws Exception {
JavaSparkContext ctx = new JavaSparkContext("local", "JavaTest");
JavaRDD<String> lines = ctx.textFile("numbers.txt", 1).cache();
List<String> lineArr = lines.collect();
for (String line : lineArr) {
System.out.println(line);
}
JavaDoubleRDD data = lines.map(new MapFunction()).cache();
System.out.println("output");
List<Double> output = data.collect();
for (Double num : output) {
System.out.println(num);
}
System.exit(0);
}
}
......@@ -14,43 +14,31 @@ import java.util.List;
public class JavaWordCount {
public static class SplitFunction extends FlatMapFunction<String, String> {
@Override
public Iterable<String> apply(String s) {
StringOps op = new StringOps(s);
return Arrays.asList(op.split(' '));
}
}
public static class MapFunction extends PairFunction<String, String, Integer> {
@Override
public Tuple2<String, Integer> apply(String s) {
return new Tuple2(s, 1);
}
}
public static class ReduceFunction extends Function2<Integer, Integer, Integer> {
@Override
public Integer apply(Integer i1, Integer i2) {
return i1 + i2;
}
}
public static void main(String[] args) throws Exception {
JavaSparkContext ctx = new JavaSparkContext("local", "JavaWordCount");
JavaRDD<String> lines = ctx.textFile("numbers.txt", 1).cache();
List<String> lineArr = lines.collect();
for (String line : lineArr) {
System.out.println(line);
if (args.length < 2) {
System.err.println("Usage: JavaWordCount <master> <file>");
System.exit(1);
}
JavaRDD<String> words = lines.flatMap(new SplitFunction());
JavaPairRDD<String, Integer> splits = words.map(new MapFunction());
JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount");
JavaRDD<String> lines = ctx.textFile(args[1], 1);
JavaPairRDD<String, Integer> counts = lines.flatMap(new FlatMapFunction<String, String>() {
public Iterable<String> apply(String s) {
StringOps op = new StringOps(s);
return Arrays.asList(op.split(' '));
}
}).map(new PairFunction<String, String, Integer>() {
public Tuple2<String, Integer> apply(String s) {
return new Tuple2(s, 1);
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
public Integer apply(Integer i1, Integer i2) {
return i1 + i2;
}
});
JavaPairRDD<String, Integer> counts = splits.reduceByKey(new ReduceFunction());
System.out.println("output");
List<Tuple2<String, Integer>> output = counts.collect();
for (Tuple2 tuple : output) {
System.out.print(tuple._1 + ": ");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment