From 20a61dbd9b57957fcc5b58ef8935533914172b07 Mon Sep 17 00:00:00 2001
From: Holden Karau <holden@pigscanfly.ca>
Date: Mon, 21 Sep 2015 18:53:28 +0100
Subject: [PATCH] [SPARK-10626] [MLLIB] create java friendly method for random
 rdd

SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD.
---
 .../spark/mllib/random/RandomRDDs.scala       | 52 ++++++++++++++++++-
 .../mllib/random/JavaRandomRDDsSuite.java     | 30 +++++++++++
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 4dd5ea214d..f8ff26b579 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
 import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD}
 import org.apache.spark.rdd.RDD
@@ -381,7 +382,7 @@ object RandomRDDs {
    * @param size Size of the RDD.
    * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
    * @param seed Random seed (default: a random long integer).
-   * @return RDD[Double] comprised of `i.i.d.` samples produced by generator.
+   * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
    */
   @DeveloperApi
   @Since("1.1.0")
@@ -394,6 +395,55 @@ object RandomRDDs {
     new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed)
   }
 
+  /**
+   * :: DeveloperApi ::
+   * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator.
+   *
+   * @param jsc JavaSparkContext used to create the RDD.
+   * @param generator RandomDataGenerator used to populate the RDD.
+   * @param size Size of the RDD.
+   * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
+   * @param seed Random seed (default: a random long integer).
+   * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+      jsc: JavaSparkContext,
+      generator: RandomDataGenerator[T],
+      size: Long,
+      numPartitions: Int,
+      seed: Long): JavaRDD[T] = {
+    implicit val ctag: ClassTag[T] = fakeClassTag
+    val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed)
+    JavaRDD.fromRDD(rdd)
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaRDD]] with the default seed.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+    jsc: JavaSparkContext,
+    generator: RandomDataGenerator[T],
+    size: Long,
+    numPartitions: Int): JavaRDD[T] = {
+    randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong())
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+    jsc: JavaSparkContext,
+    generator: RandomDataGenerator[T],
+    size: Long): JavaRDD[T] = {
+    randomJavaRDD(jsc, generator, size, 0);
+  }
+
   // TODO Generate RDD[Vector] from multivariate distributions.
 
   /**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index 33d81b1e95..fce5f6712f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.random;
 
+import java.io.Serializable;
 import java.util.Arrays;
 
 import org.apache.spark.api.java.JavaRDD;
@@ -231,4 +232,33 @@ public class JavaRandomRDDsSuite {
     }
   }
 
+  @Test
+  public void testArbitrary() {
+    long size = 10;
+    long seed = 1L;
+    int numPartitions = 0;
+    StringGenerator gen = new StringGenerator();
+    JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
+    JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
+    JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
+    for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+      Assert.assertEquals(size, rdd.count());
+      Assert.assertEquals(2, rdd.first().length());
+    }
+  }
+}
+
+// This is just a test generator, it always returns a string of 42
+class StringGenerator implements RandomDataGenerator<String>, Serializable {
+  @Override
+  public String nextValue() {
+    return "42";
+  }
+  @Override
+  public StringGenerator copy() {
+    return new StringGenerator();
+  }
+  @Override
+  public void setSeed(long seed) {
+  }
 }
-- 
GitLab