Skip to content
Snippets Groups Projects
Commit 68c0c460 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Davies Liu
Browse files

[SPARK-13742] [CORE] Add non-iterator interface to RandomSampler

JIRA: https://issues.apache.org/jira/browse/SPARK-13742

## What changes were proposed in this pull request?

`RandomSampler.sample` currently accepts iterator as input and output another iterator. This makes it inappropriate to use in wholestage codegen of `Sampler` operator #11517. This change is to add non-iterator interface to `RandomSampler`.

This change adds a new method `def sample(): Int` to the trait `RandomSampler`. As we don't need to know the actual values of the sampling items, so this new method takes no arguments.

This method will decide whether to sample the next item or not. It returns how many times the next item will be sampled.

For `BernoulliSampler` and `BernoulliCellSampler`, the returned sampling times can only be 0 or 1. It simply means whether to sample the next item or not.

For `PoissonSampler`, the returned value can be more than 1, meaning the next item will be sampled multiple times.

## How was this patch tested?

Tests are added into `RandomSamplerSuite`.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #11578 from viirya/random-sampler-no-iterator.
parent c8388297
No related branches found
No related tags found
No related merge requests found
...@@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi ...@@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi
trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable { trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {
/** take a random sample */ /** take a random sample */
def sample(items: Iterator[T]): Iterator[U] def sample(items: Iterator[T]): Iterator[U] =
items.filter(_ => sample > 0).asInstanceOf[Iterator[U]]
/**
* Whether to sample the next item or not.
* Return how many times the next item will be sampled. Return 0 if it is not sampled.
*/
def sample(): Int
/** return a copy of the RandomSampler object */ /** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] = override def clone: RandomSampler[T, U] =
...@@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals ...@@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def setSeed(seed: Long): Unit = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = { override def sample(): Int = {
if (ub - lb <= 0.0) { if (ub - lb <= 0.0) {
if (complement) items else Iterator.empty if (complement) 1 else 0
} else { } else {
if (complement) { val x = rng.nextDouble()
items.filter { item => { val n = if ((x >= lb) && (x < ub)) 1 else 0
val x = rng.nextDouble() if (complement) 1 - n else n
(x < lb) || (x >= ub)
}}
} else {
items.filter { item => {
val x = rng.nextDouble()
(x >= lb) && (x < ub)
}}
}
} }
} }
...@@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T ...@@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def setSeed(seed: Long): Unit = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = { private lazy val gapSampling: GapSampling =
new GapSampling(fraction, rng, RandomSampler.rngEpsilon)
override def sample(): Int = {
if (fraction <= 0.0) { if (fraction <= 0.0) {
Iterator.empty 0
} else if (fraction >= 1.0) { } else if (fraction >= 1.0) {
items 1
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon) gapSampling.sample()
} else { } else {
items.filter { _ => rng.nextDouble() <= fraction } if (rng.nextDouble() <= fraction) {
1
} else {
0
}
} }
} }
...@@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag]( ...@@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag](
rngGap.setSeed(seed) rngGap.setSeed(seed)
} }
override def sample(items: Iterator[T]): Iterator[T] = { private lazy val gapSamplingReplacement =
new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon)
override def sample(): Int = {
if (fraction <= 0.0) { if (fraction <= 0.0) {
Iterator.empty 0
} else if (useGapSamplingIfPossible && } else if (useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction) { fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) gapSamplingReplacement.sample()
} else {
rng.sample()
}
}
override def sample(items: Iterator[T]): Iterator[T] = {
if (fraction <= 0.0) {
Iterator.empty
} else { } else {
val useGapSampling = useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction
items.flatMap { item => items.flatMap { item =>
val count = rng.sample() val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item) if (count == 0) Iterator.empty else Iterator.fill(count)(item)
} }
} }
...@@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag]( ...@@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag](
private[spark] private[spark]
class GapSamplingIterator[T: ClassTag]( class GapSampling(
var data: Iterator[T],
f: Double, f: Double,
rng: Random = RandomSampler.newDefaultRNG, rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)") require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
/** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */ private val lnq = math.log1p(-f)
private val iterDrop: Int => Unit = {
val arrayClass = Array.empty[T].iterator.getClass
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
case `arrayClass` =>
(n: Int) => { data = data.drop(n) }
case `arrayBufferClass` =>
(n: Int) => { data = data.drop(n) }
case _ =>
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
}
}
}
override def hasNext: Boolean = data.hasNext
override def next(): T = { /** Return 1 if the next item should be sampled. Otherwise, return 0. */
val r = data.next() def sample(): Int = {
advance() if (countForDropping > 0) {
r countForDropping -= 1
0
} else {
advance()
1
}
} }
private val lnq = math.log1p(-f) private var countForDropping: Int = 0
/** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ /**
* Decide the number of elements that won't be sampled,
* according to geometric dist P(k) = (f)(1-f)^k.
*/
private def advance(): Unit = { private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon) val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / lnq).toInt countForDropping = (math.log(u) / lnq).toInt
iterDrop(k)
} }
/** advance to first sample as part of object construction. */ /** advance to first sample as part of object construction. */
...@@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag]( ...@@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag](
// work reliably. // work reliably.
} }
private[spark] private[spark]
class GapSamplingReplacementIterator[T: ClassTag]( class GapSamplingReplacement(
var data: Iterator[T], val f: Double,
f: Double, val rng: Random = RandomSampler.newDefaultRNG,
rng: Random = RandomSampler.newDefaultRNG, epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
require(f > 0.0, s"Sampling fraction ($f) must be > 0") require(f > 0.0, s"Sampling fraction ($f) must be > 0")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
/** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */ protected val q = math.exp(-f)
private val iterDrop: Int => Unit = {
val arrayClass = Array.empty[T].iterator.getClass
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
case `arrayClass` =>
(n: Int) => { data = data.drop(n) }
case `arrayBufferClass` =>
(n: Int) => { data = data.drop(n) }
case _ =>
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
}
}
}
/** current sampling value, and its replication factor, as we are sampling with replacement. */
private var v: T = _
private var rep: Int = 0
override def hasNext: Boolean = data.hasNext || rep > 0
override def next(): T = {
val r = v
rep -= 1
if (rep <= 0) advance()
r
}
/**
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probability of Poisson(0; f)
*/
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / (-f)).toInt
iterDrop(k)
// set the value and replication factor for the next value
if (data.hasNext) {
v = data.next()
rep = poissonGE1
}
}
private val q = math.exp(-f)
/** /**
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1. * Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
* This is an adaptation from the algorithm for Generating Poisson distributed random variables: * This is an adaptation from the algorithm for Generating Poisson distributed random variables:
* http://en.wikipedia.org/wiki/Poisson_distribution * http://en.wikipedia.org/wiki/Poisson_distribution
*/ */
private def poissonGE1: Int = { protected def poissonGE1: Int = {
// simulate that the standard poisson sampling // simulate that the standard poisson sampling
// gave us at least one iteration, for a sample of >= 1 // gave us at least one iteration, for a sample of >= 1
var pp = q + ((1.0 - q) * rng.nextDouble()) var pp = q + ((1.0 - q) * rng.nextDouble())
...@@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag]( ...@@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag](
} }
r r
} }
private var countForDropping: Int = 0
def sample(): Int = {
if (countForDropping > 0) {
countForDropping -= 1
0
} else {
val r = poissonGE1
advance()
r
}
}
/**
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probabililty of Poisson(0; f)
*/
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
countForDropping = (math.log(u) / (-f)).toInt
}
/** advance to first sample as part of object construction. */ /** advance to first sample as part of object construction. */
advance() advance()
......
...@@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] { ...@@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] {
s = seed s = seed
} }
override def sample(): Int = 1
override def sample(items: Iterator[Long]): Iterator[Long] = { override def sample(items: Iterator[Long]): Iterator[Long] = {
Iterator(s) Iterator(s)
} }
......
...@@ -129,6 +129,13 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -129,6 +129,13 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
t(m / 2) t(m / 2)
} }
def replacementSampling(data: Iterator[Int], sampler: PoissonSampler[Int]): Iterator[Int] = {
data.flatMap { item =>
val count = sampler.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}
}
test("utilities") { test("utilities") {
val s1 = Array(0, 1, 1, 0, 2) val s1 = Array(0, 1, 1, 0, 2)
val s2 = Array(1, 0, 3, 2, 1) val s2 = Array(1, 0, 3, 2, 1)
...@@ -189,6 +196,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -189,6 +196,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D d should be > D
} }
test("bernoulli sampling without iterator") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
var d: Double = 0.0
val data = Iterator.from(0)
var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.5)))
d should be < D
sampler = new BernoulliSampler[Int](0.7)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.7)))
d should be < D
sampler = new BernoulliSampler[Int](0.9)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9)))
d should be < D
// sampling at different frequencies should show up as statistically different:
sampler = new BernoulliSampler[Int](0.5)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.6)))
d should be > D
}
test("bernoulli sampling with gap sampling optimization") { test("bernoulli sampling with gap sampling optimization") {
// Tests expect maximum gap sampling fraction to be this value // Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4) RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
...@@ -217,6 +254,37 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -217,6 +254,37 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D d should be > D
} }
test("bernoulli sampling (without iterator) with gap sampling optimization") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
var d: Double = 0.0
val data = Iterator.from(0)
var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)),
gaps(sample(Iterator.from(0), 0.01)))
d should be < D
sampler = new BernoulliSampler[Int](0.1)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1)))
d should be < D
sampler = new BernoulliSampler[Int](0.3)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.3)))
d should be < D
// sampling at different frequencies should show up as statistically different:
sampler = new BernoulliSampler[Int](0.3)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.4)))
d should be > D
}
test("bernoulli boundary cases") { test("bernoulli boundary cases") {
val data = (1 to 100).toArray val data = (1 to 100).toArray
...@@ -233,6 +301,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -233,6 +301,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
sampler.sample(data.iterator).toArray should be (data) sampler.sample(data.iterator).toArray should be (data)
} }
test("bernoulli (without iterator) boundary cases") {
val data = (1 to 100).toArray
var sampler = new BernoulliSampler[Int](0.0)
data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int])
sampler = new BernoulliSampler[Int](1.0)
data.filter(_ => sampler.sample() > 0) should be (data)
sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0))
data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int])
sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 2.0))
data.filter(_ => sampler.sample() > 0) should be (data)
}
test("bernoulli data types") { test("bernoulli data types") {
// Tests expect maximum gap sampling fraction to be this value // Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4) RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
...@@ -341,6 +425,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -341,6 +425,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D d should be > D
} }
test("replacement sampling without iterator") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
var d: Double = 0.0
val data = Iterator.from(0)
var sampler = new PoissonSampler[Int](0.5)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.5)))
d should be < D
sampler = new PoissonSampler[Int](0.7)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.7)))
d should be < D
sampler = new PoissonSampler[Int](0.9)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.9)))
d should be < D
// sampling at different frequencies should show up as statistically different:
sampler = new PoissonSampler[Int](0.5)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.6)))
d should be > D
}
test("replacement sampling with gap sampling") { test("replacement sampling with gap sampling") {
// Tests expect maximum gap sampling fraction to be this value // Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4) RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
...@@ -369,6 +483,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -369,6 +483,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D d should be > D
} }
test("replacement sampling (without iterator) with gap sampling") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
var d: Double = 0.0
val data = Iterator.from(0)
var sampler = new PoissonSampler[Int](0.01)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.01)))
d should be < D
sampler = new PoissonSampler[Int](0.1)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.1)))
d should be < D
sampler = new PoissonSampler[Int](0.3)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.3)))
d should be < D
// sampling at different frequencies should show up as statistically different:
sampler = new PoissonSampler[Int](0.3)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.4)))
d should be > D
}
test("replacement boundary cases") { test("replacement boundary cases") {
val data = (1 to 100).toArray val data = (1 to 100).toArray
...@@ -383,6 +527,20 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -383,6 +527,20 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
sampler.sample(data.iterator).length should be > (data.length) sampler.sample(data.iterator).length should be > (data.length)
} }
test("replacement (without) boundary cases") {
val data = (1 to 100).toArray
var sampler = new PoissonSampler[Int](0.0)
replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int])
sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0))
replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int])
// sampling with replacement has no upper bound on sampling fraction
sampler = new PoissonSampler[Int](2.0)
replacementSampling(data.iterator, sampler).length should be > (data.length)
}
test("replacement data types") { test("replacement data types") {
// Tests expect maximum gap sampling fraction to be this value // Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4) RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
...@@ -477,6 +635,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -477,6 +635,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be < D d should be < D
} }
test("bernoulli partitioning sampling without iterator") {
var d: Double = 0.0
val data = Iterator.from(0)
var sampler = new BernoulliCellSampler[Int](0.1, 0.2)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1)))
d should be < D
sampler = new BernoulliCellSampler[Int](0.1, 0.2, true)
sampler.setSeed(rngSeed.nextLong)
d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9)))
d should be < D
}
test("bernoulli partitioning boundary cases") { test("bernoulli partitioning boundary cases") {
val data = (1 to 100).toArray val data = (1 to 100).toArray
val d = RandomSampler.roundingEpsilon / 2.0 val d = RandomSampler.roundingEpsilon / 2.0
...@@ -500,6 +674,29 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { ...@@ -500,6 +674,29 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
sampler.sample(data.iterator).toArray should be (Array.empty[Int]) sampler.sample(data.iterator).toArray should be (Array.empty[Int])
} }
test("bernoulli partitioning (without iterator) boundary cases") {
val data = (1 to 100).toArray
val d = RandomSampler.roundingEpsilon / 2.0
var sampler = new BernoulliCellSampler[Int](0.0, 0.0)
data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
sampler = new BernoulliCellSampler[Int](0.5, 0.5)
data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
sampler = new BernoulliCellSampler[Int](1.0, 1.0)
data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
sampler = new BernoulliCellSampler[Int](0.0, 1.0)
data.filter(_ => sampler.sample() > 0).toArray should be (data)
sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d)
data.filter(_ => sampler.sample() > 0).toArray should be (data)
sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d)
data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
}
test("bernoulli partitioning data") { test("bernoulli partitioning data") {
val seed = rngSeed.nextLong val seed = rngSeed.nextLong
val data = (1 to 100).toArray val data = (1 to 100).toArray
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment