diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 0e33362d34acdaf23cf0c4685f3b860ea46e3e55..f3b01bd60b178e5c962cb1c2db0a8a189983dbc7 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -670,4 +670,17 @@ private class KafkaUtilsPythonHelper {
     TopicAndPartition(topic, partition)
 
   def createBroker(host: String, port: JInt): Broker = Broker(host, port)
+
+  def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = {
+    val parentRDDs = rdd.getNarrowAncestors
+    val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]])
+
+    require(
+      kafkaRDDs.length == 1,
+      "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" +
+        "with this RDD, please call this method only on a Kafka RDD.")
+
+    val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]]
+    kafkaRDD.offsetRanges.toSeq
+  }
 }
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 10a859a532e280c41a482457556ab58a5dbbaa42..33dd596335b47cdad8179199cd964ff501a76624 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -21,6 +21,8 @@ from pyspark.rdd import RDD
 from pyspark.storagelevel import StorageLevel
 from pyspark.serializers import PairDeserializer, NoOpSerializer
 from pyspark.streaming import DStream
+from pyspark.streaming.dstream import TransformedDStream
+from pyspark.streaming.util import TransformFunction
 
 __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
 
@@ -122,8 +124,9 @@ class KafkaUtils(object):
             raise e
 
         ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-        stream = DStream(jstream, ssc, ser)
