Skip to content
Snippets Groups Projects
Commit 2713bc65 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7528] [MLLIB] make RankingMetrics Java-friendly

`RankingMetrics` contains a ClassTag, which is hard to create in Java. This PR adds a factory method `of` for Java users. coderxiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #6098 from mengxr/SPARK-7528 and squashes the following commits:

e5d57ae [Xiangrui Meng] make RankingMetrics Java-friendly
parent 00e7b09a
No related branches found
No related tags found
No related merge requests found
......@@ -17,11 +17,14 @@
package org.apache.spark.mllib.evaluation
import java.{lang => jl}
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.rdd.RDD
/**
......@@ -71,7 +74,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}.mean()
}
/**
......@@ -100,7 +103,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}.mean()
}
/**
......@@ -146,7 +149,23 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean
}.mean()
}
}
@Experimental
object RankingMetrics {
/**
* Creates a [[RankingMetrics]] instance (for Java users).
* @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
*/
def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
implicit val tag = JavaSparkContext.fakeClassTag[E]
val rdd = predictionAndLabels.rdd.map { case (predictions, labels) =>
(predictions.asScala.toArray, labels.asScala.toArray)
}
new RankingMetrics(rdd)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.evaluation;
import java.io.Serializable;
import java.util.ArrayList;
import scala.Tuple2;
import scala.Tuple2$;
import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
public class JavaRankingMetricsSuite implements Serializable {
private transient JavaSparkContext sc;
private transient JavaRDD<Tuple2<ArrayList<Integer>, ArrayList<Integer>>> predictionAndLabels;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
predictionAndLabels = sc.parallelize(Lists.newArrayList(
Tuple2$.MODULE$.apply(
Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)),
Tuple2$.MODULE$.apply(
Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)),
Tuple2$.MODULE$.apply(
Lists.newArrayList(1, 2, 3, 4, 5), Lists.<Integer>newArrayList())), 2);
}
@After
public void tearDown() {
sc.stop();
sc = null;
}
@Test
public void rankingMetrics() {
@SuppressWarnings("unchecked")
RankingMetrics<?> metrics = RankingMetrics.of(predictionAndLabels);
Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment