diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index 78172843be56e7fac61505cbe53018b8d4210abc..19a047ded257cef0b1e80cbf94fcb7f70950dc1b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -37,15 +37,20 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
    * trigger a Spark job if the parent RDD has more than one partitions and the window size is
    * greater than 1.
    */
-  def sliding(windowSize: Int): RDD[Array[T]] = {
+  def sliding(windowSize: Int, step: Int): RDD[Array[T]] = {
     require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
-    if (windowSize == 1) {
+    if (windowSize == 1 && step == 1) {
       self.map(Array(_))
     } else {
-      new SlidingRDD[T](self, windowSize)
+      new SlidingRDD[T](self, windowSize, step)
     }
   }
 
+  /**
+   * [[sliding(Int, Int)*]] with step = 1.
+   */
+  def sliding(windowSize: Int): RDD[Array[T]] = sliding(windowSize, 1)
+
   /**
    * Reduces the elements of this RDD in a multi-level tree pattern.
    *
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
index 1facf83d806d072fbd043d10a8c0c9aabf39eb0d..ead8db6344998b85fbe850f26e9626bdb533d96c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
@@ -24,13 +24,13 @@ import org.apache.spark.{TaskContext, Partition}
 import org.apache.spark.rdd.RDD
 
 private[mllib]
-class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T])
+class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int)
   extends Partition with Serializable {
   override val index: Int = idx
 }
 
 /**
- * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
+ * Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
  * window over them. The ordering is first based on the partition index and then the ordering of
  * items within each partition. This is similar to sliding in Scala collections, except that it
  * becomes an empty RDD if the window size is greater than the total number of items. It needs to
@@ -40,19 +40,24 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]
  *
  * @param parent the parent RDD
  * @param windowSize the window size, must be greater than 1
+ * @param step step size for windows
  *
- * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]]
+ * @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]]
+ * @see [[scala.collection.IterableLike.sliding(Int, Int)*]]
  */
 private[mllib]
-class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
+class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int)
   extends RDD[Array[T]](parent) {
 
-  require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
+  require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1),
+    "Window size and step must be greater than 0, " +
+      s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.")
 
   override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
     val part = split.asInstanceOf[SlidingRDDPartition[T]]
     (firstParent[T].iterator(part.prev, context) ++ part.tail)
-      .sliding(windowSize)
+      .drop(part.offset)
+      .sliding(windowSize, step)
       .withPartial(false)
       .map(_.toArray)
   }
@@ -62,40 +67,42 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int
 
   override def getPartitions: Array[Partition] = {
     val parentPartitions = parent.partitions
-    val n = parentPartitions.size
+    val n = parentPartitions.length
     if (n == 0) {
       Array.empty
     } else if (n == 1) {
-      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty))
+      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0))
     } else {
-      val n1 = n - 1
       val w1 = windowSize - 1
-      // Get the first w1 items of each partition, starting from the second partition.
-      val nextHeads =
-        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n)
-      val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
+      // Get partition sizes and first w1 elements.
+      val (sizes, heads) = parent.mapPartitions { iter =>
+        val w1Array = iter.take(w1).toArray
+        Iterator.single((w1Array.length + iter.length, w1Array))
+      }.collect().unzip
+      val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]]
       var i = 0
+      var cumSize = 0
       var partitionIndex = 0
-      while (i < n1) {
-        var j = i
-        val tail = mutable.ListBuffer[T]()
-        // Keep appending to the current tail until appended a head of size w1.
-        while (j < n1 && nextHeads(j).size < w1) {
-          tail ++= nextHeads(j)
-          j += 1
+      while (i < n) {
+        val mod = cumSize % step
+        val offset = if (mod == 0) 0 else step - mod
+        val size = sizes(i)
+        if (offset < size) {
+          val tail = mutable.ListBuffer.empty[T]
+          // Keep appending to the current tail until it has w1 elements.
+          var j = i + 1
+          while (j < n && tail.length < w1) {
+            tail ++= heads(j).take(w1 - tail.length)
+            j += 1
+          }
+          if (sizes(i) + tail.length >= offset + windowSize) {
+            partitions +=
+              new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset)
+            partitionIndex += 1
+          }
         }
-        if (j < n1) {
-          tail ++= nextHeads(j)
-          j += 1
-        }
-        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail)
-        partitionIndex += 1
-        // Skip appended heads.
-        i = j
-      }
-      // If the head of last partition has size w1, we also need to add this partition.
-      if (nextHeads.last.size == w1) {
-        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty)
+        cumSize += size
+        i += 1
       }
       partitions.toArray
     }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index bc64172614830ea3abe853e85daf86f1bd8d6e21..ac93733bab5f56ec3901a058e2437aac2f7e415a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -28,9 +28,12 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
     for (numPartitions <- 1 to 8) {
       val rdd = sc.parallelize(data, numPartitions)
       for (windowSize <- 1 to 6) {
-        val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList
-        val expected = data.sliding(windowSize).map(_.toList).toList
-        assert(sliding === expected)
+        for (step <- 1 to 3) {
+          val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList
+          val expected = data.sliding(windowSize, step)
+            .map(_.toList).toList.filter(l => l.size == windowSize)
+          assert(sliding === expected)
+        }
       }
       assert(rdd.sliding(7).collect().isEmpty,
         "Should return an empty RDD if the window size is greater than the number of items.")
@@ -40,7 +43,7 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
   test("sliding with empty partitions") {
     val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
     val rdd = sc.parallelize(data, data.length).flatMap(s => s)
-    assert(rdd.partitions.size === data.length)
+    assert(rdd.partitions.length === data.length)
     val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq)
     val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
     assert(sliding === expected)