diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 37a6b04f5200f617d262f3d385c18a0785c19082..4dc8ada00a3e86cc79ac47eee68f62d47f97284b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -69,7 +69,12 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } } - override def clone = new BernoulliSampler[T](lb, ub) + /** + * Return a sampler with is the complement of the range specified of the current sampler. + */ + def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) + + override def clone = new BernoulliSampler[T](lb, ub, complement) } /** diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 7576c9a51f313da9450f6c8a50f20d1681946ac1..e166787f17544d8b9ecf3539951313fae6aaf35f 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -41,21 +41,31 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar random.nextDouble().andReturn(x) } } - whenExecuting(random) - { + whenExecuting(random) { val sampler = new BernoulliSampler[Int](0.25, 0.55)(random) assert(sampler.sample(a.iterator).toList == List(3, 4, 5)) } } + test("BernoulliSamplerWithRangeInverse") { + expecting { + for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { + random.nextDouble().andReturn(x) + } + } + whenExecuting(random) { + val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) + assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9)) + } + } + test("BernoulliSamplerWithRatio") { expecting { for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { random.nextDouble().andReturn(x) } } - whenExecuting(random) - { + whenExecuting(random) { val sampler = new BernoulliSampler[Int](0.35)(random) assert(sampler.sample(a.iterator).toList == List(1, 2, 3)) } @@ -67,8 +77,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar random.nextDouble().andReturn(x) } } - whenExecuting(random) - { + whenExecuting(random) { val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9)) } @@ -78,8 +87,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar expecting { random.setSeed(10L) } - whenExecuting(random) - { + whenExecuting(random) { val sampler = new BernoulliSampler[Int](0.2)(random) sampler.setSeed(10L) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 901c3180eac4cf0b4ca16afbc21a8bdb53f1e13a..2f3ac1039751515bba07d48ea15994f51515c828 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,11 +17,16 @@ package org.apache.spark.mllib.util +import scala.reflect.ClassTag + import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.PartitionwiseSampledRDD +import org.apache.spark.SparkContext._ +import org.apache.spark.util.random.BernoulliSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors @@ -157,6 +162,22 @@ object MLUtils { dataStr.saveAsTextFile(dir) } + /** + * Return a k element array of pairs of RDDs with the first element of each pair + * containing the training data, a complement of the validation data and the second + * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + */ + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + val numFoldsF = numFolds.toFloat + (1 to numFolds).map { fold => + val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, + complement = false) + val validation = new PartitionwiseSampledRDD(rdd, sampler, seed) + val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed) + (training, validation) + }.toArray + } + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 812a8434784beed8d5ae9c492099065aa0e13113..674378a34ce34e2982495d8fbd117fbe9e2ee0c9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.mllib.util import java.io.File +import scala.math +import scala.util.Random + import org.scalatest.FunSuite import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, @@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { case t: Throwable => } } + + test("kFold") { + val data = sc.parallelize(1 to 100, 2) + val collectedData = data.collect().sorted + val twoFoldedRdd = MLUtils.kFold(data, 2, 1) + assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted) + assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted) + for (folds <- 2 to 10) { + for (seed <- 1 to 5) { + val foldedRdds = MLUtils.kFold(data, folds, seed) + assert(foldedRdds.size === folds) + foldedRdds.map { case (training, validation) => + val result = validation.union(training).collect().sorted + val validationSize = validation.collect().size.toFloat + assert(validationSize > 0, "empty validation data") + val p = 1 / folds.toFloat + // Within 3 standard deviations of the mean + val range = 3 * math.sqrt(100 * p * (1 - p)) + val expected = 100 * p + val lowerBound = expected - range + val upperBound = expected + range + assert(validationSize > lowerBound, + s"Validation data ($validationSize) smaller than expected ($lowerBound)" ) + assert(validationSize < upperBound, + s"Validation data ($validationSize) larger than expected ($upperBound)" ) + assert(training.collect().size > 0, "empty training data") + assert(result === collectedData, + "Each training+validation set combined should contain all of the data.") + } + // K fold cross validation should only have each element in the validation set exactly once + assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + data.collect().sorted) + } + } + } + }