diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 3deed52be0be2e2d66cc0285b814c32ae1a48421..5cc4bbde3995869d15f06bf686fa981a3aa6297d 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -98,8 +98,28 @@ class StreamingContext(object):
 
         # register serializer for TransformFunction
         # it happens before creating SparkContext when loading from checkpointing
-        cls._transformerSerializer = TransformFunctionSerializer(
-            SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+        if cls._transformerSerializer is None:
+            transformer_serializer = TransformFunctionSerializer()
+            transformer_serializer.init(
+                SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+            # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM
+            # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever.
+            # (https://github.com/bartdag/py4j/pull/184)
+            #
+            # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when
+            # calling "registerSerializer". If we call "registerSerializer" twice, the second
+            # PythonProxyHandler will override the first one, then the first one will be GCed and
+            # trigger "PythonProxyHandler.finalize". To avoid that, we should not call
+            # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't
+            # be GCed.
+            #
+            # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version.
+            transformer_serializer.gateway.jvm.PythonDStream.registerSerializer(
+                transformer_serializer)
+            cls._transformerSerializer = transformer_serializer
+        else:
+            cls._transformerSerializer.init(
+                SparkContext._active_spark_context, CloudPickleSerializer(), gw)
 
     @classmethod
     def getOrCreate(cls, checkpointPath, setupFunc):
@@ -116,16 +136,13 @@ class StreamingContext(object):
         gw = SparkContext._gateway
 
         # Check whether valid checkpoint information exists in the given path
-        if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty():
+        ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
+        if ssc_option.isEmpty():
             ssc = setupFunc()
             ssc.checkpoint(checkpointPath)
             return ssc
 
-        try:
-            jssc = gw.jvm.JavaStreamingContext(checkpointPath)
-        except Exception:
-            print("failed to load StreamingContext from checkpoint", file=sys.stderr)
-            raise
+        jssc = gw.jvm.JavaStreamingContext(ssc_option.get())
 
         # If there is already an active instance of Python SparkContext use it, or create a new one
         if not SparkContext._active_spark_context:
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index abbbf6eb9394fc263b9c001f06eb7206fc4ff05c..e617fc9ce9eec1c189a7d8e25afd827330868a8e 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -89,11 +89,10 @@ class TransformFunctionSerializer(object):
     it uses this class to invoke Python, which returns the serialized function
     as a byte array.
     """
-    def __init__(self, ctx, serializer, gateway=None):
+    def init(self, ctx, serializer, gateway=None):
         self.ctx = ctx
         self.serializer = serializer
         self.gateway = gateway or self.ctx._gateway
-        self.gateway.jvm.PythonDStream.registerSerializer(self)
         self.failure = None
 
     def dumps(self, id):
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index c4a10aa2dd3b96228db9381f15888a8671e1d0f6..a5ab66697589b1f9bc965a42919b551a8ee378e8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -902,3 +902,15 @@ object StreamingContext extends Logging {
     result
   }
 }
+
+private class StreamingContextPythonHelper {
+
+  /**
+   * This is a private method only for Python to implement `getOrCreate`.
+   */
+  def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = {
+    val checkpointOption = CheckpointReader.read(
+      checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false)
+    checkpointOption.map(new StreamingContext(null, _, null))
+  }
+}