Skip to content
Snippets Groups Projects
  • Davies Liu's avatar
    2aea0da8
    [SPARK-3030] [PySpark] Reuse Python worker · 2aea0da8
    Davies Liu authored
    Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts.
    
    This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming.
    
    For a job with broadcast (43M after compress):
    ```
        b = sc.broadcast(set(range(30000000)))
        print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count()
    ```
    It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks.
    
    It's enabled by default, could be disabled by `spark.python.worker.reuse = false`.
    
    Author: Davies Liu <davies.liu@gmail.com>
    
    Closes #2259 from davies/reuse-worker and squashes the following commits:
    
    f11f617 [Davies Liu] Merge branch 'master' into reuse-worker
    3939f20 [Davies Liu] fix bug in serializer in mllib
    cf1c55e [Davies Liu] address comments
    3133a60 [Davies Liu] fix accumulator with reused worker
    760ab1f [Davies Liu] do not reuse worker if there are any exceptions
    7abb224 [Davies Liu] refactor: sychronized with itself
    ac3206e [Davies Liu] renaming
    8911f44 [Davies Liu] synchronized getWorkerBroadcasts()
    6325fc1 [Davies Liu] bugfix: bid >= 0
    e0131a2 [Davies Liu] fix name of config
    583716e [Davies Liu] only reuse completed and not interrupted worker
    ace2917 [Davies Liu] kill python worker after timeout
    6123d0f [Davies Liu] track broadcasts for each worker
    8d2f08c [Davies Liu] reuse python worker
    2aea0da8
    History
    [SPARK-3030] [PySpark] Reuse Python worker
    Davies Liu authored
    Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts.
    
    This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming.
    
    For a job with broadcast (43M after compress):
    ```
        b = sc.broadcast(set(range(30000000)))
        print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count()
    ```
    It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks.
    
    It's enabled by default, could be disabled by `spark.python.worker.reuse = false`.
    
    Author: Davies Liu <davies.liu@gmail.com>
    
    Closes #2259 from davies/reuse-worker and squashes the following commits:
    
    f11f617 [Davies Liu] Merge branch 'master' into reuse-worker
    3939f20 [Davies Liu] fix bug in serializer in mllib
    cf1c55e [Davies Liu] address comments
    3133a60 [Davies Liu] fix accumulator with reused worker
    760ab1f [Davies Liu] do not reuse worker if there are any exceptions
    7abb224 [Davies Liu] refactor: sychronized with itself
    ac3206e [Davies Liu] renaming
    8911f44 [Davies Liu] synchronized getWorkerBroadcasts()
    6325fc1 [Davies Liu] bugfix: bid >= 0
    e0131a2 [Davies Liu] fix name of config
    583716e [Davies Liu] only reuse completed and not interrupted worker
    ace2917 [Davies Liu] kill python worker after timeout
    6123d0f [Davies Liu] track broadcasts for each worker
    8d2f08c [Davies Liu] reuse python worker
worker.py 4.29 KiB
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Worker that receives input from Piped RDD.
"""
import os
import sys
import time
import socket
import traceback
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
    write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
    CompressedSerializer


pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()


def report_times(outfile, boot, init, finish):
    write_int(SpecialLengths.TIMING_DATA, outfile)
    write_long(1000 * boot, outfile)
    write_long(1000 * init, outfile)
    write_long(1000 * finish, outfile)


def main(infile, outfile):
    try:
        boot_time = time.time()
        split_index = read_int(infile)
        if split_index == -1:  # for unit tests
            return

        # fetch name of workdir
        spark_files_dir = utf8_deserializer.loads(infile)
        SparkFiles._root_directory = spark_files_dir
        SparkFiles._is_running_on_worker = True

        # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
        sys.path.append(spark_files_dir)  # *.py files that were added will be copied here
        num_python_includes = read_int(infile)
        for _ in range(num_python_includes):
            filename = utf8_deserializer.loads(infile)
            sys.path.append(os.path.join(spark_files_dir, filename))

        # fetch names and values of broadcast variables
        num_broadcast_variables = read_int(infile)
        ser = CompressedSerializer(pickleSer)
        for _ in range(num_broadcast_variables):
            bid = read_long(infile)
            if bid >= 0:
                value = ser._read_with_length(infile)
                _broadcastRegistry[bid] = Broadcast(bid, value)
            else:
                bid = - bid - 1
                _broadcastRegistry.remove(bid)

        _accumulatorRegistry.clear()
        command = pickleSer._read_with_length(infile)
        (func, deserializer, serializer) = command
        init_time = time.time()
        iterator = deserializer.load_stream(infile)
        serializer.dump_stream(func(split_index, iterator), outfile)
    except Exception:
        try:
            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
            write_with_length(traceback.format_exc(), outfile)
            outfile.flush()
        except IOError:
            # JVM close the socket
            pass
        except Exception:
            # Write the error to stderr if it happened while serializing
            print >> sys.stderr, "PySpark worker failed with exception:"
            print >> sys.stderr, traceback.format_exc()
        exit(-1)
    finish_time = time.time()
    report_times(outfile, boot_time, init_time, finish_time)
    # Mark the beginning of the accumulators section of the output
    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
    write_int(len(_accumulatorRegistry), outfile)
    for (aid, accum) in _accumulatorRegistry.items():
        pickleSer._write_with_length((aid, accum._value), outfile)


if __name__ == '__main__':
    # Read a local port to connect to from stdin
    java_port = int(sys.stdin.readline())
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect(("127.0.0.1", java_port))
    sock_file = sock.makefile("a+", 65536)
    main(sock_file, sock_file)