From 870b8a2edd44c9274c43ca0db4ef5b0998e16fd8 Mon Sep 17 00:00:00 2001
From: Meihua Wu <meihuawu@umich.edu>
Date: Tue, 22 Sep 2015 11:05:24 +0100
Subject: [PATCH] [SPARK-10706] [MLLIB] Add java wrapper for random vector rdd

Add java wrapper for random vector rdd

holdenk srowen

Author: Meihua Wu <meihuawu@umich.edu>

Closes #8841 from rotationsymmetry/SPARK-10706.
---
 .../spark/mllib/random/RandomRDDs.scala       | 42 +++++++++++++++++++
 .../mllib/random/JavaRandomRDDsSuite.java     | 17 ++++++++
 2 files changed, 59 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index f8ff26b579..41d7c4d355 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -855,6 +855,48 @@ object RandomRDDs {
       sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed)
   }
 
+  /**
+   * Java-friendly version of [[RandomRDDs#randomVectorRDD]].
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaVectorRDD(
+      jsc: JavaSparkContext,
+      generator: RandomDataGenerator[Double],
+      numRows: Long,
+      numCols: Int,
+      numPartitions: Int,
+      seed: Long): JavaRDD[Vector] = {
+    randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD()
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaVectorRDD]] with the default seed.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaVectorRDD(
+      jsc: JavaSparkContext,
+      generator: RandomDataGenerator[Double],
+      numRows: Long,
+      numCols: Int,
+      numPartitions: Int): JavaRDD[Vector] = {
+    randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD()
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaVectorRDD(
+      jsc: JavaSparkContext,
+      generator: RandomDataGenerator[Double],
+      numRows: Long,
+      numCols: Int): JavaRDD[Vector] = {
+    randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD()
+  }
+
   /**
    * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise.
    */
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index fce5f6712f..5728df5aee 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -246,6 +246,23 @@ public class JavaRandomRDDsSuite {
       Assert.assertEquals(2, rdd.first().length());
     }
   }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testRandomVectorRDD() {
+    UniformGenerator generator = new UniformGenerator();
+    long m = 100L;
+    int n = 10;
+    int p = 2;
+    long seed = 1L;
+    JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
+    JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
+    JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
+    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+      Assert.assertEquals(m, rdd.count());
+      Assert.assertEquals(n, rdd.first().size());
+    }
+  }
 }
 
 // This is just a test generator, it always returns a string of 42
-- 
GitLab