Skip to content
Snippets Groups Projects
Commit 1e6648d6 authored by Shixiong Zhu's avatar Shixiong Zhu
Browse files

[SPARK-12617][PYSPARK] Move Py4jCallbackConnectionCleaner to Streaming

Move Py4jCallbackConnectionCleaner to Streaming because the callback server starts only in StreamingContext.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #10621 from zsxwing/SPARK-12617-2.
parent f82ebb15
No related branches found
No related tags found
No related merge requests found
......@@ -54,64 +54,6 @@ DEFAULT_CONFIGS = {
}
class Py4jCallbackConnectionCleaner(object):
"""
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
It will scan all callback connections every 30 seconds and close the dead connections.
"""
def __init__(self, gateway):
self._gateway = gateway
self._stopped = False
self._timer = None
self._lock = RLock()
def start(self):
if self._stopped:
return
def clean_closed_connections():
from py4j.java_gateway import quiet_close, quiet_shutdown
callback_server = self._gateway._callback_server
with callback_server.lock:
try:
closed_connections = []
for connection in callback_server.connections:
if not connection.isAlive():
quiet_close(connection.input)
quiet_shutdown(connection.socket)
quiet_close(connection.socket)
closed_connections.append(connection)
for closed_connection in closed_connections:
callback_server.connections.remove(closed_connection)
except Exception:
import traceback
traceback.print_exc()
self._start_timer(clean_closed_connections)
self._start_timer(clean_closed_connections)
def _start_timer(self, f):
from threading import Timer
with self._lock:
if not self._stopped:
self._timer = Timer(30.0, f)
self._timer.daemon = True
self._timer.start()
def stop(self):
with self._lock:
self._stopped = True
if self._timer:
self._timer.cancel()
self._timer = None
class SparkContext(object):
"""
......@@ -126,7 +68,6 @@ class SparkContext(object):
_active_spark_context = None
_lock = RLock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
_py4j_cleaner = None
PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')
......@@ -303,8 +244,6 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
_py4j_cleaner = Py4jCallbackConnectionCleaner(SparkContext._gateway)
_py4j_cleaner.start()
if instance:
if (SparkContext._active_spark_context and
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
import os
import sys
from threading import RLock, Timer
from py4j.java_gateway import java_import, JavaObject
......@@ -32,6 +33,63 @@ from pyspark.streaming.util import TransformFunction, TransformFunctionSerialize
__all__ = ["StreamingContext"]
class Py4jCallbackConnectionCleaner(object):
"""
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
It will scan all callback connections every 30 seconds and close the dead connections.
"""
def __init__(self, gateway):
self._gateway = gateway
self._stopped = False
self._timer = None
self._lock = RLock()
def start(self):
if self._stopped:
return
def clean_closed_connections():
from py4j.java_gateway import quiet_close, quiet_shutdown
callback_server = self._gateway._callback_server
if callback_server:
with callback_server.lock:
try:
closed_connections = []
for connection in callback_server.connections:
if not connection.isAlive():
quiet_close(connection.input)
quiet_shutdown(connection.socket)
quiet_close(connection.socket)
closed_connections.append(connection)
for closed_connection in closed_connections:
callback_server.connections.remove(closed_connection)
except Exception:
import traceback
traceback.print_exc()
self._start_timer(clean_closed_connections)
self._start_timer(clean_closed_connections)
def _start_timer(self, f):
with self._lock:
if not self._stopped:
self._timer = Timer(30.0, f)
self._timer.daemon = True
self._timer.start()
def stop(self):
with self._lock:
self._stopped = True
if self._timer:
self._timer.cancel()
self._timer = None
class StreamingContext(object):
"""
Main entry point for Spark Streaming functionality. A StreamingContext
......@@ -47,6 +105,9 @@ class StreamingContext(object):
# Reference to a currently active StreamingContext
_activeContext = None
# A cleaner to clean leak sockets of callback server every 30 seconds
_py4j_cleaner = None
def __init__(self, sparkContext, batchDuration=None, jssc=None):
"""
Create a new StreamingContext.
......@@ -95,6 +156,8 @@ class StreamingContext(object):
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
_py4j_cleaner = Py4jCallbackConnectionCleaner(gw)
_py4j_cleaner.start()
# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment