diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java
index 279639af5d4302e2d1ce4e0ca57afc7feb37b500..07aebb75e8f4e85de20be89394efe2260b6e4f39 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java
@@ -25,5 +25,5 @@ import java.util.Iterator;
  * Datasets.
  */
 public interface CoGroupFunction<K, V1, V2, R> extends Serializable {
-  Iterable<R> call(K key, Iterator<V1> left, Iterator<V2> right) throws Exception;
+  Iterator<R> call(K key, Iterator<V1> left, Iterator<V2> right) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 57fd0a7a80494cb2fb91c5e1c299095621dad4f3..576087b6f428e847092216d2918b2dde9b03961d 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -18,10 +18,11 @@
 package org.apache.spark.api.java.function;
 
 import java.io.Serializable;
+import java.util.Iterator;
 
 /**
  * A function that returns zero or more records of type Double from each input record.
  */
 public interface DoubleFlatMapFunction<T> extends Serializable {
-  public Iterable<Double> call(T t) throws Exception;
+  Iterator<Double> call(T t) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
index ef0d1824121ec8f96aa539ff1973dd95668133a8..2d8ea6d1a5a7e03f049aecf9ceac78ecb50688b6 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
@@ -18,10 +18,11 @@
 package org.apache.spark.api.java.function;
 
 import java.io.Serializable;
+import java.util.Iterator;
 
 /**
  * A function that returns zero or more output records from each input record.
  */
 public interface FlatMapFunction<T, R> extends Serializable {
-  Iterable<R> call(T t) throws Exception;
+  Iterator<R> call(T t) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
index 14a98a38ef5ab70085c7b41d053df14e6b7572cf..fc97b63f825d09dae22461977c2822874cc9b849 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
@@ -18,10 +18,11 @@
 package org.apache.spark.api.java.function;
 
 import java.io.Serializable;
+import java.util.Iterator;
 
 /**
  * A function that takes two inputs and returns zero or more output records.
  */
 public interface FlatMapFunction2<T1, T2, R> extends Serializable {
-  Iterable<R> call(T1 t1, T2 t2) throws Exception;
+  Iterator<R> call(T1 t1, T2 t2) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java
index d7a80e7b129b03ced260e686f6f1323d1fb5187f..bae574ab5755d77b70ef4b2654abbc193900f02b 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java
@@ -24,5 +24,5 @@ import java.util.Iterator;
  * A function that returns zero or more output records from each grouping key and its values.
  */
 public interface FlatMapGroupsFunction<K, V, R> extends Serializable {
-  Iterable<R> call(K key, Iterator<V> values) throws Exception;
+  Iterator<R> call(K key, Iterator<V> values) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java
index 6cb569ce0cb6b197ecc80570a7a332798ac9863d..cf9945a215affa9231ae875b3457f5bbb80e2cde 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java
@@ -24,5 +24,5 @@ import java.util.Iterator;
  * Base interface for function used in Dataset's mapPartitions.
  */
 public interface MapPartitionsFunction<T, U> extends Serializable {
-  Iterable<U> call(Iterator<T> input) throws Exception;
+  Iterator<U> call(Iterator<T> input) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
index 691ef2eceb1f66c87069679d116a1cf7f0e7d60e..51eed2e67b9fad8a7b4e0370e7a531f603567e6a 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -18,6 +18,7 @@
 package org.apache.spark.api.java.function;
 
 import java.io.Serializable;
+import java.util.Iterator;
 
 import scala.Tuple2;
 
@@ -26,5 +27,5 @@ import scala.Tuple2;
  * key-value pairs are represented as scala.Tuple2 objects.
  */
 public interface PairFlatMapFunction<T, K, V> extends Serializable {
-  public Iterable<Tuple2<K, V>> call(T t) throws Exception;
+  Iterator<Tuple2<K, V>> call(T t) throws Exception;
 }
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 0f8d13cf5cc2f6c664fac9cce60c7b337195cae5..7340defabfe58af6e22bd148960ade9ec9346cd5 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -121,7 +121,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    *  RDD, and then flattening the results.
    */
   def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
-    def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala
+    def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala
     JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U])
   }
 
@@ -130,7 +130,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    *  RDD, and then flattening the results.
    */
   def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
-    def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala
+    def fn: (T) => Iterator[jl.Double] = (x: T) => f.call(x).asScala
     new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue()))
   }
 
