Skip to content
Snippets Groups Projects
Commit 0903a185 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Andrew Or
Browse files

[SPARK-15084][PYTHON][SQL] Use builder pattern to create SparkSession in PySpark.

## What changes were proposed in this pull request?

This is a python port of corresponding Scala builder pattern code. `sql.py` is modified as a target example case.

## How was this patch tested?

Manual.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12860 from dongjoon-hyun/SPARK-15084.
parent c1839c99
No related branches found
No related tags found
No related merge requests found
...@@ -20,33 +20,28 @@ from __future__ import print_function ...@@ -20,33 +20,28 @@ from __future__ import print_function
import os import os
import sys import sys
from pyspark import SparkContext from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType
if __name__ == "__main__": if __name__ == "__main__":
sc = SparkContext(appName="PythonSQL") spark = SparkSession.builder.appName("PythonSQL").getOrCreate()
sqlContext = SQLContext(sc)
# RDD is created from a list of rows # A list of Rows. Infer schema from the first row, create a DataFrame and print the schema
some_rdd = sc.parallelize([Row(name="John", age=19), rows = [Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]
Row(name="Smith", age=23), some_df = spark.createDataFrame(rows)
Row(name="Sarah", age=18)])
# Infer schema from the first row, create a DataFrame and print the schema
some_df = sqlContext.createDataFrame(some_rdd)
some_df.printSchema() some_df.printSchema()
# Another RDD is created from a list of tuples # A list of tuples
another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) tuples = [("John", 19), ("Smith", 23), ("Sarah", 18)]
# Schema with two fields - person_name and person_age # Schema with two fields - person_name and person_age
schema = StructType([StructField("person_name", StringType(), False), schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)]) StructField("person_age", IntegerType(), False)])
# Create a DataFrame by applying the schema to the RDD and print the schema # Create a DataFrame by applying the schema to the RDD and print the schema
another_df = sqlContext.createDataFrame(another_rdd, schema) another_df = spark.createDataFrame(tuples, schema)
another_df.printSchema() another_df.printSchema()
# root # root
# |-- age: integer (nullable = true) # |-- age: long (nullable = true)
# |-- name: string (nullable = true) # |-- name: string (nullable = true)
# A JSON dataset is pointed to by path. # A JSON dataset is pointed to by path.
...@@ -57,7 +52,7 @@ if __name__ == "__main__": ...@@ -57,7 +52,7 @@ if __name__ == "__main__":
else: else:
path = sys.argv[1] path = sys.argv[1]
# Create a DataFrame from the file(s) pointed to by path # Create a DataFrame from the file(s) pointed to by path
people = sqlContext.jsonFile(path) people = spark.read.json(path)
# root # root
# |-- person_name: string (nullable = false) # |-- person_name: string (nullable = false)
# |-- person_age: integer (nullable = false) # |-- person_age: integer (nullable = false)
...@@ -65,16 +60,16 @@ if __name__ == "__main__": ...@@ -65,16 +60,16 @@ if __name__ == "__main__":
# The inferred schema can be visualized using the printSchema() method. # The inferred schema can be visualized using the printSchema() method.
people.printSchema() people.printSchema()
# root # root
# |-- age: IntegerType # |-- age: long (nullable = true)
# |-- name: StringType # |-- name: string (nullable = true)
# Register this DataFrame as a table. # Register this DataFrame as a table.
people.registerAsTable("people") people.registerTempTable("people")
# SQL statements can be run by using the sql methods provided by sqlContext # SQL statements can be run by using the sql methods provided by sqlContext
teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") teenagers = spark.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
for each in teenagers.collect(): for each in teenagers.collect():
print(each[0]) print(each[0])
sc.stop() spark.stop()
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import sys import sys
import warnings import warnings
from functools import reduce from functools import reduce
from threading import RLock
if sys.version >= '3': if sys.version >= '3':
basestring = unicode = str basestring = unicode = str
...@@ -58,16 +59,98 @@ def _monkey_patch_RDD(sparkSession): ...@@ -58,16 +59,98 @@ def _monkey_patch_RDD(sparkSession):
class SparkSession(object): class SparkSession(object):
"""Main entry point for Spark SQL functionality. """The entry point to programming Spark with the Dataset and DataFrame API.
A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files. tables, execute SQL over tables, cache tables, and read parquet files.
To create a SparkSession, use the following builder pattern:
>>> spark = SparkSession.builder \
.master("local") \
.appName("Word Count") \
.config("spark.some.config.option", "some-value") \
.getOrCreate()
:param sparkContext: The :class:`SparkContext` backing this SparkSession. :param sparkContext: The :class:`SparkContext` backing this SparkSession.
:param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new :param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new
SparkSession in the JVM, instead we make all calls to this object. SparkSession in the JVM, instead we make all calls to this object.
""" """
class Builder(object):
"""Builder for :class:`SparkSession`.
"""
_lock = RLock()
_options = {}
@since(2.0)
def config(self, key=None, value=None, conf=None):
"""Sets a config option. Options set using this method are automatically propagated to
both :class:`SparkConf` and :class:`SparkSession`'s own configuration.
For an existing SparkConf, use `conf` parameter.
>>> from pyspark.conf import SparkConf
>>> SparkSession.builder.config(conf=SparkConf())
<pyspark.sql.session...
For a (key, value) pair, you can omit parameter names.
>>> SparkSession.builder.config("spark.some.config.option", "some-value")
<pyspark.sql.session...
:param key: a key name string for configuration property
:param value: a value for configuration property
:param conf: an instance of :class:`SparkConf`
"""
with self._lock:
if conf is None:
self._options[key] = str(value)
else:
for (k, v) in conf.getAll():
self._options[k] = v
return self
@since(2.0)
def master(self, master):
"""Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]"
to run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone
cluster.
:param master: a url for spark master
"""
return self.config("spark.master", master)
@since(2.0)
def appName(self, name):
"""Sets a name for the application, which will be shown in the Spark web UI.
:param name: an application name
"""
return self.config("spark.app.name", name)
@since(2.0)
def enableHiveSupport(self):
"""Enables Hive support, including connectivity to a persistent Hive metastore, support
for Hive serdes, and Hive user-defined functions.
"""
return self.config("spark.sql.catalogImplementation", "hive")
@since(2.0)
def getOrCreate(self):
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
one based on the options set in this builder.
"""
with self._lock:
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql.context import SQLContext
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
sparkContext = SparkContext.getOrCreate(sparkConf)
return SQLContext.getOrCreate(sparkContext).sparkSession
builder = Builder()
_instantiatedContext = None _instantiatedContext = None
@ignore_unicode_prefix @ignore_unicode_prefix
...@@ -445,6 +528,12 @@ class SparkSession(object): ...@@ -445,6 +528,12 @@ class SparkSession(object):
""" """
return DataFrameReader(self._wrapped) return DataFrameReader(self._wrapped)
@since(2.0)
def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()
def _test(): def _test():
import os import os
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment