diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 37a6b04f5200f617d262f3d385c18a0785c19082..4dc8ada00a3e86cc79ac47eee68f62d47f97284b 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -69,7 +69,12 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
     }
   }
 
-  override def clone = new BernoulliSampler[T](lb, ub)
+  /**
+   *  Return a sampler with is the complement of the range specified of the current sampler.
+   */
+  def cloneComplement():  BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
+
+  override def clone = new BernoulliSampler[T](lb, ub, complement)
 }
 
 /**
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index 7576c9a51f313da9450f6c8a50f20d1681946ac1..e166787f17544d8b9ecf3539951313fae6aaf35f 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -41,21 +41,31 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
         random.nextDouble().andReturn(x)
       }
     }
-    whenExecuting(random)
-    {
+    whenExecuting(random) {
       val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
       assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
     }
   }
 
+  test("BernoulliSamplerWithRangeInverse") {
+    expecting {
+      for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
+        random.nextDouble().andReturn(x)
+      }
+    }
+    whenExecuting(random) {
+      val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+      assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
+    }
+  }
+
   test("BernoulliSamplerWithRatio") {
     expecting {
       for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
         random.nextDouble().andReturn(x)
       }
     }
-    whenExecuting(random)
-    {
+    whenExecuting(random) {
       val sampler = new BernoulliSampler[Int](0.35)(random)
       assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
     }
@@ -67,8 +77,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
         random.nextDouble().andReturn(x)
       }
     }
-    whenExecuting(random)
-    {
+    whenExecuting(random) {
       val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
       assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
     }
@@ -78,8 +87,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     expecting {
       random.setSeed(10L)
     }
-    whenExecuting(random)
-    {
+    whenExecuting(random) {
       val sampler = new BernoulliSampler[Int](0.2)(random)
       sampler.setSeed(10L)
     }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 901c3180eac4cf0b4ca16afbc21a8bdb53f1e13a..2f3ac1039751515bba07d48ea15994f51515c828 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -17,11 +17,16 @@
 
 package org.apache.spark.mllib.util
 
+import scala.reflect.ClassTag
+
 import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.PartitionwiseSampledRDD
+import org.apache.spark.SparkContext._
+import org.apache.spark.util.random.BernoulliSampler
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.linalg.Vectors
 
@@ -157,6 +162,22 @@ object MLUtils {
     dataStr.saveAsTextFile(dir)
   }
 
+  /**
+   * Return a k element array of pairs of RDDs with the first element of each pair
+   * containing the training data, a complement of the validation data and the second
+   * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds.
+   */
+  def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
+    val numFoldsF = numFolds.toFloat
+    (1 to numFolds).map { fold =>
+      val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
+        complement = false)
+      val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
+      val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
+      (training, validation)
+    }.toArray
+  }
+
   /**
    * Returns the squared Euclidean distance between two vectors. The following formula will be used
    * if it does not introduce too much numerical error:
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 812a8434784beed8d5ae9c492099065aa0e13113..674378a34ce34e2982495d8fbd117fbe9e2ee0c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.mllib.util
 
 import java.io.File
 
+import scala.math
+import scala.util.Random
+
 import org.scalatest.FunSuite
 
 import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
@@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
       case t: Throwable =>
     }
   }
+
+  test("kFold") {
+    val data = sc.parallelize(1 to 100, 2)
+    val collectedData = data.collect().sorted
+    val twoFoldedRdd = MLUtils.kFold(data, 2, 1)
+    assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
+    assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
+    for (folds <- 2 to 10) {
+      for (seed <- 1 to 5) {
+        val foldedRdds = MLUtils.kFold(data, folds, seed)
+        assert(foldedRdds.size === folds)
+        foldedRdds.map { case (training, validation) =>
+          val result = validation.union(training).collect().sorted
+          val validationSize = validation.collect().size.toFloat
+          assert(validationSize > 0, "empty validation data")
+          val p = 1 / folds.toFloat
+          // Within 3 standard deviations of the mean
+          val range = 3 * math.sqrt(100 * p * (1 - p))
+          val expected = 100 * p
+          val lowerBound = expected - range
+          val upperBound = expected + range
+          assert(validationSize > lowerBound,
+            s"Validation data ($validationSize) smaller than expected ($lowerBound)" )
+          assert(validationSize < upperBound,
+            s"Validation data ($validationSize) larger than expected ($upperBound)" )
+          assert(training.collect().size > 0, "empty training data")
+          assert(result ===  collectedData,
+            "Each training+validation set combined should contain all of the data.")
+        }
+        // K fold cross validation should only have each element in the validation set exactly once
+        assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted ===
+          data.collect().sorted)
+      }
+    }
+  }
+
 }