@@ -139,7 +139,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    *  RDD, and then flattening the results.
    */
   def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
-    def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala
+    def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala
     def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]]
     JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
   }
@@ -149,7 +149,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    */
   def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
     def fn: (Iterator[T]) => Iterator[U] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U])
   }
@@ -160,7 +160,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U],
       preservesPartitioning: Boolean): JavaRDD[U] = {
     def fn: (Iterator[T]) => Iterator[U] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     JavaRDD.fromRDD(
       rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U])
@@ -171,7 +171,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    */
   def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
     def fn: (Iterator[T]) => Iterator[jl.Double] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue()))
   }
@@ -182,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
   JavaPairRDD[K2, V2] = {
     def fn: (Iterator[T]) => Iterator[(K2, V2)] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2])
   }
@@ -193,7 +193,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]],
       preservesPartitioning: Boolean): JavaDoubleRDD = {
     def fn: (Iterator[T]) => Iterator[jl.Double] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning)
       .map(x => x.doubleValue()))
@@ -205,7 +205,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2],
       preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = {
     def fn: (Iterator[T]) => Iterator[(K2, V2)] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     JavaPairRDD.fromRDD(
       rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2])
@@ -290,7 +290,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
       other: JavaRDDLike[U, _],
       f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = {
     def fn: (Iterator[T], Iterator[U]) => Iterator[V] = {
-      (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala
+      (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala
     }
     JavaRDD.fromRDD(
       rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 44d5cac7c2de5a8aa64dc4b4eb5ecbe43245cfaa..8117ad9e606419c58be0fea23e293d9b6e507466 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -880,8 +880,8 @@ public class JavaAPISuite implements Serializable {
       "The quick brown fox jumps over the lazy dog."));
     JavaRDD<String> words = rdd.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Arrays.asList(x.split(" "));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(x.split(" ")).iterator();
       }
     });
     Assert.assertEquals("Hello", words.first());
@@ -890,12 +890,12 @@ public class JavaAPISuite implements Serializable {
     JavaPairRDD<String, String> pairsRDD = rdd.flatMapToPair(
       new PairFlatMapFunction<String, String, String>() {
         @Override
-        public Iterable<Tuple2<String, String>> call(String s) {
+        public Iterator<Tuple2<String, String>> call(String s) {
           List<Tuple2<String, String>> pairs = new LinkedList<>();
           for (String word : s.split(" ")) {
             pairs.add(new Tuple2<>(word, word));
           }
-          return pairs;
+          return pairs.iterator();
         }
       }
     );
@@ -904,12 +904,12 @@ public class JavaAPISuite implements Serializable {
 
     JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction<String>() {
       @Override
-      public Iterable<Double> call(String s) {
+      public Iterator<Double> call(String s) {
         List<Double> lengths = new LinkedList<>();
         for (String word : s.split(" ")) {
           lengths.add((double) word.length());
         }
-        return lengths;
+        return lengths.iterator();
       }
     });
     Assert.assertEquals(5.0, doubles.first(), 0.01);
@@ -930,8 +930,8 @@ public class JavaAPISuite implements Serializable {
     JavaPairRDD<String, Integer> swapped = pairRDD.flatMapToPair(
       new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
         @Override
-        public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) {
-          return Collections.singletonList(item.swap());
+        public Iterator<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) {
+          return Collections.singletonList(item.swap()).iterator();
         }
       });
     swapped.collect();
@@ -951,12 +951,12 @@ public class JavaAPISuite implements Serializable {
     JavaRDD<Integer> partitionSums = rdd.mapPartitions(
       new FlatMapFunction<Iterator<Integer>, Integer>() {
         @Override
-        public Iterable<Integer> call(Iterator<Integer> iter) {
+        public Iterator<Integer> call(Iterator<Integer> iter) {
           int sum = 0;
           while (iter.hasNext()) {
             sum += iter.next();
           }
-          return Collections.singletonList(sum);
+          return Collections.singletonList(sum).iterator();
         }
     });
     Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
@@ -1367,8 +1367,8 @@ public class JavaAPISuite implements Serializable {
     FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
       new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
         @Override
-        public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
-          return Arrays.asList(Iterators.size(i), Iterators.size(s));
+        public Iterator<Integer> call(Iterator<Integer> i, Iterator<String> s) {
+          return Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator();
         }
       };
 
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 93c34efb6662ded6b1f065779c7d2ea2909655fc..7e681b67cf0c2b9a0e56d7d7826f3a9d1e00cf4a 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -165,8 +165,8 @@ space into words.
 // Split each line into words
 JavaDStream<String> words = lines.flatMap(
   new FlatMapFunction<String, String>() {
-    @Override public Iterable<String> call(String x) {
-      return Arrays.asList(x.split(" "));
+    @Override public Iterator<String> call(String x) {
+      return Arrays.asList(x.split(" ")).iterator();
     }
   });
 {% endhighlight %}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index a5db8accdf1389fc4a3524cb55992193b0936099..635fb6a373c47d528831d32a4285186a4df725ee 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -17,7 +17,10 @@
 
 package org.apache.spark.examples;
 
-
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Iterator;
+import java.util.regex.Pattern;
 
 import scala.Tuple2;
 
@@ -32,11 +35,6 @@ import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Iterator;
-import java.util.regex.Pattern;
-
 /**
  * Computes the PageRank of URLs from an input file. Input file should
  * be in format of:
@@ -108,13 +106,13 @@ public final class JavaPageRank {
       JavaPairRDD<String, Double> contribs = links.join(ranks).values()
         .flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() {
           @Override
-          public Iterable<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) {
+          public Iterator<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) {
             int urlCount = Iterables.size(s._1);
-            List<Tuple2<String, Double>> results = new ArrayList<Tuple2<String, Double>>();
+            List<Tuple2<String, Double>> results = new ArrayList<>();
             for (String n : s._1) {
-              results.add(new Tuple2<String, Double>(n, s._2() / urlCount));
+              results.add(new Tuple2<>(n, s._2() / urlCount));
             }
-            return results;
+            return results.iterator();
           }
       });
 
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
index 9a6a944f7edef72e60b5f22f6d4d89c7728cadb8..d746a3d2b677353ddd24a173605cef4f3c19e344 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
@@ -27,6 +27,7 @@ import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
 
 import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
 import java.util.regex.Pattern;
 
@@ -46,8 +47,8 @@ public final class JavaWordCount {
 
     JavaRDD<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String s) {
-        return Arrays.asList(SPACE.split(s));
+      public Iterator<String> call(String s) {
+        return Arrays.asList(SPACE.split(s)).iterator();
       }
     });
 
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java
index 62e563380a9e7eddd9b0c5679bc727635b6e9fef..cf774667f6c5f179f8847f7828e4bc5ef1f72e62 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java
@@ -18,6 +18,7 @@
 package org.apache.spark.examples.streaming;
 
 import java.util.Arrays;
+import java.util.Iterator;
 
 import scala.Tuple2;
 
@@ -120,8 +121,8 @@ public class JavaActorWordCount {
     // compute wordcount
     lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String s) {
-        return Arrays.asList(s.split("\\s+"));
+      public Iterator<String> call(String s) {
+        return Arrays.asList(s.split("\\s+")).iterator();
       }
     }).mapToPair(new PairFunction<String, String, Integer>() {
       @Override
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
index 4b50fbf59f80ecb3beeace52d70b597f27d8adc1..3d668adcf815f433eb992e4baa21ed6b5075bdcb 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.examples.streaming;
 
-import com.google.common.collect.Lists;
 import com.google.common.io.Closeables;
 
 import org.apache.spark.SparkConf;
@@ -37,6 +36,8 @@ import java.io.BufferedReader;
 import java.io.InputStreamReader;
 import java.net.ConnectException;
 import java.net.Socket;
+import java.util.Arrays;
+import java.util.Iterator;
 import java.util.regex.Pattern;
 
 /**
@@ -74,8 +75,8 @@ public class JavaCustomReceiver extends Receiver<String> {
       new JavaCustomReceiver(args[0], Integer.parseInt(args[1])));
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Lists.newArrayList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
     JavaPairDStream<String, Integer> wordCounts = words.mapToPair(
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
index f9a5e7f69ffe129938c13b778c6ce6dc0ee81763..5107500a127c5c44e66736eb6405ae0569dce2e4 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
@@ -20,11 +20,11 @@ package org.apache.spark.examples.streaming;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Arrays;
+import java.util.Iterator;
 import java.util.regex.Pattern;
 
 import scala.Tuple2;
 
-import com.google.common.collect.Lists;
 import kafka.serializer.StringDecoder;
 
 import org.apache.spark.SparkConf;
@@ -87,8 +87,8 @@ public final class JavaDirectKafkaWordCount {
     });
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Lists.newArrayList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
     JavaPairDStream<String, Integer> wordCounts = words.mapToPair(
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
index 337f8ffb5bfb00b06517961bbdb13230f8ed07ef..0df4cb40a9a76e1e6c546534ee365284c5974a34 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
@@ -17,20 +17,19 @@
 
 package org.apache.spark.examples.streaming;
 
+import java.util.Arrays;
+import java.util.Iterator;
 import java.util.Map;
 import java.util.HashMap;
 import java.util.regex.Pattern;
 
-
 import scala.Tuple2;
 
-import com.google.common.collect.Lists;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.examples.streaming.StreamingExamples;
 import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
@@ -88,8 +87,8 @@ public final class JavaKafkaWordCount {
 
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Lists.newArrayList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
 
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
index 3e9f0f4b8f1276df14a1d6cf1c20a730716b3bb5..b82b319acb735e4205fa614b9895dea36f7d048f 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
@@ -17,8 +17,11 @@
 
 package org.apache.spark.examples.streaming;
 
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.regex.Pattern;
+
 import scala.Tuple2;
-import com.google.common.collect.Lists;
 
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.function.FlatMapFunction;
@@ -31,8 +34,6 @@ import org.apache.spark.streaming.api.java.JavaPairDStream;
 import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
 import org.apache.spark.streaming.api.java.JavaStreamingContext;
 
-import java.util.regex.Pattern;
-
 /**
  * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
  *
@@ -67,8 +68,8 @@ public final class JavaNetworkWordCount {
             args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER);
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Lists.newArrayList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
     JavaPairDStream<String, Integer> wordCounts = words.mapToPair(
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
index bc963a02be6089e4981d20cdcca2600b7ae1b28a..bc8cbcdef7272211d8f858cf1891448caeca6a2b 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
@@ -21,11 +21,11 @@ import java.io.File;
 import java.io.IOException;
 import java.nio.charset.Charset;
 import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
 import java.util.regex.Pattern;
 
 import scala.Tuple2;
-import com.google.common.collect.Lists;
 import com.google.common.io.Files;
 
 import org.apache.spark.Accumulator;
@@ -138,8 +138,8 @@ public final class JavaRecoverableNetworkWordCount {
     JavaReceiverInputDStream<String> lines = ssc.socketTextStream(ip, port);
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Lists.newArrayList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
     JavaPairDStream<String, Integer> wordCounts = words.mapToPair(
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java
index 084f68a8be4375d2fd31b4ec79e764539365433a..f0228f5e6345049df78fcafcbb7aa8fdffdd0f2d 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java
@@ -17,10 +17,10 @@
 
 package org.apache.spark.examples.streaming;
 
+import java.util.Arrays;
+import java.util.Iterator;
 import java.util.regex.Pattern;
 
-import com.google.common.collect.Lists;
-
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
@@ -72,8 +72,8 @@ public final class JavaSqlNetworkWordCount {
         args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER);
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Lists.newArrayList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
 
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index f52cc7c20576b0797f63805268a52bf40241170c..6beab90f086d81ba016873732ec56a913cca00e9 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -18,6 +18,7 @@
 package org.apache.spark.examples.streaming;
 
 import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
 import java.util.regex.Pattern;
 
@@ -73,8 +74,8 @@ public class JavaStatefulNetworkWordCount {
 
     JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Arrays.asList(SPACE.split(x));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(SPACE.split(x)).iterator();
       }
     });
 
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java
index d869768026ae32fe60bdb6c30b1d9e1db5598636..f0ae9a99bae47d32580c14ef9d7d9dc27a2f10ba 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java
@@ -34,6 +34,7 @@ import scala.Tuple2;
 import twitter4j.Status;
 
 import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
 
 /**
@@ -70,8 +71,8 @@ public class JavaTwitterHashTagJoinSentiments {
 
     JavaDStream<String> words = stream.flatMap(new FlatMapFunction<Status, String>() {
       @Override
-      public Iterable<String> call(Status s) {
-        return Arrays.asList(s.getText().split(" "));
+      public Iterator<String> call(Status s) {
+        return Arrays.asList(s.getText().split(" ")).iterator();
       }
     });
 
diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java
index 27d494ce355f781eea2ada13a90baf170a1f2840..c0b58e713f642018193e994a851bacfb571da4ee 100644
--- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java
+++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java
@@ -294,7 +294,7 @@ public class Java8APISuite implements Serializable {
           sizeS += 1;
           s.next();
         }
-        return Arrays.asList(sizeI, sizeS);
+        return Arrays.asList(sizeI, sizeS).iterator();
       };
     JavaRDD<Integer> sizes = rdd1.zipPartitions(rdd2, sizesFn);
     Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
index 06e0ff28afd953fc1f6dfedc605d737ff6c47a3f..64e044aa8e4a4bafa5a311cebfc10f0a491f8e43 100644
--- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
+++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
@@ -16,7 +16,10 @@
  */
 package org.apache.spark.examples.streaming;
 
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
 import java.util.regex.Pattern;
 
@@ -38,7 +41,6 @@ import scala.Tuple2;
 import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
 import com.amazonaws.services.kinesis.AmazonKinesisClient;
 import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
-import com.google.common.collect.Lists;
 
 /**
  * Consumes messages from a Amazon Kinesis streams and does wordcount.
@@ -154,8 +156,9 @@ public final class JavaKinesisWordCountASL { // needs to be public for access fr
     // Convert each line of Array[Byte] to String, and split into words
     JavaDStream<String> words = unionStreams.flatMap(new FlatMapFunction<byte[], String>() {
       @Override
-      public Iterable<String> call(byte[] line) {
-        return Lists.newArrayList(WORD_SEPARATOR.split(new String(line)));
+      public Iterator<String> call(byte[] line) {
+        String s = new String(line, StandardCharsets.UTF_8);
+        return Arrays.asList(WORD_SEPARATOR.split(s)).iterator();
       }
     });
 
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 501456b0431701a43421cde66fe3c1d0d5b13163..643bee69694dfabc855733c8c282d8ae5150dc00 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -60,6 +60,37 @@ object MimaExcludes {
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory")
       ) ++
+      Seq(
+        // SPARK-3369 Fix Iterable/Iterator in Java API
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.FlatMapFunction.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.FlatMapFunction.call"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.FlatMapFunction2.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.FlatMapFunction2.call"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.PairFlatMapFunction.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.PairFlatMapFunction.call"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.CoGroupFunction.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.CoGroupFunction.call"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.MapPartitionsFunction.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.MapPartitionsFunction.call"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem](
+          "org.apache.spark.api.java.function.FlatMapGroupsFunction.call"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.api.java.function.FlatMapGroupsFunction.call")
+      ) ++
       Seq(
         // SPARK-4819 replace Guava Optional
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index bd99c399571c813969d8a646160b3e2f51e795d4..f182270a08729ce96df94b29b2f613798efa4e51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -346,7 +346,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala
+    val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
     mapPartitions(func)(encoder)
   }
 
@@ -366,7 +366,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    val func: (T) => Iterable[U] = x => f.call(x).asScala
+    val func: (T) => Iterator[U] = x => f.call(x).asScala
     flatMap(func)(encoder)
   }
 
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 3c0f25a5dc53531cd16e06d4da7b5e678449b5d9..a6fb62c17d59b08bf2317375c832b66ee867df67 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -111,24 +111,24 @@ public class JavaDatasetSuite implements Serializable {
 
     Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
       @Override
-      public Iterable<String> call(Iterator<String> it) throws Exception {
-        List<String> ls = new LinkedList<String>();
+      public Iterator<String> call(Iterator<String> it) {
+        List<String> ls = new LinkedList<>();
         while (it.hasNext()) {
-          ls.add(it.next().toUpperCase());
+          ls.add(it.next().toUpperCase(Locale.ENGLISH));
         }
-        return ls;
+        return ls.iterator();
       }
     }, Encoders.STRING());
     Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
 
     Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String s) throws Exception {
-        List<String> ls = new LinkedList<String>();
+      public Iterator<String> call(String s) {
+        List<String> ls = new LinkedList<>();
         for (char c : s.toCharArray()) {
           ls.add(String.valueOf(c));
         }
-        return ls;
+        return ls.iterator();
       }
     }, Encoders.STRING());
     Assert.assertEquals(
@@ -192,12 +192,12 @@ public class JavaDatasetSuite implements Serializable {
     Dataset<String> flatMapped = grouped.flatMapGroups(
       new FlatMapGroupsFunction<Integer, String, String>() {
         @Override
-        public Iterable<String> call(Integer key, Iterator<String> values) throws Exception {
+        public Iterator<String> call(Integer key, Iterator<String> values) {
           StringBuilder sb = new StringBuilder(key.toString());
           while (values.hasNext()) {
             sb.append(values.next());
           }
-          return Collections.singletonList(sb.toString());
+          return Collections.singletonList(sb.toString()).iterator();
         }
       },
       Encoders.STRING());
@@ -228,10 +228,7 @@ public class JavaDatasetSuite implements Serializable {
       grouped2,
       new CoGroupFunction<Integer, String, Integer, String>() {
         @Override
-        public Iterable<String> call(
-          Integer key,
-          Iterator<String> left,
-          Iterator<Integer> right) throws Exception {
+        public Iterator<String> call(Integer key, Iterator<String> left, Iterator<Integer> right) {
           StringBuilder sb = new StringBuilder(key.toString());
           while (left.hasNext()) {
             sb.append(left.next());
@@ -240,7 +237,7 @@ public class JavaDatasetSuite implements Serializable {
           while (right.hasNext()) {
             sb.append(right.next());
           }
-          return Collections.singletonList(sb.toString());
+          return Collections.singletonList(sb.toString()).iterator();
         }
       },
       Encoders.STRING());
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index a791a474c673dd27d55cbbc4eaa187544de447fa..f10de485d0f75c6d66204d97c41b9d3425eeca26 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -166,8 +166,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
    * and then flattening the results
    */
   def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = {
-    import scala.collection.JavaConverters._
-    def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala
+    def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala
     new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U])
   }
 
@@ -176,8 +175,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
    * and then flattening the results
    */
   def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = {
-    import scala.collection.JavaConverters._
-    def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala
+    def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala
     def cm: ClassTag[(K2, V2)] = fakeClassTag
     new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
   }
@@ -189,7 +187,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
    */
   def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = {
     def fn: (Iterator[T]) => Iterator[U] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U])
   }
@@ -202,7 +200,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
   def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2])
   : JavaPairDStream[K2, V2] = {
     def fn: (Iterator[T]) => Iterator[(K2, V2)] = {
-      (x: Iterator[T]) => f.call(x.asJava).iterator().asScala
+      (x: Iterator[T]) => f.call(x.asJava).asScala
     }
     new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2])
   }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 1dfb4e7abc0ed7fe321c6fcd84275c916bbf0fe8..db79eeab9c0cc2e5868d3770962fbe84faf65418 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -550,7 +550,7 @@ abstract class DStream[T: ClassTag] (
    * Return a new DStream by applying a function to all elements of this DStream,
    * and then flattening the results
    */
-  def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope {
+  def flatMap[U: ClassTag](flatMapFunc: T => TraversableOnce[U]): DStream[U] = ssc.withScope {
     new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc))
   }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
