Skip to content
Snippets Groups Projects
  • Josh Rosen's avatar
    0cfda846
    [SPARK-2313] Use socket to communicate GatewayServer port back to Python driver · 0cfda846
    Josh Rosen authored
    This patch changes PySpark so that the GatewayServer's port is communicated back to the Python process that launches it over a local socket instead of a pipe.  The old pipe-based approach was brittle and could fail if `spark-submit` printed unexpected to stdout.
    
    To accomplish this, I wrote a custom `PythonGatewayServer.main()` function to use in place of Py4J's `GatewayServer.main()`.
    
    Closes #3424.
    
    Author: Josh Rosen <joshrosen@databricks.com>
    
    Closes #4603 from JoshRosen/SPARK-2313 and squashes the following commits:
    
    6a7740b [Josh Rosen] Remove EchoOutputThread since it's no longer needed
    0db501f [Josh Rosen] Use select() so that we don't block if GatewayServer dies.
    9bdb4b6 [Josh Rosen] Handle case where getListeningPort returns -1
    3fb7ed1 [Josh Rosen] Remove stdout=PIPE
    2458934 [Josh Rosen] Use underscore to mark env var. as private
    d12c95d [Josh Rosen] Use Logging and Utils.tryOrExit()
    e5f9730 [Josh Rosen] Wrap everything in a giant try-block
    2f70689 [Josh Rosen] Use stdin PIPE to share fate with driver
    8bf956e [Josh Rosen] Initial cut at passing Py4J gateway port back to driver via socket
    0cfda846
    History
    [SPARK-2313] Use socket to communicate GatewayServer port back to Python driver
    Josh Rosen authored
    This patch changes PySpark so that the GatewayServer's port is communicated back to the Python process that launches it over a local socket instead of a pipe.  The old pipe-based approach was brittle and could fail if `spark-submit` printed unexpected to stdout.
    
    To accomplish this, I wrote a custom `PythonGatewayServer.main()` function to use in place of Py4J's `GatewayServer.main()`.
    
    Closes #3424.
    
    Author: Josh Rosen <joshrosen@databricks.com>
    
    Closes #4603 from JoshRosen/SPARK-2313 and squashes the following commits:
    
    6a7740b [Josh Rosen] Remove EchoOutputThread since it's no longer needed
    0db501f [Josh Rosen] Use select() so that we don't block if GatewayServer dies.
    9bdb4b6 [Josh Rosen] Handle case where getListeningPort returns -1
    3fb7ed1 [Josh Rosen] Remove stdout=PIPE
    2458934 [Josh Rosen] Use underscore to mark env var. as private
    d12c95d [Josh Rosen] Use Logging and Utils.tryOrExit()
    e5f9730 [Josh Rosen] Wrap everything in a giant try-block
    2f70689 [Josh Rosen] Use stdin PIPE to share fate with driver
    8bf956e [Josh Rosen] Initial cut at passing Py4J gateway port back to driver via socket
java_gateway.py 5.30 KiB
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import atexit
import os
import select
import signal
import shlex
import socket
import platform
from subprocess import Popen, PIPE
from py4j.java_gateway import java_import, JavaGateway, GatewayClient

from pyspark.serializers import read_int


def launch_gateway():
    SPARK_HOME = os.environ["SPARK_HOME"]

    if "PYSPARK_GATEWAY_PORT" in os.environ:
        gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
    else:
        # Launch the Py4j gateway using Spark's run command so that we pick up the
        # proper classpath and settings from spark-env.sh
        on_windows = platform.system() == "Windows"
        script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
        submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS")
        submit_args = submit_args if submit_args is not None else ""
        submit_args = shlex.split(submit_args)
        command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"]

        # Start a socket that will be used by PythonGatewayServer to communicate its port to us
        callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        callback_socket.bind(('127.0.0.1', 0))
        callback_socket.listen(1)
        callback_host, callback_port = callback_socket.getsockname()
        env = dict(os.environ)
        env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
        env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)

        # Launch the Java gateway.
        # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
        if not on_windows:
            # Don't send ctrl-c / SIGINT to the Java gateway:
            def preexec_func():
                signal.signal(signal.SIGINT, signal.SIG_IGN)
            env["IS_SUBPROCESS"] = "1"  # tell JVM to exit after python exits
            proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
        else:
            # preexec_fn not supported on Windows
            proc = Popen(command, stdin=PIPE, env=env)

        gateway_port = None
        # We use select() here in order to avoid blocking indefinitely if the subprocess dies
        # before connecting
        while gateway_port is None and proc.poll() is None:
            timeout = 1  # (seconds)
            readable, _, _ = select.select([callback_socket], [], [], timeout)
            if callback_socket in readable:
                gateway_connection = callback_socket.accept()[0]
                # Determine which ephemeral port the server started on:
                gateway_port = read_int(gateway_connection.makefile())
                gateway_connection.close()
                callback_socket.close()
        if gateway_port is None:
            raise Exception("Java gateway process exited before sending the driver its port number")

        # In Windows, ensure the Java child processes do not linger after Python has exited.
        # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
        # the parent process' stdin sends an EOF). In Windows, however, this is not possible
        # because java.lang.Process reads directly from the parent process' stdin, contending
        # with any opportunity to read an EOF from the parent. Note that this is only best
        # effort and will not take effect if the python process is violently terminated.
        if on_windows:
            # In Windows, the child process here is "spark-submit.cmd", not the JVM itself
            # (because the UNIX "exec" command is not available). This means we cannot simply
            # call proc.kill(), which kills only the "spark-submit.cmd" process but not the
            # JVMs. Instead, we use "taskkill" with the tree-kill option "/t" to terminate all
            # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
            def killChild():
                Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
            atexit.register(killChild)

    # Connect to the gateway
    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)

    # Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

    return gateway