diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index a647e6bf395811fe4a2e9e11198119748cd41c76..d50c6b8d4a4285a276d5b09d193dab00f98c40a1 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1149,6 +1149,55 @@ class KafkaStreamTests(PySparkStreamingTestCase):
         self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
         self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
 
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_transform_with_checkpoint(self):
+        """Test the Python direct Kafka stream transform with checkpoint correctly recovered."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        offsetRanges = []
+
+        def transformWithOffsetRanges(rdd):
+            for o in rdd.offsetRanges():
+                offsetRanges.append(o)
+            return rdd
+
+        self.ssc.stop(False)
+        self.ssc = None
+        tmpdir = "checkpoint-test-%d" % random.randint(0, 10000)
+
+        def setup():
+            ssc = StreamingContext(self.sc, 0.5)
+            ssc.checkpoint(tmpdir)
+            stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams)
+            stream.transform(transformWithOffsetRanges).count().pprint()
+            return ssc
+
+        try:
+            ssc1 = StreamingContext.getOrCreate(tmpdir, setup)
+            ssc1.start()
+            self.wait_for(offsetRanges, 1)
+            self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
+            # To make sure some checkpoint is written
+            time.sleep(3)
+            ssc1.stop(False)
+            ssc1 = None
+
+            # Restart again to make sure the checkpoint is recovered correctly
+            ssc2 = StreamingContext.getOrCreate(tmpdir, setup)
+            ssc2.start()
+            ssc2.awaitTermination(3)
+            ssc2.stop(stopSparkContext=False, stopGraceFully=True)
+            ssc2 = None
+        finally:
+            shutil.rmtree(tmpdir)
+
     @unittest.skipIf(sys.version >= "3", "long type not support")
     def test_kafka_rdd_message_handler(self):
         """Test Python direct Kafka RDD MessageHandler."""
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index c7f02bca2ae389882c91dccda8394c8b13fe6f8c..abbbf6eb9394fc263b9c001f06eb7206fc4ff05c 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -37,11 +37,11 @@ class TransformFunction(object):
         self.ctx = ctx
         self.func = func
         self.deserializers = deserializers
-        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+        self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
         self.failure = None
 
     def rdd_wrapper(self, func):
-        self._rdd_wrapper = func
+        self.rdd_wrap_func = func
         return self
 
     def call(self, milliseconds, jrdds):
@@ -59,7 +59,7 @@ class TransformFunction(object):
             if len(sers) < len(jrdds):
                 sers += (sers[0],) * (len(jrdds) - len(sers))
 
-            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
+            rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
                     for jrdd, ser in zip(jrdds, sers)]
             t = datetime.fromtimestamp(milliseconds / 1000.0)
             r = self.func(t, *rdds)
@@ -101,7 +101,8 @@ class TransformFunctionSerializer(object):
         self.failure = None
         try:
             func = self.gateway.gateway_property.pool[id]
-            return bytearray(self.serializer.dumps((func.func, func.deserializers)))
+            return bytearray(self.serializer.dumps((
+                func.func, func.rdd_wrap_func, func.deserializers)))
         except:
             self.failure = traceback.format_exc()
 
@@ -109,8 +110,8 @@ class TransformFunctionSerializer(object):
         # Clear the failure
         self.failure = None
         try:
-            f, deserializers = self.serializer.loads(bytes(data))
-            return TransformFunction(self.ctx, f, *deserializers)
+            f, wrap_func, deserializers = self.serializer.loads(bytes(data))
+            return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func)
         except:
             self.failure = traceback.format_exc()