Skip to content
Snippets Groups Projects
Commit d1cf3201 authored by Sean Owen's avatar Sean Owen Committed by Nick Pentreath
Browse files

[SPARK-14886][MLLIB] RankingMetrics.ndcgAt throw java.lang.ArrayIndexOutOfBoundsException

## What changes were proposed in this pull request?

Handle case where number of predictions is less than label set, k in nDCG computation

## How was this patch tested?

New unit test; existing tests

Author: Sean Owen <sowen@cloudera.com>

Closes #12756 from srowen/SPARK-14886.
parent 24d07e45
No related branches found
No related tags found
No related merge requests found
......@@ -140,7 +140,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
var i = 0
while (i < n) {
val gain = 1.0 / math.log(i + 2)
if (labSet.contains(pred(i))) {
if (i < pred.length && labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
......
......@@ -22,14 +22,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Ranking metrics: map, ndcg") {
test("Ranking metrics: MAP, NDCG") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
(Array[Int](1, 2, 3, 4, 5), Array[Int]())
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
(Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)),
(Array(1, 2, 3, 4, 5), Array[Int]())
), 2)
val eps: Double = 1E-5
val eps = 1.0E-5
val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision
......@@ -48,6 +49,21 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
}
test("MAP, NDCG with few predictions (SPARK-14886)") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
(Array[Int](), Array(1, 2, 3))
), 2)
val eps = 1.0E-5
val metrics = new RankingMetrics(predictionAndLabels)
assert(metrics.precisionAt(1) ~== 0.5 absTol eps)
assert(metrics.precisionAt(2) ~== 0.25 absTol eps)
assert(metrics.ndcgAt(1) ~== 0.5 absTol eps)
assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment