diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 4dd5ea214d6784f99051a4ee316efce8e9ab1c7e..f8ff26b5795be16807a80bf796bb9291708cbf01 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
 import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD}
 import org.apache.spark.rdd.RDD
@@ -381,7 +382,7 @@ object RandomRDDs {
    * @param size Size of the RDD.
    * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
    * @param seed Random seed (default: a random long integer).
-   * @return RDD[Double] comprised of `i.i.d.` samples produced by generator.
+   * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
    */
   @DeveloperApi
   @Since("1.1.0")
@@ -394,6 +395,55 @@ object RandomRDDs {
     new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed)
   }
 
+  /**
+   * :: DeveloperApi ::
+   * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator.
+   *
+   * @param jsc JavaSparkContext used to create the RDD.
+   * @param generator RandomDataGenerator used to populate the RDD.
+   * @param size Size of the RDD.
+   * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
+   * @param seed Random seed (default: a random long integer).
+   * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+      jsc: JavaSparkContext,
+      generator: RandomDataGenerator[T],
+      size: Long,
+      numPartitions: Int,
+      seed: Long): JavaRDD[T] = {
+    implicit val ctag: ClassTag[T] = fakeClassTag
+    val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed)
+    JavaRDD.fromRDD(rdd)
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaRDD]] with the default seed.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+    jsc: JavaSparkContext,
+    generator: RandomDataGenerator[T],
+    size: Long,
+    numPartitions: Int): JavaRDD[T] = {
+    randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong())
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+    jsc: JavaSparkContext,
+    generator: RandomDataGenerator[T],
+    size: Long): JavaRDD[T] = {
+    randomJavaRDD(jsc, generator, size, 0);
+  }
+
   // TODO Generate RDD[Vector] from multivariate distributions.
 
   /**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index 33d81b1e9592b233b93a530e57266e07dd358164..fce5f6712f4628396c03319047b6504d15d0130a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.random;
 
+import java.io.Serializable;
 import java.util.Arrays;
 
 import org.apache.spark.api.java.JavaRDD;
@@ -231,4 +232,33 @@ public class JavaRandomRDDsSuite {
     }
   }
 
+  @Test
+  public void testArbitrary() {
+    long size = 10;
+    long seed = 1L;
+    int numPartitions = 0;
+    StringGenerator gen = new StringGenerator();
+    JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
+    JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
+    JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
+    for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+      Assert.assertEquals(size, rdd.count());
+      Assert.assertEquals(2, rdd.first().length());
+    }
+  }
+}
+
+// This is just a test generator, it always returns a string of 42
+class StringGenerator implements RandomDataGenerator<String>, Serializable {
+  @Override
+  public String nextValue() {
+    return "42";
+  }
+  @Override
+  public StringGenerator copy() {
+    return new StringGenerator();
+  }
+  @Override
+  public void setSeed(long seed) {
+  }
 }