From f5ace8da34c58d1005c7c377cfe3df21102c1dd6 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Fri, 11 Apr 2014 12:06:13 -0700
Subject: [PATCH] [SPARK-1225, 1241] [MLLIB] Add AreaUnderCurve and
 BinaryClassificationMetrics

This PR implements a generic version of `AreaUnderCurve` using the `RDD.sliding` implementation from https://github.com/apache/spark/pull/136 . It also contains refactoring of https://github.com/apache/spark/pull/160 for binary classification evaluation.

Author: Xiangrui Meng <meng@databricks.com>

Closes #364 from mengxr/auc and squashes the following commits:

a05941d [Xiangrui Meng] replace TP/FP/TN/FN by their full names
3f42e98 [Xiangrui Meng] add (0, 0), (1, 1) to roc, and (0, 1) to pr
fb4b6d2 [Xiangrui Meng] rename Evaluator to Metrics and add more metrics
b1b7dab [Xiangrui Meng] fix code styles
9dc3518 [Xiangrui Meng] add tests for BinaryClassificationEvaluator
ca31da5 [Xiangrui Meng] remove PredictionAndResponse
3d71525 [Xiangrui Meng] move binary evalution classes to evaluation.binary
8f78958 [Xiangrui Meng] add PredictionAndResponse
dda82d5 [Xiangrui Meng] add confusion matrix
aa7e278 [Xiangrui Meng] add initial version of binary classification evaluator
221ebce [Xiangrui Meng] add a new test to sliding
a920865 [Xiangrui Meng] Merge branch 'sliding' into auc
a9b250a [Xiangrui Meng] move sliding to mllib
cab9a52 [Xiangrui Meng] use last for the last element
db6cb30 [Xiangrui Meng] remove unnecessary toSeq
9916202 [Xiangrui Meng] change RDD.sliding return type to RDD[Seq[T]]
284d991 [Xiangrui Meng] change SlidedRDD to SlidingRDD
c1c6c22 [Xiangrui Meng] add AreaUnderCurve
65461b2 [Xiangrui Meng] Merge branch 'sliding' into auc
5ee6001 [Xiangrui Meng] add TODO
d2a600d [Xiangrui Meng] add sliding to rdd
---
 .../mllib/evaluation/AreaUnderCurve.scala     |  62 ++++++
 .../BinaryClassificationMetricComputers.scala |  57 +++++
 .../binary/BinaryClassificationMetrics.scala  | 204 ++++++++++++++++++
 .../binary/BinaryConfusionMatrix.scala        |  41 ++++
 .../apache/spark/mllib/rdd/RDDFunctions.scala |  53 +++++
 .../apache/spark/mllib/rdd/SlidingRDD.scala   | 104 +++++++++
 .../evaluation/AreaUnderCurveSuite.scala      |  46 ++++
 .../BinaryClassificationMetricsSuite.scala    |  55 +++++
 .../spark/mllib/rdd/RDDFunctionsSuite.scala   |  49 +++++
 9 files changed, 671 insertions(+)
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
new file mode 100644
index 0000000000..7858ec6024
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.rdd.RDDFunctions._
+
+/**
+ * Computes the area under the curve (AUC) using the trapezoidal rule.
+ */
+private[evaluation] object AreaUnderCurve {
+
+  /**
+   * Uses the trapezoidal rule to compute the area under the line connecting the two input points.
+   * @param points two 2D points stored in Seq
+   */
+  private def trapezoid(points: Seq[(Double, Double)]): Double = {
+    require(points.length == 2)
+    val x = points.head
+    val y = points.last
+    (y._1 - x._1) * (y._2 + x._2) / 2.0
+  }
+
+  /**
+   * Returns the area under the given curve.
+   *
+   * @param curve a RDD of ordered 2D points stored in pairs representing a curve
+   */
+  def of(curve: RDD[(Double, Double)]): Double = {
+    curve.sliding(2).aggregate(0.0)(
+      seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
+      combOp = _ + _
+    )
+  }
+
+  /**
+   * Returns the area under the given curve.
+   *
+   * @param curve an iterator over ordered 2D points stored in pairs representing a curve
+   */
+  def of(curve: Iterable[(Double, Double)]): Double = {
+    curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
+      seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
+      combop = _ + _
+    )
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
new file mode 100644
index 0000000000..562663ad36
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation.binary
+
+/**
+ * Trait for a binary classification evaluation metric computer.
+ */
+private[evaluation] trait BinaryClassificationMetricComputer extends Serializable {
+  def apply(c: BinaryConfusionMatrix): Double
+}
+
+/** Precision. */
+private[evaluation] object Precision extends BinaryClassificationMetricComputer {
+  override def apply(c: BinaryConfusionMatrix): Double =
+    c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
+}
+
+/** False positive rate. */
+private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
+  override def apply(c: BinaryConfusionMatrix): Double =
+    c.numFalsePositives.toDouble / c.numNegatives
+}
+
+/** Recall. */
+private[evaluation] object Recall extends BinaryClassificationMetricComputer {
+  override def apply(c: BinaryConfusionMatrix): Double =
+    c.numTruePositives.toDouble / c.numPositives
+}
+
+/**
+ * F-Measure.
+ * @param beta the beta constant in F-Measure
+ * @see http://en.wikipedia.org/wiki/F1_score
+ */
+private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer {
+  private val beta2 = beta * beta
+  override def apply(c: BinaryConfusionMatrix): Double = {
+    val precision = Precision(c)
+    val recall = Recall(c)
+    (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala
new file mode 100644
index 0000000000..ed7b0fc943
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation.binary
+
+import org.apache.spark.rdd.{UnionRDD, RDD}
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.evaluation.AreaUnderCurve
+import org.apache.spark.Logging
+
+/**
+ * Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
+ *
+ * @param count label counter for labels with scores greater than or equal to the current score
+ * @param totalCount label counter for all labels
+ */
+private case class BinaryConfusionMatrixImpl(
+    count: LabelCounter,
+    totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
+
+  /** number of true positives */
+  override def numTruePositives: Long = count.numPositives
+
+  /** number of false positives */
+  override def numFalsePositives: Long = count.numNegatives
+
+  /** number of false negatives */
+  override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
+
+  /** number of true negatives */
+  override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
+
+  /** number of positives */
+  override def numPositives: Long = totalCount.numPositives
+
+  /** number of negatives */
+  override def numNegatives: Long = totalCount.numNegatives
+}
+
+/**
+ * Evaluator for binary classification.
+ *
+ * @param scoreAndLabels an RDD of (score, label) pairs.
+ */
+class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
+  extends Serializable with Logging {
+
+  private lazy val (
+      cumulativeCounts: RDD[(Double, LabelCounter)],
+      confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
+    // Create a bin for each distinct score value, count positives and negatives within each bin,
+    // and then sort by score values in descending order.
+    val counts = scoreAndLabels.combineByKey(
+      createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
+      mergeValue = (c: LabelCounter, label: Double) => c += label,
+      mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
+    ).sortByKey(ascending = false)
+    val agg = counts.values.mapPartitions({ iter =>
+      val agg = new LabelCounter()
+      iter.foreach(agg += _)
+      Iterator(agg)
+    }, preservesPartitioning = true).collect()
+    val partitionwiseCumulativeCounts =
+      agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
+    val totalCount = partitionwiseCumulativeCounts.last
+    logInfo(s"Total counts: $totalCount")
+    val cumulativeCounts = counts.mapPartitionsWithIndex(
+      (index: Int, iter: Iterator[(Double, LabelCounter)]) => {
+        val cumCount = partitionwiseCumulativeCounts(index)
+        iter.map { case (score, c) =>
+          cumCount += c
+          (score, cumCount.clone())
+        }
+      }, preservesPartitioning = true)
+    cumulativeCounts.persist()
+    val confusions = cumulativeCounts.map { case (score, cumCount) =>
+      (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
+    }
+    (cumulativeCounts, confusions)
+  }
+
+  /** Unpersist intermediate RDDs used in the computation. */
+  def unpersist() {
+    cumulativeCounts.unpersist()
+  }
+
+  /** Returns thresholds in descending order. */
+  def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
+
+  /**
+   * Returns the receiver operating characteristic (ROC) curve,
+   * which is an RDD of (false positive rate, true positive rate)
+   * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+   * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+   */
+  def roc(): RDD[(Double, Double)] = {
+    val rocCurve = createCurve(FalsePositiveRate, Recall)
+    val sc = confusions.context
+    val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
+    val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
+    new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
+  }
+
+  /**
+   * Computes the area under the receiver operating characteristic (ROC) curve.
+   */
+  def areaUnderROC(): Double = AreaUnderCurve.of(roc())
+
+  /**
+   * Returns the precision-recall curve, which is an RDD of (recall, precision),
+   * NOT (precision, recall), with (0.0, 1.0) prepended to it.
+   * @see http://en.wikipedia.org/wiki/Precision_and_recall
+   */
+  def pr(): RDD[(Double, Double)] = {
+    val prCurve = createCurve(Recall, Precision)
+    val sc = confusions.context
+    val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
+    first.union(prCurve)
+  }
+
+  /**
+   * Computes the area under the precision-recall curve.
+   */
+  def areaUnderPR(): Double = AreaUnderCurve.of(pr())
+
+  /**
+   * Returns the (threshold, F-Measure) curve.
+   * @param beta the beta factor in F-Measure computation.
+   * @return an RDD of (threshold, F-Measure) pairs.
+   * @see http://en.wikipedia.org/wiki/F1_score
+   */
+  def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
+
+  /** Returns the (threshold, F-Measure) curve with beta = 1.0. */
+  def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
+
+  /** Returns the (threshold, precision) curve. */
+  def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)
+
+  /** Returns the (threshold, recall) curve. */
+  def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)
+
+  /** Creates a curve of (threshold, metric). */
+  private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
+    confusions.map { case (s, c) =>
+      (s, y(c))
+    }
+  }
+
+  /** Creates a curve of (metricX, metricY). */
+  private def createCurve(
+      x: BinaryClassificationMetricComputer,
+      y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
+    confusions.map { case (_, c) =>
+      (x(c), y(c))
+    }
+  }
+}
+
+/**
+ * A counter for positives and negatives.
+ *
+ * @param numPositives number of positive labels
+ * @param numNegatives number of negative labels
+ */
+private class LabelCounter(
+    var numPositives: Long = 0L,
+    var numNegatives: Long = 0L) extends Serializable {
+
+  /** Processes a label. */
+  def +=(label: Double): LabelCounter = {
+    // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
+    // -1.0 for negative as well.
+    if (label > 0.5) numPositives += 1L else numNegatives += 1L
+    this
+  }
+
+  /** Merges another counter. */
+  def +=(other: LabelCounter): LabelCounter = {
+    numPositives += other.numPositives
+    numNegatives += other.numNegatives
+    this
+  }
+
+  override def clone: LabelCounter = {
+    new LabelCounter(numPositives, numNegatives)
+  }
+
+  override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala
new file mode 100644
index 0000000000..75a75b2160
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation.binary
+
+/**
+ * Trait for a binary confusion matrix.
+ */
+private[evaluation] trait BinaryConfusionMatrix {
+  /** number of true positives */
+  def numTruePositives: Long
+
+  /** number of false positives */
+  def numFalsePositives: Long
+
+  /** number of false negatives */
+  def numFalseNegatives: Long
+
+  /** number of true negatives */
+  def numTrueNegatives: Long
+
+  /** number of positives */
+  def numPositives: Long = numTruePositives + numFalseNegatives
+
+  /** number of negatives */
+  def numNegatives: Long = numFalsePositives + numTrueNegatives
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
new file mode 100644
index 0000000000..873de871fd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.RDD
+
+/**
+ * Machine learning specific RDD functions.
+ */
+private[mllib]
+class RDDFunctions[T: ClassTag](self: RDD[T]) {
+
+  /**
+   * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
+   * window over them. The ordering is first based on the partition index and then the ordering of
+   * items within each partition. This is similar to sliding in Scala collections, except that it
+   * becomes an empty RDD if the window size is greater than the total number of items. It needs to
+   * trigger a Spark job if the parent RDD has more than one partitions and the window size is
+   * greater than 1.
+   */
+  def sliding(windowSize: Int): RDD[Seq[T]] = {
+    require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
+    if (windowSize == 1) {
+      self.map(Seq(_))
+    } else {
+      new SlidingRDD[T](self, windowSize)
+    }
+  }
+}
+
+private[mllib]
+object RDDFunctions {
+
+  /** Implicit conversion from an RDD to RDDFunctions. */
+  implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
new file mode 100644
index 0000000000..dd80782c0f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import org.apache.spark.{TaskContext, Partition}
+import org.apache.spark.rdd.RDD
+
+private[mllib]
+class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T])
+  extends Partition with Serializable {
+  override val index: Int = idx
+}
+
+/**
+ * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
+ * window over them. The ordering is first based on the partition index and then the ordering of
+ * items within each partition. This is similar to sliding in Scala collections, except that it
+ * becomes an empty RDD if the window size is greater than the total number of items. It needs to
+ * trigger a Spark job if the parent RDD has more than one partitions. To make this operation
+ * efficient, the number of items per partition should be larger than the window size and the
+ * window size should be small, e.g., 2.
+ *
+ * @param parent the parent RDD
+ * @param windowSize the window size, must be greater than 1
+ *
+ * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]]
+ */
+private[mllib]
+class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
+  extends RDD[Seq[T]](parent) {
+
+  require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
+
+  override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = {
+    val part = split.asInstanceOf[SlidingRDDPartition[T]]
+    (firstParent[T].iterator(part.prev, context) ++ part.tail)
+      .sliding(windowSize)
+      .withPartial(false)
+  }
+
+  override def getPreferredLocations(split: Partition): Seq[String] =
+    firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev)
+
+  override def getPartitions: Array[Partition] = {
+    val parentPartitions = parent.partitions
+    val n = parentPartitions.size
+    if (n == 0) {
+      Array.empty
+    } else if (n == 1) {
+      Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty))
+    } else {
+      val n1 = n - 1
+      val w1 = windowSize - 1
+      // Get the first w1 items of each partition, starting from the second partition.
+      val nextHeads =
+        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
+      val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
+      var i = 0
+      var partitionIndex = 0
+      while (i < n1) {
+        var j = i
+        val tail = mutable.ListBuffer[T]()
+        // Keep appending to the current tail until appended a head of size w1.
+        while (j < n1 && nextHeads(j).size < w1) {
+          tail ++= nextHeads(j)
+          j += 1
+        }
+        if (j < n1) {
+          tail ++= nextHeads(j)
+          j += 1
+        }
+        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail)
+        partitionIndex += 1
+        // Skip appended heads.
+        i = j
+      }
+      // If the head of last partition has size w1, we also need to add this partition.
+      if (nextHeads.last.size == w1) {
+        partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty)
+      }
+      partitions.toArray
+    }
+  }
+
+  // TODO: Override methods such as aggregate, which only requires one Spark job.
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
new file mode 100644
index 0000000000..1c9844f289
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
+  test("auc computation") {
+    val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
+    val auc = 4.0
+    assert(AreaUnderCurve.of(curve) === auc)
+    val rddCurve = sc.parallelize(curve, 2)
+    assert(AreaUnderCurve.of(rddCurve) == auc)
+  }
+
+  test("auc of an empty curve") {
+    val curve = Seq.empty[(Double, Double)]
+    assert(AreaUnderCurve.of(curve) === 0.0)
+    val rddCurve = sc.parallelize(curve, 2)
+    assert(AreaUnderCurve.of(rddCurve) === 0.0)
+  }
+
+  test("auc of a curve with a single point") {
+    val curve = Seq((1.0, 1.0))
+    assert(AreaUnderCurve.of(curve) === 0.0)
+    val rddCurve = sc.parallelize(curve, 2)
+    assert(AreaUnderCurve.of(rddCurve) === 0.0)
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala
new file mode 100644
index 0000000000..173fdaefab
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation.binary
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.evaluation.AreaUnderCurve
+
+class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
+  test("binary evaluation metrics") {
+    val scoreAndLabels = sc.parallelize(
+      Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
+    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+    val threshold = Seq(0.8, 0.6, 0.4, 0.1)
+    val numTruePositives = Seq(1, 3, 3, 4)
+    val numFalsePositives = Seq(0, 1, 2, 3)
+    val numPositives = 4
+    val numNegatives = 3
+    val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
+      t.toDouble / (t + f)
+    }
+    val recall = numTruePositives.map(t => t.toDouble / numPositives)
+    val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
+    val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
+    val pr = recall.zip(precision)
+    val prCurve = Seq((0.0, 1.0)) ++ pr
+    val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
+    val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
+    assert(metrics.thresholds().collect().toSeq === threshold)
+    assert(metrics.roc().collect().toSeq === rocCurve)
+    assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
+    assert(metrics.pr().collect().toSeq === prCurve)
+    assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
+    assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
+    assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
+    assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
+    assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
new file mode 100644
index 0000000000..3f3b10dfff
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.rdd.RDDFunctions._
+
+class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
+
+  test("sliding") {
+    val data = 0 until 6
+    for (numPartitions <- 1 to 8) {
+      val rdd = sc.parallelize(data, numPartitions)
+      for (windowSize <- 1 to 6) {
+        val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList
+        val expected = data.sliding(windowSize).map(_.toList).toList
+        assert(sliding === expected)
+      }
+      assert(rdd.sliding(7).collect().isEmpty,
+        "Should return an empty RDD if the window size is greater than the number of items.")
+    }
+  }
+
+  test("sliding with empty partitions") {
+    val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
+    val rdd = sc.parallelize(data, data.length).flatMap(s => s)
+    assert(rdd.partitions.size === data.length)
+    val sliding = rdd.sliding(3)
+    val expected = data.flatMap(x => x).sliding(3).toList
+    assert(sliding.collect().toList === expected)
+  }
+}
-- 
GitLab