Skip to content
Snippets Groups Projects
Commit 870b8a2e authored by Meihua Wu's avatar Meihua Wu Committed by Sean Owen
Browse files

[SPARK-10706] [MLLIB] Add java wrapper for random vector rdd

Add java wrapper for random vector rdd

holdenk srowen

Author: Meihua Wu <meihuawu@umich.edu>

Closes #8841 from rotationsymmetry/SPARK-10706.
parent 7278f792
No related branches found
No related tags found
No related merge requests found
......@@ -855,6 +855,48 @@ object RandomRDDs {
sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed)
}
/**
* Java-friendly version of [[RandomRDDs#randomVectorRDD]].
*/
@DeveloperApi
@Since("1.6.0")
def randomJavaVectorRDD(
jsc: JavaSparkContext,
generator: RandomDataGenerator[Double],
numRows: Long,
numCols: Int,
numPartitions: Int,
seed: Long): JavaRDD[Vector] = {
randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD()
}
/**
* [[RandomRDDs#randomJavaVectorRDD]] with the default seed.
*/
@DeveloperApi
@Since("1.6.0")
def randomJavaVectorRDD(
jsc: JavaSparkContext,
generator: RandomDataGenerator[Double],
numRows: Long,
numCols: Int,
numPartitions: Int): JavaRDD[Vector] = {
randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD()
}
/**
* [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed.
*/
@DeveloperApi
@Since("1.6.0")
def randomJavaVectorRDD(
jsc: JavaSparkContext,
generator: RandomDataGenerator[Double],
numRows: Long,
numCols: Int): JavaRDD[Vector] = {
randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD()
}
/**
* Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise.
*/
......
......@@ -246,6 +246,23 @@ public class JavaRandomRDDsSuite {
Assert.assertEquals(2, rdd.first().length());
}
}
@Test
@SuppressWarnings("unchecked")
public void testRandomVectorRDD() {
UniformGenerator generator = new UniformGenerator();
long m = 100L;
int n = 10;
int p = 2;
long seed = 1L;
JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
Assert.assertEquals(m, rdd.count());
Assert.assertEquals(n, rdd.first().size());
}
}
}
// This is just a test generator, it always returns a string of 42
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment