diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 31282220775375b9e6ec417954d912d2f094bf2a..ad2fb8aa5f24c30164066844f94129ff8b128bcf 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -17,25 +17,29 @@
 
 package org.apache.spark.streaming.kafka
 
+import java.io.OutputStream
 import java.lang.{Integer => JInt, Long => JLong}
 import java.util.{List => JList, Map => JMap, Set => JSet}
 
 import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
+import com.google.common.base.Charsets.UTF_8
 import kafka.common.TopicAndPartition
 import kafka.message.MessageAndMetadata
-import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder}
+import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
+import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler}
 
 import org.apache.spark.api.java.function.{Function => JFunction}
-import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.streaming.util.WriteAheadLogUtils
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.StreamingContext
-import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
-import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
-import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.streaming.api.java._
+import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream}
 
 object KafkaUtils {
   /**
@@ -184,6 +188,27 @@ object KafkaUtils {
     }
   }
 
+  private[kafka] def getFromOffsets(
+      kc: KafkaCluster,
+      kafkaParams: Map[String, String],
+      topics: Set[String]
+    ): Map[TopicAndPartition, Long] = {
+    val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
+    val result = for {
+      topicPartitions <- kc.getPartitions(topics).right
+      leaderOffsets <- (if (reset == Some("smallest")) {
+        kc.getEarliestLeaderOffsets(topicPartitions)
+      } else {
+        kc.getLatestLeaderOffsets(topicPartitions)
+      }).right
+    } yield {
+      leaderOffsets.map { case (tp, lo) =>
+          (tp, lo.offset)
+      }
+    }
+    KafkaCluster.checkErrors(result)
+  }
+
   /**
    * Create a RDD from Kafka using offset ranges for each topic and partition.
    *
@@ -246,7 +271,7 @@ object KafkaUtils {
       // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
       leaders.map {
         case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
-      }.toMap
+      }
     }
     val cleanedHandler = sc.clean(messageHandler)
     checkOffsets(kc, offsetRanges)
@@ -406,23 +431,9 @@ object KafkaUtils {
   ): InputDStream[(K, V)] = {
     val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
     val kc = new KafkaCluster(kafkaParams)
-    val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
-
-    val result = for {
-      topicPartitions <- kc.getPartitions(topics).right
-      leaderOffsets <- (if (reset == Some("smallest")) {
-        kc.getEarliestLeaderOffsets(topicPartitions)
-      } else {
-        kc.getLatestLeaderOffsets(topicPartitions)
-      }).right
-    } yield {
-      val fromOffsets = leaderOffsets.map { case (tp, lo) =>
-          (tp, lo.offset)
-      }
-      new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
-        ssc, kafkaParams, fromOffsets, messageHandler)
-    }
-    KafkaCluster.checkErrors(result)
+    val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
+    new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
+      ssc, kafkaParams, fromOffsets, messageHandler)
   }
 
   /**
@@ -550,6 +561,8 @@ object KafkaUtils {
  * takes care of known parameters instead of passing them from Python
  */
 private[kafka] class KafkaUtilsPythonHelper {
+  import KafkaUtilsPythonHelper._
+
   def createStream(
       jssc: JavaStreamingContext,
       kafkaParams: JMap[String, String],
@@ -566,86 +579,92 @@ private[kafka] class KafkaUtilsPythonHelper {
       storageLevel)
   }
 
-  def createRDD(
+  def createRDDWithoutMessageHandler(
       jsc: JavaSparkContext,
       kafkaParams: JMap[String, String],
       offsetRanges: JList[OffsetRange],
-      leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = {
-    val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
-      (Array[Byte], Array[Byte])] {
-      def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
-        (t1.key(), t1.message())
-    }
+      leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = {
+    val messageHandler =
+      (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
+    new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler))
+  }
 
-    val jrdd = KafkaUtils.createRDD[
-      Array[Byte],
-      Array[Byte],
-      DefaultDecoder,
-      DefaultDecoder,
-      (Array[Byte], Array[Byte])](
-        jsc,
-        classOf[Array[Byte]],
-        classOf[Array[Byte]],
-        classOf[DefaultDecoder],
-        classOf[DefaultDecoder],
-        classOf[(Array[Byte], Array[Byte])],
-        kafkaParams,
-        offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
-        leaders,
-        messageHandler
-      )
-    new JavaPairRDD(jrdd.rdd)
+  def createRDDWithMessageHandler(
+      jsc: JavaSparkContext,
+      kafkaParams: JMap[String, String],
+      offsetRanges: JList[OffsetRange],
+      leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = {
+    val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
+      new PythonMessageAndMetadata(
+        mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
+    val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler).
+        mapPartitions(picklerIterator)
+    new JavaRDD(rdd)
   }
 
-  def createDirectStream(
+  private def createRDD[V: ClassTag](
+      jsc: JavaSparkContext,
+      kafkaParams: JMap[String, String],
+      offsetRanges: JList[OffsetRange],
+      leaders: JMap[TopicAndPartition, Broker],
+      messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = {
+    KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
+      jsc.sc,
+      kafkaParams.asScala.toMap,
+      offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
+      leaders.asScala.toMap,
+      messageHandler
+    )
+  }
+
+  def createDirectStreamWithoutMessageHandler(
+      jssc: JavaStreamingContext,
+      kafkaParams: JMap[String, String],
+      topics: JSet[String],
+      fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = {
+    val messageHandler =
+      (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
+    new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler))
+  }
+
+  def createDirectStreamWithMessageHandler(
       jssc: JavaStreamingContext,
       kafkaParams: JMap[String, String],
       topics: JSet[String],
-      fromOffsets: JMap[TopicAndPartition, JLong]
-    ): JavaPairInputDStream[Array[Byte], Array[Byte]] = {
+      fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = {
+    val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
+      new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
+    val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler).
+      mapPartitions(picklerIterator)
+    new JavaDStream(stream)
+  }
 
-    if (!fromOffsets.isEmpty) {
+  private def createDirectStream[V: ClassTag](
+      jssc: JavaStreamingContext,
+      kafkaParams: JMap[String, String],
+      topics: JSet[String],
+      fromOffsets: JMap[TopicAndPartition, JLong],
+      messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = {
+
+    val currentFromOffsets = if (!fromOffsets.isEmpty) {
       val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
       if (topicsFromOffsets != topics.asScala.toSet) {
         throw new IllegalStateException(
           s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " +
           s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}")
       }
-    }
-
-    if (fromOffsets.isEmpty) {
-      KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
-        jssc,
-        classOf[Array[Byte]],
-        classOf[Array[Byte]],
-        classOf[DefaultDecoder],
-        classOf[DefaultDecoder],
-        kafkaParams,
-        topics)
+      Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*)
     } else {
-      val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
-        (Array[Byte], Array[Byte])] {
-        def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
-          (t1.key(), t1.message())
-      }
-
-      val jstream = KafkaUtils.createDirectStream[
-        Array[Byte],
-        Array[Byte],
-        DefaultDecoder,
-        DefaultDecoder,
-        (Array[Byte], Array[Byte])](
-          jssc,
-          classOf[Array[Byte]],
-          classOf[Array[Byte]],
-          classOf[DefaultDecoder],
-          classOf[DefaultDecoder],
-          classOf[(Array[Byte], Array[Byte])],
-          kafkaParams,
-          fromOffsets,
-          messageHandler)
-      new JavaPairInputDStream(jstream.inputDStream)
+      val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*))
+      KafkaUtils.getFromOffsets(
+        kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
     }
+
+    KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
+      jssc.ssc,
+      Map(kafkaParams.asScala.toSeq: _*),
+      Map(currentFromOffsets.toSeq: _*),
+      messageHandler)
   }
 
   def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong
@@ -669,3 +688,57 @@ private[kafka] class KafkaUtilsPythonHelper {
     kafkaRDD.offsetRanges.toSeq.asJava
   }
 }
+
+private object KafkaUtilsPythonHelper {
+  private var initialized = false
+
+  def initialize(): Unit = {
+    SerDeUtil.initialize()
+    synchronized {
+      if (!initialized) {
+        new PythonMessageAndMetadataPickler().register()
+        initialized = true
+      }
+    }
+  }
+
+  initialize()
+
+  def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = {
+    new SerDeUtil.AutoBatchedPickler(iter)
+  }
+
+  case class PythonMessageAndMetadata(
+      topic: String,
+      partition: JInt,
+      offset: JLong,
+      key: Array[Byte],
+      message: Array[Byte])
+
+  class PythonMessageAndMetadataPickler extends IObjectPickler {
+    private val module = "pyspark.streaming.kafka"
+
+    def register(): Unit = {
+      Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this)
+      Pickler.registerCustomPickler(this.getClass, this)
+    }
+
+    def pickle(obj: Object, out: OutputStream, pickler: Pickler) {
+      if (obj == this) {
+        out.write(Opcodes.GLOBAL)
+        out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8))
+      } else {
+        pickler.save(this)
+        val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata]
+        out.write(Opcodes.MARK)
+        pickler.save(msgAndMetaData.topic)
+        pickler.save(msgAndMetaData.partition)
+        pickler.save(msgAndMetaData.offset)
+        pickler.save(msgAndMetaData.key)
+        pickler.save(msgAndMetaData.message)
+        out.write(Opcodes.TUPLE)
+        out.write(Opcodes.REDUCE)
+      }
+    }
+  }
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8b3bc96801e20f52488803f50437ba6e444057f0..eb70d27c34c20453e5b0739c158221a972ad2df8 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -136,6 +136,12 @@ object MimaExcludes {
         // SPARK-11766 add toJson to Vector
         ProblemFilters.exclude[MissingMethodProblem](
           "org.apache.spark.mllib.linalg.Vector.toJson")
+      ) ++ Seq(
+        // SPARK-9065 Support message handler in Kafka Python API
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD")
       )
     case v if v.startsWith("1.5") =>
       Seq(
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 06e159172ab51ad3b2c0fbaf9aa3d00e4453f93c..cdf97ec73aaf9f2b69c16ad4c5cdf7c9728cca23 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -19,12 +19,14 @@ from py4j.protocol import Py4JJavaError
 
 from pyspark.rdd import RDD
 from pyspark.storagelevel import StorageLevel
-from pyspark.serializers import PairDeserializer, NoOpSerializer
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer, PairDeserializer, \
+    NoOpSerializer
 from pyspark.streaming import DStream
 from pyspark.streaming.dstream import TransformedDStream
 from pyspark.streaming.util import TransformFunction
 
-__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
+__all__ = ['Broker', 'KafkaMessageAndMetadata', 'KafkaUtils', 'OffsetRange',
+           'TopicAndPartition', 'utf8_decoder']
 
 
 def utf8_decoder(s):
@@ -82,7 +84,8 @@ class KafkaUtils(object):
 
     @staticmethod
     def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
-                           keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+                           keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
+                           messageHandler=None):
         """
         .. note:: Experimental
 
@@ -107,6 +110,8 @@ class KafkaUtils(object):
                             point of the stream.
         :param keyDecoder:  A function used to decode key (default is utf8_decoder).
         :param valueDecoder:  A function used to decode value (default is utf8_decoder).
+        :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
+                               meta using messageHandler (default is None).
         :return: A DStream object
         """
         if fromOffsets is None:
@@ -116,6 +121,14 @@ class KafkaUtils(object):
         if not isinstance(kafkaParams, dict):
             raise TypeError("kafkaParams should be dict")
 
+        def funcWithoutMessageHandler(k_v):
+            return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
+
+        def funcWithMessageHandler(m):
+            m._set_key_decoder(keyDecoder)
+            m._set_value_decoder(valueDecoder)
+            return messageHandler(m)
+
         try:
             helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
                 .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
@@ -123,20 +136,28 @@ class KafkaUtils(object):
 
             jfromOffsets = dict([(k._jTopicAndPartition(helper),
                                   v) for (k, v) in fromOffsets.items()])
-            jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+            if messageHandler is None:
+                ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+                func = funcWithoutMessageHandler
+                jstream = helper.createDirectStreamWithoutMessageHandler(
+                    ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+            else:
+                ser = AutoBatchedSerializer(PickleSerializer())
+                func = funcWithMessageHandler
+                jstream = helper.createDirectStreamWithMessageHandler(
+                    ssc._jssc, kafkaParams, set(topics), jfromOffsets)
         except Py4JJavaError as e:
             if 'ClassNotFoundException' in str(e.java_exception):
                 KafkaUtils._printErrorMsg(ssc.sparkContext)
             raise e
 
-        ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-        stream = DStream(jstream, ssc, ser) \
-            .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        stream = DStream(jstream, ssc, ser).map(func)
         return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
 
     @staticmethod
     def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
-                  keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+                  keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
+                  messageHandler=None):
         """
         .. note:: Experimental
 
@@ -149,6 +170,8 @@ class KafkaUtils(object):
             map, in which case leaders will be looked up on the driver.
         :param keyDecoder:  A function used to decode key (default is utf8_decoder)
         :param valueDecoder:  A function used to decode value (default is utf8_decoder)
+        :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
+                               meta using messageHandler (default is None).
         :return: A RDD object
         """
         if leaders is None:
@@ -158,6 +181,14 @@ class KafkaUtils(object):
         if not isinstance(offsetRanges, list):
             raise TypeError("offsetRanges should be list")
 
+        def funcWithoutMessageHandler(k_v):
+            return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
+
+        def funcWithMessageHandler(m):
+            m._set_key_decoder(keyDecoder)
+            m._set_value_decoder(valueDecoder)
+            return messageHandler(m)
+
         try:
             helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
                 .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
@@ -165,15 +196,21 @@ class KafkaUtils(object):
             joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
             jleaders = dict([(k._jTopicAndPartition(helper),
                               v._jBroker(helper)) for (k, v) in leaders.items()])
-            jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders)
+            if messageHandler is None:
+                jrdd = helper.createRDDWithoutMessageHandler(
+                    sc._jsc, kafkaParams, joffsetRanges, jleaders)
+                ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+                rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
+            else:
+                jrdd = helper.createRDDWithMessageHandler(
+                    sc._jsc, kafkaParams, joffsetRanges, jleaders)
+                rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
         except Py4JJavaError as e:
             if 'ClassNotFoundException' in str(e.java_exception):
                 KafkaUtils._printErrorMsg(sc)
             raise e
 
-        ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-        rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
-        return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
+        return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
 
     @staticmethod
     def _printErrorMsg(sc):
@@ -365,3 +402,53 @@ class KafkaTransformedDStream(TransformedDStream):
         dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
         self._jdstream_val = dstream.asJavaDStream()
         return self._jdstream_val
+
+
+class KafkaMessageAndMetadata(object):
+    """
+    Kafka message and metadata information. Including topic, partition, offset and message
+    """
+
+    def __init__(self, topic, partition, offset, key, message):
+        """
+        Python wrapper of Kafka MessageAndMetadata
+        :param topic: topic name of this Kafka message
+        :param partition: partition id of this Kafka message
+        :param offset: Offset of this Kafka message in the specific partition
+        :param key: key payload of this Kafka message, can be null if this Kafka message has no key
+                    specified, the return data is undecoded bytearry.
+        :param message: actual message payload of this Kafka message, the return data is
+                        undecoded bytearray.
+        """
+        self.topic = topic
+        self.partition = partition
+        self.offset = offset
+        self._rawKey = key
+        self._rawMessage = message
+        self._keyDecoder = utf8_decoder
+        self._valueDecoder = utf8_decoder
+
+    def __str__(self):
+        return "KafkaMessageAndMetadata(topic: %s, partition: %d, offset: %d, key and message...)" \
+               % (self.topic, self.partition, self.offset)
+
+    def __repr__(self):
+        return self.__str__()
+
+    def __reduce__(self):
+        return (KafkaMessageAndMetadata,
+                (self.topic, self.partition, self.offset, self._rawKey, self._rawMessage))
+
+    def _set_key_decoder(self, decoder):
+        self._keyDecoder = decoder
+
+    def _set_value_decoder(self, decoder):
+        self._valueDecoder = decoder
+
+    @property
+    def key(self):
+        return self._keyDecoder(self._rawKey)
+
+    @property
+    def message(self):
+        return self._valueDecoder(self._rawMessage)
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index ff95639146e59ee817be8fb3b7ef7529af761038..0bcd1f15532b56e6e8093b36c5a2058fff0fb8b0 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1042,6 +1042,41 @@ class KafkaStreamTests(PySparkStreamingTestCase):
         self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
         self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
 
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_rdd_message_handler(self):
+        """Test Python direct Kafka RDD MessageHandler."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 1, "c": 2}
+        offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
+
+        def getKeyAndDoubleMessage(m):
+            return m and (m.key, m.message * 2)
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+        rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges,
+                                   messageHandler=getKeyAndDoubleMessage)
+        self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd)
+
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_message_handler(self):
+        """Test the Python direct Kafka stream MessageHandler."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        def getKeyAndDoubleMessage(m):
+            return m and (m.key, m.message * 2)
+
+        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams,
+                                               messageHandler=getKeyAndDoubleMessage)
+        self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream)
+
 
 class FlumeStreamTests(PySparkStreamingTestCase):
     timeout = 20  # seconds