From 127bf1bb07967e2e4f99ad7abaa7f6fab3b3f407 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng <ruifengz@foxmail.com> Date: Fri, 20 May 2016 16:40:33 -0700 Subject: [PATCH] [SPARK-15031][EXAMPLE] Use SparkSession in examples ## What changes were proposed in this pull request? Use `SparkSession` according to [SPARK-15031](https://issues.apache.org/jira/browse/SPARK-15031) `MLLLIB` is not recommended to use now, so examples in `MLLIB` are ignored in this PR. `StreamingContext` can not be directly obtained from `SparkSession`, so example in `Streaming` are ignored too. cc andrewor14 ## How was this patch tested? manual tests with spark-submit Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #13164 from zhengruifeng/use_sparksession_ii. --- .../org/apache/spark/examples/JavaHdfsLR.java | 14 ++++++++------ .../apache/spark/examples/JavaLogQuery.java | 12 ++++++++---- .../apache/spark/examples/JavaPageRank.java | 13 +++++++------ .../apache/spark/examples/JavaSparkPi.java | 12 ++++++++---- .../spark/examples/JavaStatusTrackerDemo.java | 19 ++++++++++++------- .../org/apache/spark/examples/JavaTC.java | 15 ++++++++++----- .../apache/spark/examples/JavaWordCount.java | 15 +++++++++------ examples/src/main/python/als.py | 12 +++++++++--- examples/src/main/python/avro_inputformat.py | 12 +++++++++--- examples/src/main/python/kmeans.py | 12 ++++++++---- .../src/main/python/logistic_regression.py | 13 +++++++++---- examples/src/main/python/pagerank.py | 11 +++++++---- .../src/main/python/parquet_inputformat.py | 12 +++++++++--- examples/src/main/python/pi.py | 12 +++++++++--- examples/src/main/python/sort.py | 13 +++++++++---- .../src/main/python/transitive_closure.py | 12 +++++++++--- examples/src/main/python/wordcount.py | 13 +++++++++---- .../apache/spark/examples/BroadcastTest.scala | 16 ++++++++++++---- .../spark/examples/DFSReadWriteTest.scala | 13 ++++++++----- .../examples/ExceptionHandlingTest.scala | 12 ++++++++---- .../apache/spark/examples/GroupByTest.scala | 12 ++++++++---- .../org/apache/spark/examples/HdfsTest.scala | 12 +++++++----- .../spark/examples/MultiBroadcastTest.scala | 13 +++++++++---- .../examples/SimpleSkewedGroupByTest.scala | 13 ++++++++----- .../spark/examples/SkewedGroupByTest.scala | 13 +++++++++---- .../org/apache/spark/examples/SparkALS.scala | 12 ++++++++---- .../apache/spark/examples/SparkHdfsLR.scala | 16 +++++++++------- .../apache/spark/examples/SparkKMeans.scala | 15 +++++++++------ .../org/apache/spark/examples/SparkLR.scala | 13 +++++++++---- .../apache/spark/examples/SparkPageRank.scala | 13 ++++++++----- .../org/apache/spark/examples/SparkPi.scala | 11 +++++++---- .../org/apache/spark/examples/SparkTC.scala | 11 +++++++---- .../examples/sql/hive/HiveFromSpark.scala | 2 +- .../apache/spark/sql/JavaDataFrameSuite.java | 6 +++--- 34 files changed, 279 insertions(+), 146 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index f64155ce3c..ded442096c 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -17,11 +17,10 @@ package org.apache.spark.examples; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SparkSession; import java.io.Serializable; import java.util.Arrays; @@ -122,9 +121,12 @@ public final class JavaHdfsLR { showWarning(); - SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD<String> lines = sc.textFile(args[0]); + SparkSession spark = SparkSession + .builder() + .appName("JavaHdfsLR") + .getOrCreate(); + + JavaRDD<String> lines = spark.read().text(args[0]).javaRDD(); JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache(); int ITERATIONS = Integer.parseInt(args[1]); @@ -152,6 +154,6 @@ public final class JavaHdfsLR { System.out.print("Final w: "); printWeights(w); - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index ebb0687b14..7775443861 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -20,12 +20,13 @@ package org.apache.spark.examples; import com.google.common.collect.Lists; import scala.Tuple2; import scala.Tuple3; -import org.apache.spark.SparkConf; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; import java.io.Serializable; import java.util.List; @@ -99,9 +100,12 @@ public final class JavaLogQuery { } public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaLogQuery") + .getOrCreate(); - SparkConf sparkConf = new SparkConf().setAppName("JavaLogQuery"); - JavaSparkContext jsc = new JavaSparkContext(sparkConf); + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); JavaRDD<String> dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs); @@ -123,6 +127,6 @@ public final class JavaLogQuery { for (Tuple2<?,?> t : output) { System.out.println(t._1() + "\t" + t._2()); } - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index 229d123441..128b5ab17c 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -26,14 +26,13 @@ import scala.Tuple2; import com.google.common.collect.Iterables; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; /** * Computes the PageRank of URLs from an input file. Input file should @@ -73,15 +72,17 @@ public final class JavaPageRank { showWarning(); - SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); - JavaSparkContext ctx = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName("JavaPageRank") + .getOrCreate(); // Loads in input file. It should be in format of: // URL neighbor URL // URL neighbor URL // URL neighbor URL // ... - JavaRDD<String> lines = ctx.textFile(args[0], 1); + JavaRDD<String> lines = spark.read().text(args[0]).javaRDD(); // Loads all URLs from input file and initialize their neighbors. JavaPairRDD<String, Iterable<String>> links = lines.mapToPair( @@ -132,6 +133,6 @@ public final class JavaPageRank { System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } - ctx.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 04a57a6bfb..7df145e311 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -17,11 +17,11 @@ package org.apache.spark.examples; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SparkSession; import java.util.ArrayList; import java.util.List; @@ -33,8 +33,12 @@ import java.util.List; public final class JavaSparkPi { public static void main(String[] args) throws Exception { - SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi"); - JavaSparkContext jsc = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName("JavaSparkPi") + .getOrCreate(); + + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2; int n = 100000 * slices; @@ -61,6 +65,6 @@ public final class JavaSparkPi { System.out.println("Pi is roughly " + 4.0 * count / n); - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java index e68ec74c3e..6f899c772e 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -17,13 +17,14 @@ package org.apache.spark.examples; -import org.apache.spark.SparkConf; import org.apache.spark.SparkJobInfo; import org.apache.spark.SparkStageInfo; import org.apache.spark.api.java.JavaFutureAction; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SparkSession; + import java.util.Arrays; import java.util.List; @@ -44,11 +45,15 @@ public final class JavaStatusTrackerDemo { } public static void main(String[] args) throws Exception { - SparkConf sparkConf = new SparkConf().setAppName(APP_NAME); - final JavaSparkContext sc = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName(APP_NAME) + .getOrCreate(); + + final JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); // Example of implementing a progress reporter for a simple job. - JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( + JavaRDD<Integer> rdd = jsc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( new IdentityWithDelay<Integer>()); JavaFutureAction<List<Integer>> jobFuture = rdd.collectAsync(); while (!jobFuture.isDone()) { @@ -58,13 +63,13 @@ public final class JavaStatusTrackerDemo { continue; } int currentJobId = jobIds.get(jobIds.size() - 1); - SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId); - SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); + SparkJobInfo jobInfo = jsc.statusTracker().getJobInfo(currentJobId); + SparkStageInfo stageInfo = jsc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + " active, " + stageInfo.numCompletedTasks() + " complete"); } System.out.println("Job results are: " + jobFuture.get()); - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index ca10384212..f12ca77ed1 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -25,10 +25,10 @@ import java.util.Set; import scala.Tuple2; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; /** * Transitive closure on a graph, implemented in Java. @@ -64,10 +64,15 @@ public final class JavaTC { } public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName("JavaTC") + .getOrCreate(); + + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); + Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2; - JavaPairRDD<Integer, Integer> tc = sc.parallelizePairs(generateGraph(), slices).cache(); + JavaPairRDD<Integer, Integer> tc = jsc.parallelizePairs(generateGraph(), slices).cache(); // Linear transitive closure: each round grows paths by one edge, // by joining the graph's edges with the already-discovered paths. @@ -94,6 +99,6 @@ public final class JavaTC { } while (nextCount != oldCount); System.out.println("TC has " + tc.count() + " edges."); - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 3ff5412b93..1caee60e34 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -18,13 +18,13 @@ package org.apache.spark.examples; import scala.Tuple2; -import org.apache.spark.SparkConf; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; import java.util.Arrays; import java.util.Iterator; @@ -41,9 +41,12 @@ public final class JavaWordCount { System.exit(1); } - SparkConf sparkConf = new SparkConf().setAppName("JavaWordCount"); - JavaSparkContext ctx = new JavaSparkContext(sparkConf); - JavaRDD<String> lines = ctx.textFile(args[0], 1); + SparkSession spark = SparkSession + .builder() + .appName("JavaWordCount") + .getOrCreate(); + + JavaRDD<String> lines = spark.read().text(args[0]).javaRDD(); JavaRDD<String> words = lines.flatMap(new FlatMapFunction<String, String>() { @Override @@ -72,6 +75,6 @@ public final class JavaWordCount { for (Tuple2<?,?> tuple : output) { System.out.println(tuple._1() + ": " + tuple._2()); } - ctx.stop(); + spark.stop(); } } diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index f07020b503..81562e20a9 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -28,7 +28,7 @@ import sys import numpy as np from numpy.random import rand from numpy import matrix -from pyspark import SparkContext +from pyspark.sql import SparkSession LAMBDA = 0.01 # regularization np.random.seed(42) @@ -62,7 +62,13 @@ if __name__ == "__main__": example. Please use pyspark.ml.recommendation.ALS for more conventional use.""", file=sys.stderr) - sc = SparkContext(appName="PythonALS") + spark = SparkSession\ + .builder\ + .appName("PythonALS")\ + .getOrCreate() + + sc = spark._sc + M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 U = int(sys.argv[2]) if len(sys.argv) > 2 else 500 F = int(sys.argv[3]) if len(sys.argv) > 3 else 10 @@ -99,4 +105,4 @@ if __name__ == "__main__": print("Iteration %d:" % i) print("\nRMSE: %5.4f\n" % error) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index da368ac628..3f65e8f79a 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -19,8 +19,8 @@ from __future__ import print_function import sys -from pyspark import SparkContext from functools import reduce +from pyspark.sql import SparkSession """ Read data file users.avro in local Spark distro: @@ -64,7 +64,13 @@ if __name__ == "__main__": exit(-1) path = sys.argv[1] - sc = SparkContext(appName="AvroKeyInputFormat") + + spark = SparkSession\ + .builder\ + .appName("AvroKeyInputFormat")\ + .getOrCreate() + + sc = spark._sc conf = None if len(sys.argv) == 3: @@ -82,4 +88,4 @@ if __name__ == "__main__": for k in output: print(k) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 3426e491dc..92e0a3ae2e 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -27,7 +27,7 @@ from __future__ import print_function import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession def parseVector(line): @@ -55,8 +55,12 @@ if __name__ == "__main__": as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an example on how to use ML's KMeans implementation.""", file=sys.stderr) - sc = SparkContext(appName="PythonKMeans") - lines = sc.textFile(sys.argv[1]) + spark = SparkSession\ + .builder\ + .appName("PythonKMeans")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) data = lines.map(parseVector).cache() K = int(sys.argv[2]) convergeDist = float(sys.argv[3]) @@ -79,4 +83,4 @@ if __name__ == "__main__": print("Final centers: " + str(kPoints)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 7d33be7e81..01c938454b 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -27,7 +27,7 @@ from __future__ import print_function import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession D = 10 # Number of dimensions @@ -55,8 +55,13 @@ if __name__ == "__main__": Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py to see how ML's implementation is used.""", file=sys.stderr) - sc = SparkContext(appName="PythonLR") - points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() + spark = SparkSession\ + .builder\ + .appName("PythonLR")\ + .getOrCreate() + + points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\ + .mapPartitions(readPointBatch).cache() iterations = int(sys.argv[2]) # Initialize w to a random value @@ -80,4 +85,4 @@ if __name__ == "__main__": print("Final w: " + str(w)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 2fdc9773d4..a399a9c37c 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -25,7 +25,7 @@ import re import sys from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession def computeContribs(urls, rank): @@ -51,14 +51,17 @@ if __name__ == "__main__": file=sys.stderr) # Initialize the spark context. - sc = SparkContext(appName="PythonPageRank") + spark = SparkSession\ + .builder\ + .appName("PythonPageRank")\ + .getOrCreate() # Loads in input file. It should be in format of: # URL neighbor URL # URL neighbor URL # URL neighbor URL # ... - lines = sc.textFile(sys.argv[1], 1) + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) # Loads all URLs from input file and initialize their neighbors. links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() @@ -79,4 +82,4 @@ if __name__ == "__main__": for (link, rank) in ranks.collect(): print("%s has rank: %s." % (link, rank)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index e1fd85b082..2f09f4d573 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -18,7 +18,7 @@ from __future__ import print_function import sys -from pyspark import SparkContext +from pyspark.sql import SparkSession """ Read data file users.parquet in local Spark distro: @@ -47,7 +47,13 @@ if __name__ == "__main__": exit(-1) path = sys.argv[1] - sc = SparkContext(appName="ParquetInputFormat") + + spark = SparkSession\ + .builder\ + .appName("ParquetInputFormat")\ + .getOrCreate() + + sc = spark._sc parquet_rdd = sc.newAPIHadoopFile( path, @@ -59,4 +65,4 @@ if __name__ == "__main__": for k in output: print(k) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 92e5cf45ab..5db03e4a21 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -20,14 +20,20 @@ import sys from random import random from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": """ Usage: pi [partitions] """ - sc = SparkContext(appName="PythonPi") + spark = SparkSession\ + .builder\ + .appName("PythonPi")\ + .getOrCreate() + + sc = spark._sc + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 n = 100000 * partitions @@ -39,4 +45,4 @@ if __name__ == "__main__": count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index b6c2916254..81898cf6d5 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -19,15 +19,20 @@ from __future__ import print_function import sys -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort <file>", file=sys.stderr) exit(-1) - sc = SparkContext(appName="PythonSort") - lines = sc.textFile(sys.argv[1], 1) + + spark = SparkSession\ + .builder\ + .appName("PythonSort")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) sortedCount = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ .sortByKey() @@ -37,4 +42,4 @@ if __name__ == "__main__": for (num, unitcount) in output: print(num) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 3d61250d8b..37c41dcd03 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -20,7 +20,7 @@ from __future__ import print_function import sys from random import Random -from pyspark import SparkContext +from pyspark.sql import SparkSession numEdges = 200 numVertices = 100 @@ -41,7 +41,13 @@ if __name__ == "__main__": """ Usage: transitive_closure [partitions] """ - sc = SparkContext(appName="PythonTransitiveClosure") + spark = SparkSession\ + .builder\ + .appName("PythonTransitiveClosure")\ + .getOrCreate() + + sc = spark._sc + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 tc = sc.parallelize(generateGraph(), partitions).cache() @@ -67,4 +73,4 @@ if __name__ == "__main__": print("TC has %i edges" % tc.count()) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 7c0143607b..3d5e44d5b2 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -20,15 +20,20 @@ from __future__ import print_function import sys from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount <file>", file=sys.stderr) exit(-1) - sc = SparkContext(appName="PythonWordCount") - lines = sc.textFile(sys.argv[1], 1) + + spark = SparkSession\ + .builder\ + .appName("PythonWordCount")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) counts = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (x, 1)) \ .reduceByKey(add) @@ -36,4 +41,4 @@ if __name__ == "__main__": for (word, count) in output: print("%s: %i" % (word, count)) - sc.stop() + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index af5a815f6e..c50f25d951 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -18,7 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession /** * Usage: BroadcastTest [slices] [numElem] [blockSize] @@ -28,9 +29,16 @@ object BroadcastTest { val blockSize = if (args.length > 2) args(2) else "4096" - val sparkConf = new SparkConf().setAppName("Broadcast Test") + val sparkConf = new SparkConf() .set("spark.broadcast.blockSize", blockSize) - val sc = new SparkContext(sparkConf) + + val spark = SparkSession + .builder + .config(sparkConf) + .appName("Broadcast Test") + .getOrCreate() + + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val num = if (args.length > 1) args(1).toInt else 1000000 @@ -48,7 +56,7 @@ object BroadcastTest { println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 7bf023667d..4b5e36c736 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.File import scala.io.Source._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Simple test for reading and writing to a distributed @@ -101,11 +101,14 @@ object DFSReadWriteTest { val fileContents = readFile(localFilePath.toString()) val localWordCount = runLocalWordCount(fileContents) - println("Creating SparkConf") - val conf = new SparkConf().setAppName("DFS Read Write Test") + println("Creating SparkSession") + val spark = SparkSession + .builder + .appName("DFS Read Write Test") + .getOrCreate() println("Creating SparkContext") - val sc = new SparkContext(conf) + val sc = spark.sparkContext println("Writing local file to DFS") val dfsFilename = dfsDirPath + "/dfs_read_write_test" @@ -124,7 +127,7 @@ object DFSReadWriteTest { .values .sum - sc.stop() + spark.stop() if (localWordCount == dfsWordCount) { println(s"Success! Local Word Count ($localWordCount) " + diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index d42f63e870..6a1bbed290 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -17,18 +17,22 @@ package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession object ExceptionHandlingTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("ExceptionHandlingTest") + .getOrCreate() + val sc = spark.sparkContext + sc.parallelize(0 until sc.defaultParallelism).foreach { i => if (math.random > 0.75) { throw new Exception("Testing exception handling") } } - sc.stop() + spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 4db229b5de..0cb61d7495 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -20,20 +20,24 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object GroupByTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("GroupBy Test") + val spark = SparkSession + .builder + .appName("GroupBy Test") + .getOrCreate() + var numMappers = if (args.length > 0) args(0).toInt else 2 var numKVPairs = if (args.length > 1) args(1).toInt else 1000 var valSize = if (args.length > 2) args(2).toInt else 1000 var numReducers = if (args.length > 3) args(3).toInt else numMappers - val sc = new SparkContext(sparkConf) + val sc = spark.sparkContext val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random @@ -50,7 +54,7 @@ object GroupByTest { println(pairs1.groupByKey(numReducers).count()) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index 124dc9af63..aa8de69839 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark._ +import org.apache.spark.sql.SparkSession object HdfsTest { @@ -29,9 +29,11 @@ object HdfsTest { System.err.println("Usage: HdfsTest <file>") System.exit(1) } - val sparkConf = new SparkConf().setAppName("HdfsTest") - val sc = new SparkContext(sparkConf) - val file = sc.textFile(args(0)) + val spark = SparkSession + .builder + .appName("HdfsTest") + .getOrCreate() + val file = spark.read.text(args(0)).rdd val mapped = file.map(s => s.length).cache() for (iter <- 1 to 10) { val start = System.currentTimeMillis() @@ -39,7 +41,7 @@ object HdfsTest { val end = System.currentTimeMillis() println("Iteration " + iter + " took " + (end-start) + " ms") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 3eb0c27723..961ab99200 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,8 +18,9 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession + /** * Usage: MultiBroadcastTest [slices] [numElem] @@ -27,8 +28,12 @@ import org.apache.spark.rdd.RDD object MultiBroadcastTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("Multi-Broadcast Test") + .getOrCreate() + + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val num = if (args.length > 1) args(1).toInt else 1000000 @@ -51,7 +56,7 @@ object MultiBroadcastTest { // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index ec07e6323e..255c2bfcee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -20,23 +20,26 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { + val spark = SparkSession + .builder + .appName("SimpleSkewedGroupByTest") + .getOrCreate() + + val sc = spark.sparkContext - val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest") var numMappers = if (args.length > 0) args(0).toInt else 2 var numKVPairs = if (args.length > 1) args(1).toInt else 1000 var valSize = if (args.length > 2) args(2).toInt else 1000 var numReducers = if (args.length > 3) args(3).toInt else numMappers var ratio = if (args.length > 4) args(4).toInt else 5.0 - val sc = new SparkContext(sparkConf) - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random var result = new Array[(Int, Array[Byte])](numKVPairs) @@ -64,7 +67,7 @@ object SimpleSkewedGroupByTest { // .map{case (k,v) => (k, v.size)} // .collectAsMap) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 8e4c2b6229..efd40147f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -20,20 +20,25 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object SkewedGroupByTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("GroupBy Test") + val spark = SparkSession + .builder + .appName("GroupBy Test") + .getOrCreate() + + val sc = spark.sparkContext + var numMappers = if (args.length > 0) args(0).toInt else 2 var numKVPairs = if (args.length > 1) args(1).toInt else 1000 var valSize = if (args.length > 2) args(2).toInt else 1000 var numReducers = if (args.length > 3) args(3).toInt else numMappers - val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random @@ -54,7 +59,7 @@ object SkewedGroupByTest { println(pairs1.groupByKey(numReducers).count()) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index b06c629802..8a3d08f459 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import org.apache.commons.math3.linear._ -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Alternating least squares matrix factorization. @@ -108,8 +108,12 @@ object SparkALS { println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") - val sparkConf = new SparkConf().setAppName("SparkALS") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkALS") + .getOrCreate() + + val sc = spark.sparkContext val R = generateR() @@ -135,7 +139,7 @@ object SparkALS { println() } - sc.stop() + spark.stop() } private def randomVector(n: Int): RealVector = diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index c514eb0fa5..84f133e011 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -23,9 +23,8 @@ import java.util.Random import scala.math.exp import breeze.linalg.{DenseVector, Vector} -import org.apache.hadoop.conf.Configuration -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. @@ -67,11 +66,14 @@ object SparkHdfsLR { showWarning() - val sparkConf = new SparkConf().setAppName("SparkHdfsLR") + val spark = SparkSession + .builder + .appName("SparkHdfsLR") + .getOrCreate() + val inputPath = args(0) - val conf = new Configuration() - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(inputPath) + val lines = spark.read.text(inputPath).rdd + val points = lines.map(parsePoint).cache() val ITERATIONS = args(1).toInt @@ -88,7 +90,7 @@ object SparkHdfsLR { } println("Final w: " + w) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 676164806e..aa93c93c44 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import breeze.linalg.{squaredDistance, DenseVector, Vector} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * K-means clustering. @@ -66,14 +66,17 @@ object SparkKMeans { showWarning() - val sparkConf = new SparkConf().setAppName("SparkKMeans") - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(args(0)) + val spark = SparkSession + .builder + .appName("SparkKMeans") + .getOrCreate() + + val lines = spark.read.text(args(0)).rdd val data = lines.map(parseVector _).cache() val K = args(1).toInt val convergeDist = args(2).toDouble - val kPoints = data.takeSample(withReplacement = false, K, 42).toArray + val kPoints = data.takeSample(withReplacement = false, K, 42) var tempDist = 1.0 while(tempDist > convergeDist) { @@ -97,7 +100,7 @@ object SparkKMeans { println("Final centers:") kPoints.foreach(println) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 718f84f645..8ef3aab657 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -24,7 +24,7 @@ import scala.math.exp import breeze.linalg.{DenseVector, Vector} -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. @@ -63,8 +63,13 @@ object SparkLR { showWarning() - val sparkConf = new SparkConf().setAppName("SparkLR") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkLR") + .getOrCreate() + + val sc = spark.sparkContext + val numSlices = if (args.length > 0) args(0).toInt else 2 val points = sc.parallelize(generateData, numSlices).cache() @@ -82,7 +87,7 @@ object SparkLR { println("Final w: " + w) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 2664ddbb87..b7c363c7d4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Computes the PageRank of URLs from an input file. Input file should @@ -50,10 +50,13 @@ object SparkPageRank { showWarning() - val sparkConf = new SparkConf().setAppName("PageRank") + val spark = SparkSession + .builder + .appName("SparkPageRank") + .getOrCreate() + val iters = if (args.length > 1) args(1).toInt else 10 - val ctx = new SparkContext(sparkConf) - val lines = ctx.textFile(args(0), 1) + val lines = spark.read.text(args(0)).rdd val links = lines.map{ s => val parts = s.split("\\s+") (parts(0), parts(1)) @@ -71,7 +74,7 @@ object SparkPageRank { val output = ranks.collect() output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) - ctx.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 818d4f2b81..5be8f3b073 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -20,16 +20,19 @@ package org.apache.spark.examples import scala.math.random -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** Computes an approximation to pi */ object SparkPi { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("Spark Pi") - val spark = new SparkContext(conf) + val spark = SparkSession + .builder + .appName("Spark Pi") + .getOrCreate() + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow - val count = spark.parallelize(1 until n, slices).map { i => + val count = sc.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 if (x*x + y*y < 1) 1 else 0 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index fc7a1f859f..46aa68b8b8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import scala.collection.mutable import scala.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Transitive closure on a graph. @@ -42,10 +42,13 @@ object SparkTC { } def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("SparkTC") - val spark = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkTC") + .getOrCreate() + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 - var tc = spark.parallelize(generateGraph, slices).cache() + var tc = sc.parallelize(generateGraph, slices).cache() // Linear transitive closure: each round grows paths by one edge, // by joining the graph's edges with the already-discovered paths. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 7293cb51b2..59bdfa09ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -22,7 +22,7 @@ import java.io.File import com.google.common.io.{ByteStreams, Files} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql._ object HiveFromSpark { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 35a9f44fec..1e8f1062c5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -168,8 +168,8 @@ public class JavaDataFrameSuite { Assert.assertEquals( new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), schema.apply("d")); - Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), - schema.apply("e")); + Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, + Metadata.empty()), schema.apply("e")); Row first = df.select("a", "b", "c", "d", "e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, @@ -189,7 +189,7 @@ public class JavaDataFrameSuite { for (int i = 0; i < d.length(); i++) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } - // Java.math.BigInteger is equavient to Spark Decimal(38,0) + // Java.math.BigInteger is equavient to Spark Decimal(38,0) Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); } -- GitLab