diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
index 9cc263930e8b9cb240e8d6946b27ea2f22a9974f..ec546c81907c1e6b203b403b29a7e64496c92113 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
@@ -100,8 +100,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
    * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
    * of the RDD.
    */
-  def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V])
-  : JavaPairDStream[K, V] = {
+  def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2])
+  : JavaPairDStream[K2, V2] = {
     def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
     new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType())
   }
diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
index ec4e5ae18b25771fe2acff51c0b36b6c212c1c93..67d82d546fb6d8eddc4e48dc5bf078c2070418d8 100644
--- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -507,7 +507,7 @@ public class JavaAPISuite implements Serializable {
           new Tuple2<String, Integer>("new york", 1)));
 
   @Test
-  public void testPairMap() { // Maps pair -> pair
+  public void testPairMap() { // Maps pair -> pair of different type
     List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
 
     List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
@@ -538,6 +538,43 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(expected, result);
   }
 
+  @Test
+  public void testPairMapPartitions() { // Maps pair -> pair of different type
+    List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+    List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
+        Arrays.asList(
+            new Tuple2<Integer, String>(1, "california"),
+            new Tuple2<Integer, String>(3, "california"),
+            new Tuple2<Integer, String>(4, "new york"),
+            new Tuple2<Integer, String>(1, "new york")),
+        Arrays.asList(
+            new Tuple2<Integer, String>(5, "california"),
+            new Tuple2<Integer, String>(5, "california"),
+            new Tuple2<Integer, String>(3, "new york"),
+            new Tuple2<Integer, String>(1, "new york")));
+
+    JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+    JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+    JavaPairDStream<Integer, String> reversed = pairStream.mapPartitions(
+        new PairFlatMapFunction<Iterator<Tuple2<String, Integer>>, Integer, String>() {
+          @Override
+          public Iterable<Tuple2<Integer, String>> call(Iterator<Tuple2<String, Integer>> in) throws Exception {
+            LinkedList<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>();
+            while (in.hasNext()) {
+              Tuple2<String, Integer> next = in.next();
+              out.add(new Tuple2<Integer, String>(next._2(), next._1()));
+            }
+            return out;
+          }
+        });
+
+    JavaTestUtils.attachTestOutputStream(reversed);
+    List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+    Assert.assertEquals(expected, result);
+  }
+
   @Test
   public void testPairMap2() { // Maps pair -> single
     List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
@@ -588,16 +625,16 @@ public class JavaAPISuite implements Serializable {
         JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
     JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
     JavaPairDStream<Integer, String> flatMapped = pairStream.flatMap(
-      new PairFlatMapFunction<Tuple2<String, Integer>, Integer, String>() {
-      @Override
-      public Iterable<Tuple2<Integer, String>> call(Tuple2<String, Integer> in) throws Exception {
-        List<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>();
-        for (Character s: in._1().toCharArray()) {
-          out.add(new Tuple2<Integer, String>(in._2(), s.toString()));
-        }
-        return out;
-      }
-    });
+        new PairFlatMapFunction<Tuple2<String, Integer>, Integer, String>() {
+          @Override
+          public Iterable<Tuple2<Integer, String>> call(Tuple2<String, Integer> in) throws Exception {
+            List<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>();
+            for (Character s : in._1().toCharArray()) {
+              out.add(new Tuple2<Integer, String>(in._2(), s.toString()));
+            }
+            return out;
+          }
+        });
     JavaTestUtils.attachTestOutputStream(flatMapped);
     List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
 
@@ -668,7 +705,7 @@ public class JavaAPISuite implements Serializable {
 
     JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey(
         new Function<Integer, Integer>() {
-          @Override
+        @Override
           public Integer call(Integer i) throws Exception {
             return i;
           }
@@ -766,19 +803,19 @@ public class JavaAPISuite implements Serializable {
     JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
 
     JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
-      new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>(){
+        new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
         @Override
         public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
           int out = 0;
           if (state.isPresent()) {
             out = out + state.get();
           }
-          for (Integer v: values) {
+          for (Integer v : values) {
             out = out + v;
           }
           return Optional.of(out);
         }
-      });
+    });
     JavaTestUtils.attachTestOutputStream(updated);
     List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);