Skip to content
Snippets Groups Projects
Commit c3527a33 authored by Holden Karau's avatar Holden Karau Committed by Patrick Wendell
Browse files

SPARK-1310: Start adding k-fold cross validation to MLLib [adds kFold to...

SPARK-1310: Start adding k-fold cross validation to MLLib [adds kFold to MLUtils & fixes bug in BernoulliSampler]

Author: Holden Karau <holden@pigscanfly.ca>

Closes #18 from holdenk/addkfoldcrossvalidation and squashes the following commits:

208db9b [Holden Karau] Fix a bad space
e84f2fc [Holden Karau] Fix the test, we should be looking at the second element instead
6ddbf05 [Holden Karau] swap training and validation order
7157ae9 [Holden Karau] CR feedback
90896c7 [Holden Karau] New line
150889c [Holden Karau] Fix up error messages in the MLUtilsSuite
2cb90b3 [Holden Karau] Fix the names in kFold
c702a96 [Holden Karau] Fix imports in MLUtils
e187e35 [Holden Karau] Move { up to same line as whenExecuting(random) in RandomSamplerSuite.scala
c5b723f [Holden Karau] clean up
7ebe4d5 [Holden Karau] CR feedback, remove unecessary learners (came back during merge mistake) and insert an empty line
bb5fa56 [Holden Karau] extra line sadness
163c5b1 [Holden Karau] code review feedback 1.to -> 1 to and folds -> numFolds
5a33f1d [Holden Karau] Code review follow up.
e8741a7 [Holden Karau] CR feedback
b78804e [Holden Karau] Remove cross validation [TODO in another pull request]
91eae64 [Holden Karau] Consolidate things in mlutils
264502a [Holden Karau] Add a test for the bug that was found with BernoulliSampler not copying the complement param
dd0b737 [Holden Karau] Wrap long lines (oops)
c0b7fa4 [Holden Karau] Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD
08f8e4d [Holden Karau] Fix BernoulliSampler to respect complement
a751ec6 [Holden Karau] Add k-fold cross validation to MLLib
parent 9edd8878
No related branches found
No related tags found
No related merge requests found
...@@ -69,7 +69,12 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) ...@@ -69,7 +69,12 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
} }
} }
override def clone = new BernoulliSampler[T](lb, ub) /**
* Return a sampler with is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
override def clone = new BernoulliSampler[T](lb, ub, complement)
} }
/** /**
......
...@@ -41,21 +41,31 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -41,21 +41,31 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
random.nextDouble().andReturn(x) random.nextDouble().andReturn(x)
} }
} }
whenExecuting(random) whenExecuting(random) {
{
val sampler = new BernoulliSampler[Int](0.25, 0.55)(random) val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
assert(sampler.sample(a.iterator).toList == List(3, 4, 5)) assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
} }
} }
test("BernoulliSamplerWithRangeInverse") {
expecting {
for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
random.nextDouble().andReturn(x)
}
}
whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
}
}
test("BernoulliSamplerWithRatio") { test("BernoulliSamplerWithRatio") {
expecting { expecting {
for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
random.nextDouble().andReturn(x) random.nextDouble().andReturn(x)
} }
} }
whenExecuting(random) whenExecuting(random) {
{
val sampler = new BernoulliSampler[Int](0.35)(random) val sampler = new BernoulliSampler[Int](0.35)(random)
assert(sampler.sample(a.iterator).toList == List(1, 2, 3)) assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
} }
...@@ -67,8 +77,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -67,8 +77,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
random.nextDouble().andReturn(x) random.nextDouble().andReturn(x)
} }
} }
whenExecuting(random) whenExecuting(random) {
{
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9)) assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
} }
...@@ -78,8 +87,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar ...@@ -78,8 +87,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
expecting { expecting {
random.setSeed(10L) random.setSeed(10L)
} }
whenExecuting(random) whenExecuting(random) {
{
val sampler = new BernoulliSampler[Int](0.2)(random) val sampler = new BernoulliSampler[Int](0.2)(random)
sampler.setSeed(10L) sampler.setSeed(10L)
} }
......
...@@ -17,11 +17,16 @@ ...@@ -17,11 +17,16 @@
package org.apache.spark.mllib.util package org.apache.spark.mllib.util
import scala.reflect.ClassTag
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
import org.apache.spark.SparkContext._
import org.apache.spark.util.random.BernoulliSampler
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.Vectors
...@@ -157,6 +162,22 @@ object MLUtils { ...@@ -157,6 +162,22 @@ object MLUtils {
dataStr.saveAsTextFile(dir) dataStr.saveAsTextFile(dir)
} }
/**
* Return a k element array of pairs of RDDs with the first element of each pair
* containing the training data, a complement of the validation data and the second
* element, the validation data, containing a unique 1/kth of the data. Where k=numFolds.
*/
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
(training, validation)
}.toArray
}
/** /**
* Returns the squared Euclidean distance between two vectors. The following formula will be used * Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error: * if it does not introduce too much numerical error:
......
...@@ -19,6 +19,9 @@ package org.apache.spark.mllib.util ...@@ -19,6 +19,9 @@ package org.apache.spark.mllib.util
import java.io.File import java.io.File
import scala.math
import scala.util.Random
import org.scalatest.FunSuite import org.scalatest.FunSuite
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
...@@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { ...@@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
case t: Throwable => case t: Throwable =>
} }
} }
test("kFold") {
val data = sc.parallelize(1 to 100, 2)
val collectedData = data.collect().sorted
val twoFoldedRdd = MLUtils.kFold(data, 2, 1)
assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
for (folds <- 2 to 10) {
for (seed <- 1 to 5) {
val foldedRdds = MLUtils.kFold(data, folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map { case (training, validation) =>
val result = validation.union(training).collect().sorted
val validationSize = validation.collect().size.toFloat
assert(validationSize > 0, "empty validation data")
val p = 1 / folds.toFloat
// Within 3 standard deviations of the mean
val range = 3 * math.sqrt(100 * p * (1 - p))
val expected = 100 * p
val lowerBound = expected - range
val upperBound = expected + range
assert(validationSize > lowerBound,
s"Validation data ($validationSize) smaller than expected ($lowerBound)" )
assert(validationSize < upperBound,
s"Validation data ($validationSize) larger than expected ($upperBound)" )
assert(training.collect().size > 0, "empty training data")
assert(result === collectedData,
"Each training+validation set combined should contain all of the data.")
}
// K fold cross validation should only have each element in the validation set exactly once
assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted ===
data.collect().sorted)
}
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment