Skip to content
Snippets Groups Projects
Commit e82d95bf authored by Rekha Joshi's avatar Rekha Joshi Committed by Cheng Lian
Browse files

[SPARK-14372][SQL] Dataset.randomSplit() needs a Java version

## What changes were proposed in this pull request?

1.Added method randomSplitAsList() in Dataset for java
for https://issues.apache.org/jira/browse/SPARK-14372

## How was this patch tested?

TestSuite

Author: Rekha Joshi <rekhajoshm@gmail.com>
Author: Joshi <rekhajoshm@gmail.com>

Closes #12184 from rekhajoshm/SPARK-14372.
parent 1a0cca1f
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,6 @@ package org.apache.spark.sql ...@@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.io.CharArrayWriter import java.io.CharArrayWriter
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal import scala.util.control.NonFatal
...@@ -1493,6 +1492,8 @@ class Dataset[T] private[sql]( ...@@ -1493,6 +1492,8 @@ class Dataset[T] private[sql](
* @param weights weights for splits, will be normalized if they don't sum to 1. * @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling. * @param seed Seed for sampling.
* *
* For Java API, use [[randomSplitAsList]].
*
* @group typedrel * @group typedrel
* @since 2.0.0 * @since 2.0.0
*/ */
...@@ -1510,6 +1511,20 @@ class Dataset[T] private[sql]( ...@@ -1510,6 +1511,20 @@ class Dataset[T] private[sql](
}.toArray }.toArray
} }
/**
* Returns a Java list that contains randomly split [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*
* @group typedrel
* @since 2.0.0
*/
def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
val values = randomSplit(weights, seed)
java.util.Arrays.asList(values : _*)
}
/** /**
* Randomly splits this [[Dataset]] with the provided weights. * Randomly splits this [[Dataset]] with the provided weights.
* *
......
...@@ -454,6 +454,16 @@ public class JavaDatasetSuite implements Serializable { ...@@ -454,6 +454,16 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList()); Assert.assertEquals(data, ds.collectAsList());
} }
@Test
public void testRandomSplit() {
List<String> data = Arrays.asList("hello", "world", "from", "spark");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
double[] arraySplit = {1, 2, 3};
List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1);
Assert.assertEquals("wrong number of splits", randomSplit.size(), 3);
}
/** /**
* For testing error messages when creating an encoder on a private class. This is done * For testing error messages when creating an encoder on a private class. This is done
* here since we cannot create truly private classes in Scala. * here since we cannot create truly private classes in Scala.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment