Skip to content
Snippets Groups Projects
Commit cdce4e62 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Andrew Or
Browse files

[SPARK-15031][EXAMPLE] Use SparkSession in Scala/Python/Java example.

## What changes were proposed in this pull request?

This PR aims to update Scala/Python/Java examples by replacing `SQLContext` with newly added `SparkSession`.

- Use **SparkSession Builder Pattern** in 154(Scala 55, Java 52, Python 47) files.
- Add `getConf` in Python SparkContext class: `python/pyspark/context.py`
- Replace **SQLContext Singleton Pattern** with **SparkSession Singleton Pattern**:
  - `SqlNetworkWordCount.scala`
  - `JavaSqlNetworkWordCount.java`
  - `sql_network_wordcount.py`

Now, `SQLContexts` are used only in R examples and the following two Python examples. The python examples are untouched in this PR since it already fails some unknown issue.
- `simple_params_example.py`
- `aft_survival_regression.py`

## How was this patch tested?

Manual.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12809 from dongjoon-hyun/SPARK-15031.
parent cf2e9da6
No related branches found
No related tags found
No related merge requests found
Showing
with 117 additions and 179 deletions
...@@ -21,23 +21,19 @@ package org.apache.spark.examples.ml; ...@@ -21,23 +21,19 @@ package org.apache.spark.examples.ml;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.regression.AFTSurvivalRegression; import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.mllib.linalg.*; import org.apache.spark.mllib.linalg.*;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*; import org.apache.spark.sql.types.*;
// $example off$ // $example off$
public class JavaAFTSurvivalRegressionExample { public class JavaAFTSurvivalRegressionExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); SparkSession spark = SparkSession.builder().appName("JavaAFTSurvivalRegressionExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// $example on$ // $example on$
List<Row> data = Arrays.asList( List<Row> data = Arrays.asList(
...@@ -52,7 +48,7 @@ public class JavaAFTSurvivalRegressionExample { ...@@ -52,7 +48,7 @@ public class JavaAFTSurvivalRegressionExample {
new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty()) new StructField("features", new VectorUDT(), false, Metadata.empty())
}); });
Dataset<Row> training = jsql.createDataFrame(data, schema); Dataset<Row> training = spark.createDataFrame(data, schema);
double[] quantileProbabilities = new double[]{0.3, 0.6}; double[] quantileProbabilities = new double[]{0.3, 0.6};
AFTSurvivalRegression aft = new AFTSurvivalRegression() AFTSurvivalRegression aft = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities) .setQuantileProbabilities(quantileProbabilities)
...@@ -66,6 +62,6 @@ public class JavaAFTSurvivalRegressionExample { ...@@ -66,6 +62,6 @@ public class JavaAFTSurvivalRegressionExample {
model.transform(training).show(false); model.transform(training).show(false);
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,11 +17,9 @@ ...@@ -17,11 +17,9 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import java.io.Serializable; import java.io.Serializable;
...@@ -83,18 +81,17 @@ public class JavaALSExample { ...@@ -83,18 +81,17 @@ public class JavaALSExample {
// $example off$ // $example off$
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaALSExample"); SparkSession spark = SparkSession.builder().appName("JavaALSExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
JavaRDD<Rating> ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt") JavaRDD<Rating> ratingsRDD = spark
.read().text("data/mllib/als/sample_movielens_ratings.txt").javaRDD()
.map(new Function<String, Rating>() { .map(new Function<String, Rating>() {
public Rating call(String str) { public Rating call(String str) {
return Rating.parseRating(str); return Rating.parseRating(str);
} }
}); });
Dataset<Row> ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0]; Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1]; Dataset<Row> test = splits[1];
...@@ -121,6 +118,6 @@ public class JavaALSExample { ...@@ -121,6 +118,6 @@ public class JavaALSExample {
Double rmse = evaluator.evaluate(predictions); Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse); System.out.println("Root-mean-square error = " + rmse);
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -20,10 +20,11 @@ package org.apache.spark.examples.ml; ...@@ -20,10 +20,11 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.Binarizer; import org.apache.spark.ml.feature.Binarizer;
...@@ -37,21 +38,19 @@ import org.apache.spark.sql.types.StructType; ...@@ -37,21 +38,19 @@ import org.apache.spark.sql.types.StructType;
public class JavaBinarizerExample { public class JavaBinarizerExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); SparkSession spark = SparkSession.builder().appName("JavaBinarizerExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// $example on$ // $example on$
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(0, 0.1), RowFactory.create(0, 0.1),
RowFactory.create(1, 0.8), RowFactory.create(1, 0.8),
RowFactory.create(2, 0.2) RowFactory.create(2, 0.2)
)); );
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
}); });
Dataset<Row> continuousDataFrame = jsql.createDataFrame(jrdd, schema); Dataset<Row> continuousDataFrame = spark.createDataFrame(data, schema);
Binarizer binarizer = new Binarizer() Binarizer binarizer = new Binarizer()
.setInputCol("feature") .setInputCol("feature")
.setOutputCol("binarized_feature") .setOutputCol("binarized_feature")
...@@ -63,6 +62,6 @@ public class JavaBinarizerExample { ...@@ -63,6 +62,6 @@ public class JavaBinarizerExample {
System.out.println(binarized_value); System.out.println(binarized_value);
} }
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -18,12 +18,10 @@ ...@@ -18,12 +18,10 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import org.apache.spark.ml.clustering.BisectingKMeans; import org.apache.spark.ml.clustering.BisectingKMeans;
import org.apache.spark.ml.clustering.BisectingKMeansModel; import org.apache.spark.ml.clustering.BisectingKMeansModel;
...@@ -44,25 +42,23 @@ import org.apache.spark.sql.types.StructType; ...@@ -44,25 +42,23 @@ import org.apache.spark.sql.types.StructType;
public class JavaBisectingKMeansExample { public class JavaBisectingKMeansExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); SparkSession spark = SparkSession.builder().appName("JavaBisectingKMeansExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// $example on$ // $example on$
JavaRDD<Row> data = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)),
RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)),
RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)),
RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)),
RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)),
RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) RowFactory.create(Vectors.dense(18.9, 20.0, 19.7))
)); );
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()),
}); });
Dataset<Row> dataset = jsql.createDataFrame(data, schema); Dataset<Row> dataset = spark.createDataFrame(data, schema);
BisectingKMeans bkm = new BisectingKMeans().setK(2); BisectingKMeans bkm = new BisectingKMeans().setK(2);
BisectingKMeansModel model = bkm.fit(dataset); BisectingKMeansModel model = bkm.fit(dataset);
...@@ -76,6 +72,6 @@ public class JavaBisectingKMeansExample { ...@@ -76,6 +72,6 @@ public class JavaBisectingKMeansExample {
} }
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,14 +17,12 @@ ...@@ -17,14 +17,12 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf; import org.apache.spark.sql.SparkSession;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
// $example on$ // $example on$
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.Bucketizer; import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
...@@ -37,23 +35,21 @@ import org.apache.spark.sql.types.StructType; ...@@ -37,23 +35,21 @@ import org.apache.spark.sql.types.StructType;
public class JavaBucketizerExample { public class JavaBucketizerExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); SparkSession spark = SparkSession.builder().appName("JavaBucketizerExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// $example on$ // $example on$
double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY};
JavaRDD<Row> data = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(-0.5), RowFactory.create(-0.5),
RowFactory.create(-0.3), RowFactory.create(-0.3),
RowFactory.create(0.0), RowFactory.create(0.0),
RowFactory.create(0.2) RowFactory.create(0.2)
)); );
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
}); });
Dataset<Row> dataFrame = jsql.createDataFrame(data, schema); Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
Bucketizer bucketizer = new Bucketizer() Bucketizer bucketizer = new Bucketizer()
.setInputCol("features") .setInputCol("features")
...@@ -64,7 +60,7 @@ public class JavaBucketizerExample { ...@@ -64,7 +60,7 @@ public class JavaBucketizerExample {
Dataset<Row> bucketedData = bucketizer.transform(dataFrame); Dataset<Row> bucketedData = bucketizer.transform(dataFrame);
bucketedData.show(); bucketedData.show();
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
......
...@@ -21,10 +21,11 @@ import org.apache.spark.SparkConf; ...@@ -21,10 +21,11 @@ import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.ChiSqSelector; import org.apache.spark.ml.feature.ChiSqSelector;
import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.VectorUDT;
...@@ -39,23 +40,21 @@ import org.apache.spark.sql.types.StructType; ...@@ -39,23 +40,21 @@ import org.apache.spark.sql.types.StructType;
public class JavaChiSqSelectorExample { public class JavaChiSqSelectorExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); SparkSession spark = SparkSession.builder().appName("JavaChiSqSelectorExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0),
RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0),
RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0)
)); );
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()),
new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty())
}); });
Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema); Dataset<Row> df = spark.createDataFrame(data, schema);
ChiSqSelector selector = new ChiSqSelector() ChiSqSelector selector = new ChiSqSelector()
.setNumTopFeatures(1) .setNumTopFeatures(1)
...@@ -66,6 +65,6 @@ public class JavaChiSqSelectorExample { ...@@ -66,6 +65,6 @@ public class JavaChiSqSelectorExample {
Dataset<Row> result = selector.fit(df).transform(df); Dataset<Row> result = selector.fit(df).transform(df);
result.show(); result.show();
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -19,36 +19,31 @@ package org.apache.spark.examples.ml; ...@@ -19,36 +19,31 @@ package org.apache.spark.examples.ml;
// $example on$ // $example on$
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.CountVectorizer; import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*; import org.apache.spark.sql.types.*;
// $example off$ // $example off$
public class JavaCountVectorizerExample { public class JavaCountVectorizerExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkSession spark = SparkSession.builder().appName("JavaCountVectorizerExample").getOrCreate();
SparkConf conf = new SparkConf().setAppName("JavaCountVectorizerExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Input data: Each row is a bag of words from a sentence or document. // Input data: Each row is a bag of words from a sentence or document.
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(Arrays.asList("a", "b", "c")), RowFactory.create(Arrays.asList("a", "b", "c")),
RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) RowFactory.create(Arrays.asList("a", "b", "b", "c", "a"))
)); );
StructType schema = new StructType(new StructField [] { StructType schema = new StructType(new StructField [] {
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
}); });
Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema); Dataset<Row> df = spark.createDataFrame(data, schema);
// fit a CountVectorizerModel from the corpus // fit a CountVectorizerModel from the corpus
CountVectorizerModel cvModel = new CountVectorizer() CountVectorizerModel cvModel = new CountVectorizer()
...@@ -66,6 +61,6 @@ public class JavaCountVectorizerExample { ...@@ -66,6 +61,6 @@ public class JavaCountVectorizerExample {
cvModel.transform(df).show(); cvModel.transform(df).show();
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -20,10 +20,11 @@ package org.apache.spark.examples.ml; ...@@ -20,10 +20,11 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.DCT; import org.apache.spark.ml.feature.DCT;
...@@ -38,20 +39,18 @@ import org.apache.spark.sql.types.StructType; ...@@ -38,20 +39,18 @@ import org.apache.spark.sql.types.StructType;
public class JavaDCTExample { public class JavaDCTExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); SparkSession spark = SparkSession.builder().appName("JavaDCTExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// $example on$ // $example on$
JavaRDD<Row> data = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)),
RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)),
RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0))
)); );
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()),
}); });
Dataset<Row> df = jsql.createDataFrame(data, schema); Dataset<Row> df = spark.createDataFrame(data, schema);
DCT dct = new DCT() DCT dct = new DCT()
.setInputCol("features") .setInputCol("features")
.setOutputCol("featuresDCT") .setOutputCol("featuresDCT")
...@@ -59,7 +58,7 @@ public class JavaDCTExample { ...@@ -59,7 +58,7 @@ public class JavaDCTExample {
Dataset<Row> dctDf = dct.transform(df); Dataset<Row> dctDf = dct.transform(df);
dctDf.select("featuresDCT").show(3); dctDf.select("featuresDCT").show(3);
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
// scalastyle:off println // scalastyle:off println
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
// $example on$ // $example on$
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.PipelineStage;
...@@ -28,18 +26,17 @@ import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; ...@@ -28,18 +26,17 @@ import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*; import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example off$ // $example off$
public class JavaDecisionTreeClassificationExample { public class JavaDecisionTreeClassificationExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaDecisionTreeClassificationExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Load the data stored in LIBSVM format as a DataFrame. // Load the data stored in LIBSVM format as a DataFrame.
Dataset<Row> data = sqlContext Dataset<Row> data = spark
.read() .read()
.format("libsvm") .format("libsvm")
.load("data/mllib/sample_libsvm_data.txt"); .load("data/mllib/sample_libsvm_data.txt");
...@@ -100,6 +97,6 @@ public class JavaDecisionTreeClassificationExample { ...@@ -100,6 +97,6 @@ public class JavaDecisionTreeClassificationExample {
System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
// scalastyle:off println // scalastyle:off println
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
// $example on$ // $example on$
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.PipelineStage;
...@@ -29,17 +27,16 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel; ...@@ -29,17 +27,16 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor; import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example off$ // $example off$
public class JavaDecisionTreeRegressionExample { public class JavaDecisionTreeRegressionExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaDecisionTreeRegressionExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Load the data stored in LIBSVM format as a DataFrame. // Load the data stored in LIBSVM format as a DataFrame.
Dataset<Row> data = sqlContext.read().format("libsvm") Dataset<Row> data = spark.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt"); .load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them. // Automatically identify categorical features, and index them.
...@@ -85,6 +82,6 @@ public class JavaDecisionTreeRegressionExample { ...@@ -85,6 +82,6 @@ public class JavaDecisionTreeRegressionExample {
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -21,9 +21,7 @@ import java.util.List; ...@@ -21,9 +21,7 @@ import java.util.List;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.Classifier; import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.IntParam;
...@@ -35,7 +33,7 @@ import org.apache.spark.mllib.linalg.Vectors; ...@@ -35,7 +33,7 @@ import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
/** /**
...@@ -51,9 +49,7 @@ import org.apache.spark.sql.SQLContext; ...@@ -51,9 +49,7 @@ import org.apache.spark.sql.SQLContext;
public class JavaDeveloperApiExample { public class JavaDeveloperApiExample {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); SparkSession spark = SparkSession.builder().appName("JavaDeveloperApiExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
// Prepare training data. // Prepare training data.
List<LabeledPoint> localTraining = Lists.newArrayList( List<LabeledPoint> localTraining = Lists.newArrayList(
...@@ -61,8 +57,7 @@ public class JavaDeveloperApiExample { ...@@ -61,8 +57,7 @@ public class JavaDeveloperApiExample {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
Dataset<Row> training = jsql.createDataFrame( Dataset<Row> training = spark.createDataFrame(localTraining, LabeledPoint.class);
jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator. // Create a LogisticRegression instance. This instance is an Estimator.
MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
...@@ -80,7 +75,7 @@ public class JavaDeveloperApiExample { ...@@ -80,7 +75,7 @@ public class JavaDeveloperApiExample {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); Dataset<Row> test = spark.createDataFrame(localTest, LabeledPoint.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel). // Make predictions on test documents. cvModel uses the best model found (lrModel).
Dataset<Row> results = model.transform(test); Dataset<Row> results = model.transform(test);
...@@ -93,7 +88,7 @@ public class JavaDeveloperApiExample { ...@@ -93,7 +88,7 @@ public class JavaDeveloperApiExample {
" even though all coefficients are 0!"); " even though all coefficients are 0!");
} }
jsc.stop(); spark.stop();
} }
} }
......
...@@ -20,7 +20,7 @@ package org.apache.spark.examples.ml; ...@@ -20,7 +20,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import java.util.ArrayList; import java.util.ArrayList;
...@@ -41,16 +41,15 @@ import org.apache.spark.sql.types.StructType; ...@@ -41,16 +41,15 @@ import org.apache.spark.sql.types.StructType;
public class JavaElementwiseProductExample { public class JavaElementwiseProductExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaElementwiseProductExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Create some vector data; also works for sparse vectors // Create some vector data; also works for sparse vectors
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)),
RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0))
)); );
List<StructField> fields = new ArrayList<>(2); List<StructField> fields = new ArrayList<>(2);
fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("id", DataTypes.StringType, false));
...@@ -58,7 +57,7 @@ public class JavaElementwiseProductExample { ...@@ -58,7 +57,7 @@ public class JavaElementwiseProductExample {
StructType schema = DataTypes.createStructType(fields); StructType schema = DataTypes.createStructType(fields);
Dataset<Row> dataFrame = sqlContext.createDataFrame(jrdd, schema); Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0);
...@@ -70,6 +69,6 @@ public class JavaElementwiseProductExample { ...@@ -70,6 +69,6 @@ public class JavaElementwiseProductExample {
// Batch transform the vectors to create new column: // Batch transform the vectors to create new column:
transformer.transform(dataFrame).show(); transformer.transform(dataFrame).show();
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -21,8 +21,6 @@ package org.apache.spark.examples.ml; ...@@ -21,8 +21,6 @@ package org.apache.spark.examples.ml;
import java.util.Arrays; import java.util.Arrays;
// $example off$ // $example off$
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
// $example on$ // $example on$
import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.classification.LogisticRegressionModel;
...@@ -32,23 +30,21 @@ import org.apache.spark.mllib.regression.LabeledPoint; ...@@ -32,23 +30,21 @@ import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
// $example off$ // $example off$
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
/** /**
* Java example for Estimator, Transformer, and Param. * Java example for Estimator, Transformer, and Param.
*/ */
public class JavaEstimatorTransformerParamExample { public class JavaEstimatorTransformerParamExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf() SparkSession spark = SparkSession
.setAppName("JavaEstimatorTransformerParamExample"); .builder().appName("JavaEstimatorTransformerParamExample").getOrCreate();
SparkContext sc = new SparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
// $example on$ // $example on$
// Prepare training data. // Prepare training data.
// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into
// DataFrames, where it uses the bean metadata to infer the schema. // DataFrames, where it uses the bean metadata to infer the schema.
Dataset<Row> training = sqlContext.createDataFrame( Dataset<Row> training = spark.createDataFrame(
Arrays.asList( Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
...@@ -89,7 +85,7 @@ public class JavaEstimatorTransformerParamExample { ...@@ -89,7 +85,7 @@ public class JavaEstimatorTransformerParamExample {
System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents. // Prepare test documents.
Dataset<Row> test = sqlContext.createDataFrame(Arrays.asList( Dataset<Row> test = spark.createDataFrame(Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
...@@ -107,6 +103,6 @@ public class JavaEstimatorTransformerParamExample { ...@@ -107,6 +103,6 @@ public class JavaEstimatorTransformerParamExample {
} }
// $example off$ // $example off$
sc.stop(); spark.stop();
} }
} }
...@@ -29,18 +29,17 @@ import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; ...@@ -29,18 +29,17 @@ import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*; import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example off$ // $example off$
public class JavaGradientBoostedTreeClassifierExample { public class JavaGradientBoostedTreeClassifierExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeClassifierExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaGradientBoostedTreeClassifierExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Load and parse the data file, converting it to a DataFrame. // Load and parse the data file, converting it to a DataFrame.
Dataset<Row> data = sqlContext.read().format("libsvm") Dataset<Row> data = spark.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt"); .load("data/mllib/sample_libsvm_data.txt");
// Index labels, adding metadata to the label column. // Index labels, adding metadata to the label column.
...@@ -99,6 +98,6 @@ public class JavaGradientBoostedTreeClassifierExample { ...@@ -99,6 +98,6 @@ public class JavaGradientBoostedTreeClassifierExample {
System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
// $example on$ // $example on$
import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineModel;
...@@ -30,19 +28,17 @@ import org.apache.spark.ml.regression.GBTRegressionModel; ...@@ -30,19 +28,17 @@ import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.GBTRegressor; import org.apache.spark.ml.regression.GBTRegressor;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example off$ // $example off$
public class JavaGradientBoostedTreeRegressorExample { public class JavaGradientBoostedTreeRegressorExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeRegressorExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaGradientBoostedTreeRegressorExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Load and parse the data file, converting it to a DataFrame. // Load and parse the data file, converting it to a DataFrame.
Dataset<Row> data = Dataset<Row> data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them. // Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous. // Set maxCategories so features with > 4 distinct values are treated as continuous.
...@@ -87,6 +83,6 @@ public class JavaGradientBoostedTreeRegressorExample { ...@@ -87,6 +83,6 @@ public class JavaGradientBoostedTreeRegressorExample {
System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString());
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,14 +17,12 @@ ...@@ -17,14 +17,12 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example on$ // $example on$
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.IndexToString; import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.ml.feature.StringIndexer;
...@@ -39,24 +37,22 @@ import org.apache.spark.sql.types.StructType; ...@@ -39,24 +37,22 @@ import org.apache.spark.sql.types.StructType;
public class JavaIndexToStringExample { public class JavaIndexToStringExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); SparkSession spark = SparkSession.builder().appName("JavaIndexToStringExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( List<Row> data = Arrays.asList(
RowFactory.create(0, "a"), RowFactory.create(0, "a"),
RowFactory.create(1, "b"), RowFactory.create(1, "b"),
RowFactory.create(2, "c"), RowFactory.create(2, "c"),
RowFactory.create(3, "a"), RowFactory.create(3, "a"),
RowFactory.create(4, "a"), RowFactory.create(4, "a"),
RowFactory.create(5, "c") RowFactory.create(5, "c")
)); );
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty()) new StructField("category", DataTypes.StringType, false, Metadata.empty())
}); });
Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema); Dataset<Row> df = spark.createDataFrame(data, schema);
StringIndexerModel indexer = new StringIndexer() StringIndexerModel indexer = new StringIndexer()
.setInputCol("category") .setInputCol("category")
...@@ -70,6 +66,6 @@ public class JavaIndexToStringExample { ...@@ -70,6 +66,6 @@ public class JavaIndexToStringExample {
Dataset<Row> converted = converter.transform(indexed); Dataset<Row> converted = converter.transform(indexed);
converted.select("id", "originalCategory").show(); converted.select("id", "originalCategory").show();
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -19,12 +19,10 @@ package org.apache.spark.examples.ml; ...@@ -19,12 +19,10 @@ package org.apache.spark.examples.ml;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.catalyst.expressions.GenericRow;
// $example on$ // $example on$
import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeansModel;
...@@ -72,16 +70,14 @@ public class JavaKMeansExample { ...@@ -72,16 +70,14 @@ public class JavaKMeansExample {
int k = Integer.parseInt(args[1]); int k = Integer.parseInt(args[1]);
// Parses the arguments // Parses the arguments
SparkConf conf = new SparkConf().setAppName("JavaKMeansExample"); SparkSession spark = SparkSession.builder().appName("JavaKMeansExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Loads data // Loads data
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint()); JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
StructType schema = new StructType(fields); StructType schema = new StructType(fields);
Dataset<Row> dataset = sqlContext.createDataFrame(points, schema); Dataset<Row> dataset = spark.createDataFrame(points, schema);
// Trains a k-means model // Trains a k-means model
KMeans kmeans = new KMeans() KMeans kmeans = new KMeans()
...@@ -96,6 +92,6 @@ public class JavaKMeansExample { ...@@ -96,6 +92,6 @@ public class JavaKMeansExample {
} }
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -19,9 +19,7 @@ package org.apache.spark.examples.ml; ...@@ -19,9 +19,7 @@ package org.apache.spark.examples.ml;
// $example on$ // $example on$
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.LDA; import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel; import org.apache.spark.ml.clustering.LDAModel;
...@@ -30,7 +28,7 @@ import org.apache.spark.mllib.linalg.VectorUDT; ...@@ -30,7 +28,7 @@ import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
...@@ -67,15 +65,13 @@ public class JavaLDAExample { ...@@ -67,15 +65,13 @@ public class JavaLDAExample {
String inputFile = "data/mllib/sample_lda_data.txt"; String inputFile = "data/mllib/sample_lda_data.txt";
// Parses the arguments // Parses the arguments
SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); SparkSession spark = SparkSession.builder().appName("JavaLDAExample").getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// Loads data // Loads data
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParseVector()); JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParseVector());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
StructType schema = new StructType(fields); StructType schema = new StructType(fields);
Dataset<Row> dataset = sqlContext.createDataFrame(points, schema); Dataset<Row> dataset = spark.createDataFrame(points, schema);
// Trains a LDA model // Trains a LDA model
LDA lda = new LDA() LDA lda = new LDA()
...@@ -91,7 +87,7 @@ public class JavaLDAExample { ...@@ -91,7 +87,7 @@ public class JavaLDAExample {
topics.show(false); topics.show(false);
model.transform(dataset).show(false); model.transform(dataset).show(false);
jsc.stop(); spark.stop();
} }
// $example off$ // $example off$
} }
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
// $example on$ // $example on$
import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel; import org.apache.spark.ml.regression.LinearRegressionModel;
...@@ -26,18 +24,17 @@ import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; ...@@ -26,18 +24,17 @@ import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
// $example off$ // $example off$
public class JavaLinearRegressionWithElasticNetExample { public class JavaLinearRegressionWithElasticNetExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithElasticNetExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaLinearRegressionWithElasticNetExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// $example on$ // $example on$
// Load training data // Load training data
Dataset<Row> training = sqlContext.read().format("libsvm") Dataset<Row> training = spark.read().format("libsvm")
.load("data/mllib/sample_linear_regression_data.txt"); .load("data/mllib/sample_linear_regression_data.txt");
LinearRegression lr = new LinearRegression() LinearRegression lr = new LinearRegression()
...@@ -61,6 +58,6 @@ public class JavaLinearRegressionWithElasticNetExample { ...@@ -61,6 +58,6 @@ public class JavaLinearRegressionWithElasticNetExample {
System.out.println("r2: " + trainingSummary.r2()); System.out.println("r2: " + trainingSummary.r2());
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.examples.ml; package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
// $example on$ // $example on$
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegression;
...@@ -26,18 +24,17 @@ import org.apache.spark.ml.classification.LogisticRegressionModel; ...@@ -26,18 +24,17 @@ import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions; import org.apache.spark.sql.functions;
// $example off$ // $example off$
public class JavaLogisticRegressionSummaryExample { public class JavaLogisticRegressionSummaryExample {
public static void main(String[] args) { public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionSummaryExample"); SparkSession spark = SparkSession
JavaSparkContext jsc = new JavaSparkContext(conf); .builder().appName("JavaLogisticRegressionSummaryExample").getOrCreate();
SQLContext sqlContext = new SQLContext(jsc);
// Load training data // Load training data
Dataset<Row> training = sqlContext.read().format("libsvm") Dataset<Row> training = spark.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt"); .load("data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression() LogisticRegression lr = new LogisticRegression()
...@@ -80,6 +77,6 @@ public class JavaLogisticRegressionSummaryExample { ...@@ -80,6 +77,6 @@ public class JavaLogisticRegressionSummaryExample {
lrModel.setThreshold(bestThreshold); lrModel.setThreshold(bestThreshold);
// $example off$ // $example off$
jsc.stop(); spark.stop();
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment