diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
index 9b135a5c54cf376bccfb9f334ccbd070ddb34924..e0c3555f21b2990e0cc7228b5253bb9bfae01252 100644
--- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -37,7 +37,7 @@ object KafkaWordCount {
     ssc.checkpoint("checkpoint")
 
     val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
-    val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap)
+    val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
     val words = lines.flatMap(_.split(" "))
     val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
     wordCounts.print()
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index bb7f216ca7dcd9b2c5e5e43a12b5850c3fadcd08..2c6326943dc99495ee98b62537de0377225037aa 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path
 import java.util.UUID
 import twitter4j.Status
 
+
 /**
  * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
  * information (such as, cluster URL and job name) to internally create a SparkContext, it provides
@@ -207,14 +208,14 @@ class StreamingContext private (
    * @param storageLevel  Storage level to use for storing the received objects
    *                      (default: StorageLevel.MEMORY_AND_DISK_SER_2)
    */
-  def kafkaStream[T: ClassManifest](
+  def kafkaStream(
       zkQuorum: String,
       groupId: String,
       topics: Map[String, Int],
       storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
-    ): DStream[T] = {
+    ): DStream[String] = {
     val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000");
-    kafkaStream[T](kafkaParams, topics, storageLevel)
+    kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
   }
 
   /**
@@ -224,12 +225,12 @@ class StreamingContext private (
    * in its own thread.
    * @param storageLevel  Storage level to use for storing the received objects
    */
-  def kafkaStream[T: ClassManifest](
+  def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
       kafkaParams: Map[String, String],
       topics: Map[String, Int],
       storageLevel: StorageLevel
     ): DStream[T] = {
-    val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel)
+    val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
     registerInputStream(inputStream)
     inputStream
   }
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
index 7a8864614c95f8a39b68cc9709ed402a58bcdb44..b35d9032f1d63bc5749de407fc32412c47d7f065 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
@@ -68,33 +68,54 @@ class JavaStreamingContext(val ssc: StreamingContext) {
    * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
    * in its own thread.
    */
-  def kafkaStream[T](
+  def kafkaStream(
     zkQuorum: String,
     groupId: String,
     topics: JMap[String, JInt])
-  : JavaDStream[T] = {
-    implicit val cmt: ClassManifest[T] =
-      implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
-    ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
+  : JavaDStream[String] = {
+    implicit val cmt: ClassManifest[String] =
+      implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
+    ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2)
   }
 
   /**
    * Create an input stream that pulls messages form a Kafka Broker.
-   * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
    * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
    * @param groupId The group id for this consumer.
    * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+   * @param storageLevel RDD storage level. Defaults to memory-only
+   * in its own thread.
+   */
+  def kafkaStream(
+    zkQuorum: String,
+    groupId: String,
+    topics: JMap[String, JInt],
+    storageLevel: StorageLevel)
+  : JavaDStream[String] = {
+    implicit val cmt: ClassManifest[String] =
+      implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
+    ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel)
+  }
+
+  /**
+   * Create an input stream that pulls messages form a Kafka Broker.
+   * @param typeClass Type of RDD
+   * @param decoderClass Type of kafka decoder
+   * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
+   * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
    * in its own thread.
    * @param storageLevel RDD storage level. Defaults to memory-only
    */
-  def kafkaStream[T](
-	kafkaParams: JMap[String, String],
+  def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
+    typeClass: Class[T],
+    decoderClass: Class[D],
+    kafkaParams: JMap[String, String],
     topics: JMap[String, JInt],
     storageLevel: StorageLevel)
   : JavaDStream[T] = {
-    implicit val cmt: ClassManifest[T] =
-      implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
-    ssc.kafkaStream[T](
+    implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+    implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
+    ssc.kafkaStream[T, D](
       kafkaParams.toMap,
       Map(topics.mapValues(_.intValue()).toSeq: _*),
       storageLevel)
diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
index 17a5be3420d41a051e6aef9c70c802ad9f31c373..7bd53fb6ddcb92789bd29fb53bd8e661e4908370 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
@@ -9,7 +9,7 @@ import java.util.concurrent.Executors
 
 import kafka.consumer._
 import kafka.message.{Message, MessageSet, MessageAndMetadata}
-import kafka.serializer.StringDecoder
+import kafka.serializer.Decoder
 import kafka.utils.{Utils, ZKGroupTopicDirs}
 import kafka.utils.ZkUtils._
 import kafka.utils.ZKStringSerializer
@@ -28,7 +28,7 @@ import scala.collection.JavaConversions._
  * @param storageLevel RDD storage level.
  */
 private[streaming]
-class KafkaInputDStream[T: ClassManifest](
+class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
     @transient ssc_ : StreamingContext,
     kafkaParams: Map[String, String],
     topics: Map[String, Int],
@@ -37,15 +37,17 @@ class KafkaInputDStream[T: ClassManifest](
 
 
   def getReceiver(): NetworkReceiver[T] = {
-    new KafkaReceiver(kafkaParams, topics, storageLevel)
+    new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
         .asInstanceOf[NetworkReceiver[T]]
   }
 }
 
 private[streaming]
-class KafkaReceiver(kafkaParams: Map[String, String],
+class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
+  kafkaParams: Map[String, String],
   topics: Map[String, Int],
-  storageLevel: StorageLevel) extends NetworkReceiver[Any] {
+  storageLevel: StorageLevel
+  ) extends NetworkReceiver[Any] {
 
   // Handles pushing data into the BlockManager
   lazy protected val blockGenerator = new BlockGenerator(storageLevel)
@@ -82,7 +84,8 @@ class KafkaReceiver(kafkaParams: Map[String, String],
     }
 
     // Create Threads for each Topic/Message Stream we are listening
-    val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+    val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
+    val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
 
     // Start the messages handler for each partition
     topicMessageStreams.values.foreach { streams =>
@@ -91,7 +94,7 @@ class KafkaReceiver(kafkaParams: Map[String, String],
   }
 
   // Handles Kafka Messages
-  private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
+  private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
     def run() {
       logInfo("Starting MessageHandler.")
       for (msgAndMetadata <- stream) {
diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
index 3bed500f73e84548a2f470a414ea68cfbb3df1ac..e5fdbe1b7af75ad92a057eed6c2008a10c32b2fd 100644
--- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -4,6 +4,7 @@ import com.google.common.base.Optional;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.io.Files;
+import kafka.serializer.StringDecoder;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
 import org.junit.After;
 import org.junit.Assert;
@@ -23,7 +24,6 @@ import spark.streaming.api.java.JavaPairDStream;
 import spark.streaming.api.java.JavaStreamingContext;
 import spark.streaming.JavaTestUtils;
 import spark.streaming.JavaCheckpointTestUtils;
-import spark.streaming.dstream.KafkaPartitionKey;
 import spark.streaming.InputStreamsSuite;
 
 import java.io.*;
@@ -1203,10 +1203,14 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void testKafkaStream() {
     HashMap<String, Integer> topics = Maps.newHashMap();
-    HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
     JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
-    JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets);
-    JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets,
+    JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
+      StorageLevel.MEMORY_AND_DISK());
+
+    HashMap<String, String> kafkaParams = Maps.newHashMap();
+    kafkaParams.put("zk.connect","localhost:12345");
+    kafkaParams.put("groupid","consumer-group");
+    JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
       StorageLevel.MEMORY_AND_DISK());
   }
 
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index 1024d3ac9790e801a060d90a9faab1d64237515b..595c766a216f03e6a19ead6f2e80cb59448a00a0 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -240,6 +240,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
       assert(output(i) === expectedOutput(i))
     }
   }
+
+  test("kafka input stream") {
+    val ssc = new StreamingContext(master, framework, batchDuration)
+    val topics = Map("my-topic" -> 1)
+    val test1 = ssc.kafkaStream("localhost:12345", "group", topics)
+    val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
+
+    // Test specifying decoder
+    val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
+    val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+  }
 }