From 814a9cd7fabebf2a06f7e2e5d46b6a2b28b917c2 Mon Sep 17 00:00:00 2001
From: coderxiang <shuoxiangpub@gmail.com>
Date: Tue, 21 Oct 2014 15:45:47 -0700
Subject: [PATCH] SPARK-3568 [mllib] add ranking metrics

Add common metrics for ranking algorithms (http://www-nlp.stanford.edu/IR-book/), including:
 - Mean Average Precision
 - Precisionn: top-n precision
 - Discounted cumulative gain (DCG) and NDCG

The following methods and the corresponding tests are implemented:

```
class RankingMetrics[T](predictionAndLabels: RDD[(Array[T], Array[T])]) {
  /* Returns the precsionk for each query */
  lazy val precAtK: RDD[Array[Double]]

  /**
   * param k the position to compute the truncated precision
   * return the average precision at the first k ranking positions
   */
  def precision(k: Int): Double

  /* Returns the average precision for each query */
  lazy val avePrec: RDD[Double]

  /*Returns the mean average precision (MAP) of all the queries*/
  lazy val meanAvePrec: Double

  /*Returns the normalized discounted cumulative gain for each query */
  lazy val ndcgAtK: RDD[Array[Double]]

  /**
   * param k the position to compute the truncated ndcg
   * return the average ndcg at the first k ranking positions
   */
  def ndcg(k: Int): Double
}
```

Author: coderxiang <shuoxiangpub@gmail.com>

Closes #2667 from coderxiang/rankingmetrics and squashes the following commits:

d881097 [coderxiang] update doc
14d9cd9 [coderxiang] remove unexpected files
d7fb93f [coderxiang] style change and remove ignored files
f113ee1 [coderxiang] modify doc for displaying superscript and subscript
f626896 [coderxiang] improve doc and remove unnecessary computation while labSet is empty
be6645e [coderxiang] set the precision of empty labset to 0.0
d64c120 [coderxiang] add logWarning for empty ground truth set
dfae292 [coderxiang] handle empty labSet for map. add test
62047c4 [coderxiang] style change and add documentation
f66612d [coderxiang] add additional test of precisionAt
b794cb2 [coderxiang] move private members precAtK, ndcgAtK into public methods. style change
77c9e5d [coderxiang] set precAtK and ndcgAtK as private member. Improve documentation
5f87bce [coderxiang] add API to calculate precision and ndcg at each ranking position
b7851cc [coderxiang] Use generic type to represent IDs
e443fee [coderxiang] change style and use alternative builtin methods
3a5a6ff [coderxiang] add ranking metrics
---
 .../mllib/evaluation/RankingMetrics.scala     | 152 ++++++++++++++++++
 .../evaluation/RankingMetricsSuite.scala      |  54 +++++++
 2 files changed, 206 insertions(+)
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
new file mode 100644
index 0000000000..93a7353e2c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+
+/**
+ * ::Experimental::
+ * Evaluator for ranking algorithms.
+ *
+ * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
+ */
+@Experimental
+class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
+  extends Logging with Serializable {
+
+  /**
+   * Compute the average precision of all the queries, truncated at ranking position k.
+   *
+   * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
+   * computed as #(relevant items retrieved) / k. This formula also applies when the size of the
+   * ground truth set is less than k.
+   *
+   * If a query has an empty ground truth set, zero will be used as precision together with
+   * a log warning.
+   *
+   * See the following paper for detail:
+   *
+   * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+   *
+   * @param k the position to compute the truncated precision, must be positive
+   * @return the average precision at the first k ranking positions
+   */
+  def precisionAt(k: Int): Double = {
+    require(k > 0, "ranking position k should be positive")
+    predictionAndLabels.map { case (pred, lab) =>
+      val labSet = lab.toSet
+
+      if (labSet.nonEmpty) {
+        val n = math.min(pred.length, k)
+        var i = 0
+        var cnt = 0
+        while (i < n) {
+          if (labSet.contains(pred(i))) {
+            cnt += 1
+          }
+          i += 1
+        }
+        cnt.toDouble / k
+      } else {
+        logWarning("Empty ground truth set, check input data")
+        0.0
+      }
+    }.mean
+  }
+
+  /**
+   * Returns the mean average precision (MAP) of all the queries.
+   * If a query has an empty ground truth set, the average precision will be zero and a log
+   * warining is generated.
+   */
+  lazy val meanAveragePrecision: Double = {
+    predictionAndLabels.map { case (pred, lab) =>
+      val labSet = lab.toSet
+
+      if (labSet.nonEmpty) {
+        var i = 0
+        var cnt = 0
+        var precSum = 0.0
+        val n = pred.length
+        while (i < n) {
+          if (labSet.contains(pred(i))) {
+            cnt += 1
+            precSum += cnt.toDouble / (i + 1)
+          }
+          i += 1
+        }
+        precSum / labSet.size
+      } else {
+        logWarning("Empty ground truth set, check input data")
+        0.0
+      }
+    }.mean
+  }
+
+  /**
+   * Compute the average NDCG value of all the queries, truncated at ranking position k.
+   * The discounted cumulative gain at position k is computed as:
+   *    sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+   * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
+   * implementation, the relevance value is binary.
+
+   * If a query has an empty ground truth set, zero will be used as ndcg together with
+   * a log warning.
+   *
+   * See the following paper for detail:
+   *
+   * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+   *
+   * @param k the position to compute the truncated ndcg, must be positive
+   * @return the average ndcg at the first k ranking positions
+   */
+  def ndcgAt(k: Int): Double = {
+    require(k > 0, "ranking position k should be positive")
+    predictionAndLabels.map { case (pred, lab) =>
+      val labSet = lab.toSet
+
+      if (labSet.nonEmpty) {
+        val labSetSize = labSet.size
+        val n = math.min(math.max(pred.length, labSetSize), k)
+        var maxDcg = 0.0
+        var dcg = 0.0
+        var i = 0
+        while (i < n) {
+          val gain = 1.0 / math.log(i + 2)
+          if (labSet.contains(pred(i))) {
+            dcg += gain
+          }
+          if (i < labSetSize) {
+            maxDcg += gain
+          }
+          i += 1
+        }
+        dcg / maxDcg
+      } else {
+        logWarning("Empty ground truth set, check input data")
+        0.0
+      }
+    }.mean
+  }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
new file mode 100644
index 0000000000..a2d4bb4148
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class RankingMetricsSuite extends FunSuite with LocalSparkContext {
+  test("Ranking metrics: map, ndcg") {
+    val predictionAndLabels = sc.parallelize(
+      Seq(
+        (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
+        (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
+        (Array[Int](1, 2, 3, 4, 5), Array[Int]())
+      ), 2)
+    val eps: Double = 1E-5
+
+    val metrics = new RankingMetrics(predictionAndLabels)
+    val map = metrics.meanAveragePrecision
+
+    assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
+    assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
+    assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
+    assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
+    assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
+    assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
+    assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)
+
+    assert(map ~== 0.355026 absTol eps)
+
+    assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
+    assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
+    assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
+    assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+
+  }
+}
-- 
GitLab