-        return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        stream = DStream(jstream, ssc, ser) \
+            .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
 
     @staticmethod
     def createRDD(sc, kafkaParams, offsetRanges, leaders={},
@@ -161,8 +164,8 @@ class KafkaUtils(object):
             raise e
 
         ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-        rdd = RDD(jrdd, sc, ser)
-        return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
 
     @staticmethod
     def _printErrorMsg(sc):
@@ -200,14 +203,30 @@ class OffsetRange(object):
         :param fromOffset: Inclusive starting offset.
         :param untilOffset: Exclusive ending offset.
         """
-        self._topic = topic
-        self._partition = partition
-        self._fromOffset = fromOffset
-        self._untilOffset = untilOffset
+        self.topic = topic
+        self.partition = partition
+        self.fromOffset = fromOffset
+        self.untilOffset = untilOffset
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return (self.topic == other.topic
+                    and self.partition == other.partition
+                    and self.fromOffset == other.fromOffset
+                    and self.untilOffset == other.untilOffset)
+        else:
+            return False
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __str__(self):
+        return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \
+               % (self.topic, self.partition, self.fromOffset, self.untilOffset)
 
     def _jOffsetRange(self, helper):
-        return helper.createOffsetRange(self._topic, self._partition, self._fromOffset,
-                                        self._untilOffset)
+        return helper.createOffsetRange(self.topic, self.partition, self.fromOffset,
+                                        self.untilOffset)
 
 
 class TopicAndPartition(object):
@@ -244,3 +263,87 @@ class Broker(object):
 
     def _jBroker(self, helper):
         return helper.createBroker(self._host, self._port)
+
+
+class KafkaRDD(RDD):
+    """
+    A Python wrapper of KafkaRDD, to provide additional information on normal RDD.
+    """
+
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
+        RDD.__init__(self, jrdd, ctx, jrdd_deserializer)
+
+    def offsetRanges(self):
+        """
+        Get the OffsetRange of specific KafkaRDD.
+        :return: A list of OffsetRange
+        """
+        try:
+            helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
+            helper = helperClass.newInstance()
+            joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
+        except Py4JJavaError as e:
+            if 'ClassNotFoundException' in str(e.java_exception):
+                KafkaUtils._printErrorMsg(self.ctx)
+            raise e
+
+        ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
+                  for o in joffsetRanges]
+        return ranges
+
+
+class KafkaDStream(DStream):
+    """
+    A Python wrapper of KafkaDStream
+    """
+
+    def __init__(self, jdstream, ssc, jrdd_deserializer):
+        DStream.__init__(self, jdstream, ssc, jrdd_deserializer)
+
+    def foreachRDD(self, func):
+        """
+        Apply a function to each RDD in this DStream.
+        """
+        if func.__code__.co_argcount == 1:
+            old_func = func
+            func = lambda r, rdd: old_func(rdd)
+        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \
+            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
+        api = self._ssc._jvm.PythonDStream
+        api.callForeachRDD(self._jdstream, jfunc)
+
+    def transform(self, func):
+        """
+        Return a new DStream in which each RDD is generated by applying a function
+        on each RDD of this DStream.
+
+        `func` can have one argument of `rdd`, or have two arguments of
+        (`time`, `rdd`)
+        """
+        if func.__code__.co_argcount == 1:
+            oldfunc = func
+            func = lambda t, rdd: oldfunc(rdd)
+        assert func.__code__.co_argcount == 2, "func should take one or two arguments"
+
+        return KafkaTransformedDStream(self, func)
+
+
+class KafkaTransformedDStream(TransformedDStream):
+    """
+    Kafka specific wrapper of TransformedDStream to transform on Kafka RDD.
+    """
+
+    def __init__(self, prev, func):
+        TransformedDStream.__init__(self, prev, func)
+
+    @property
+    def _jdstream(self):
+        if self._jdstream_val is not None:
+            return self._jdstream_val
+
+        jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \
+            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
+        dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
+        self._jdstream_val = dstream.asJavaDStream()
+        return self._jdstream_val
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 188c8ff12067e0d6f8df427076e2cff9296af01f..4ecae1e4bf28276baf801e2d33ab8a6b11c565f8 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -678,6 +678,70 @@ class KafkaStreamTests(PySparkStreamingTestCase):
         rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
         self._validateRddResult(sendData, rdd)
 
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_rdd_get_offsetRanges(self):
+        """Test Python direct Kafka RDD get OffsetRanges."""
+        topic = self._randomTopic()
+        sendData = {"a": 3, "b": 4, "c": 5}
+        offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+        rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
+        self.assertEqual(offsetRanges, rdd.offsetRanges())
+
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_foreach_get_offsetRanges(self):
+        """Test the Python direct Kafka stream foreachRDD get offsetRanges."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
+
+        offsetRanges = []
+
+        def getOffsetRanges(_, rdd):
+            for o in rdd.offsetRanges():
+                offsetRanges.append(o)
+
+        stream.foreachRDD(getOffsetRanges)
+        self.ssc.start()
+        self.wait_for(offsetRanges, 1)
+
+        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_transform_get_offsetRanges(self):
+        """Test the Python direct Kafka stream transform get offsetRanges."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
+
+        offsetRanges = []
+
+        def transformWithOffsetRanges(rdd):
+            for o in rdd.offsetRanges():
+                offsetRanges.append(o)
+            return rdd
+
+        stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count())
+        self.ssc.start()
+        self.wait_for(offsetRanges, 1)
+
+        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
 
 class FlumeStreamTests(PySparkStreamingTestCase):
     timeout = 20  # seconds
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index a9bfec2aab8fc1bee7fcd3af1bae5be7d8c1a123..b20613b1283bd3f17816392486e6fbc6ac00213e 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -37,6 +37,11 @@ class TransformFunction(object):
         self.ctx = ctx
         self.func = func
         self.deserializers = deserializers
+        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+
+    def rdd_wrapper(self, func):
+        self._rdd_wrapper = func
+        return self
 
     def call(self, milliseconds, jrdds):
         try:
@@ -51,7 +56,7 @@ class TransformFunction(object):
             if len(sers) < len(jrdds):
                 sers += (sers[0],) * (len(jrdds) - len(sers))
 
-            rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
                     for jrdd, ser in zip(jrdds, sers)]
             t = datetime.fromtimestamp(milliseconds / 1000.0)
             r = self.func(t, *rdds)