index 96a444a7baa5e940dcf692f58cfbaaca961b3181..d60a6179782e03b91335a2a9acb5b1349accb31c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
@@ -25,7 +25,7 @@ import org.apache.spark.streaming.{Duration, Time}
 private[streaming]
 class FlatMappedDStream[T: ClassTag, U: ClassTag](
     parent: DStream[T],
-    flatMapFunc: T => Traversable[U]
+    flatMapFunc: T => TraversableOnce[U]
   ) extends DStream[U](parent.ssc) {
 
   override def dependencies: List[DStream[_]] = List(parent)
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 4dbcef293487c210e205ccc7f3a7c39ad2981cb6..806cea24caddbb5625cae15e9e5595434c8bf80d 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -271,12 +271,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     JavaDStream<String> mapped = stream.mapPartitions(
         new FlatMapFunction<Iterator<String>, String>() {
           @Override
-          public Iterable<String> call(Iterator<String> in) {
+          public Iterator<String> call(Iterator<String> in) {
             StringBuilder out = new StringBuilder();
             while (in.hasNext()) {
               out.append(in.next().toUpperCase(Locale.ENGLISH));
             }
-            return Arrays.asList(out.toString());
+            return Arrays.asList(out.toString()).iterator();
           }
         });
     JavaTestUtils.attachTestOutputStream(mapped);
@@ -759,8 +759,8 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
     JavaDStream<String> flatMapped = stream.flatMap(new FlatMapFunction<String, String>() {
       @Override
-      public Iterable<String> call(String x) {
-        return Arrays.asList(x.split("(?!^)"));
+      public Iterator<String> call(String x) {
+        return Arrays.asList(x.split("(?!^)")).iterator();
       }
     });
     JavaTestUtils.attachTestOutputStream(flatMapped);
@@ -846,12 +846,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     JavaPairDStream<Integer, String> flatMapped = stream.flatMapToPair(
       new PairFlatMapFunction<String, Integer, String>() {
         @Override
-        public Iterable<Tuple2<Integer, String>> call(String in) {
+        public Iterator<Tuple2<Integer, String>> call(String in) {
           List<Tuple2<Integer, String>> out = new ArrayList<>();
           for (String letter: in.split("(?!^)")) {
             out.add(new Tuple2<>(in.length(), letter));
           }
-          return out;
+          return out.iterator();
         }
       });
     JavaTestUtils.attachTestOutputStream(flatMapped);
@@ -1019,13 +1019,13 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     JavaPairDStream<Integer, String> reversed = pairStream.mapPartitionsToPair(
         new PairFlatMapFunction<Iterator<Tuple2<String, Integer>>, Integer, String>() {
           @Override
-          public Iterable<Tuple2<Integer, String>> call(Iterator<Tuple2<String, Integer>> in) {
+          public Iterator<Tuple2<Integer, String>> call(Iterator<Tuple2<String, Integer>> in) {
             List<Tuple2<Integer, String>> out = new LinkedList<>();
             while (in.hasNext()) {
               Tuple2<String, Integer> next = in.next();
               out.add(next.swap());
             }
-            return out;
+            return out.iterator();
           }
         });
 
@@ -1089,12 +1089,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     JavaPairDStream<Integer, String> flatMapped = pairStream.flatMapToPair(
         new PairFlatMapFunction<Tuple2<String, Integer>, Integer, String>() {
           @Override
-          public Iterable<Tuple2<Integer, String>> call(Tuple2<String, Integer> in) {
+          public Iterator<Tuple2<Integer, String>> call(Tuple2<String, Integer> in) {
             List<Tuple2<Integer, String>> out = new LinkedList<>();
             for (Character s : in._1().toCharArray()) {
               out.add(new Tuple2<>(in._2(), s.toString()));
             }
-            return out;
+            return out.iterator();
           }
         });
     JavaTestUtils.attachTestOutputStream(flatMapped);