From 2713bc65af1e0e81edd5fad0338e34fd127391f9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng <meng@databricks.com> Date: Tue, 12 May 2015 16:53:47 -0700 Subject: [PATCH] [SPARK-7528] [MLLIB] make RankingMetrics Java-friendly `RankingMetrics` contains a ClassTag, which is hard to create in Java. This PR adds a factory method `of` for Java users. coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #6098 from mengxr/SPARK-7528 and squashes the following commits: e5d57ae [Xiangrui Meng] make RankingMetrics Java-friendly --- .../mllib/evaluation/RankingMetrics.scala | 27 ++++++-- .../evaluation/JavaRankingMetricsSuite.java | 64 +++++++++++++++++++ 2 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 93a7353e2c..b9b54b93c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -17,11 +17,14 @@ package org.apache.spark.mllib.evaluation +import java.{lang => jl} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD /** @@ -71,7 +74,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] logWarning("Empty ground truth set, check input data") 0.0 } - }.mean + }.mean() } /** @@ -100,7 +103,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] logWarning("Empty ground truth set, check input data") 0.0 } - }.mean + }.mean() } /** @@ -146,7 +149,23 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] logWarning("Empty ground truth set, check input data") 0.0 } - }.mean + }.mean() } } + +@Experimental +object RankingMetrics { + + /** + * Creates a [[RankingMetrics]] instance (for Java users). + * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs + */ + def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { + implicit val tag = JavaSparkContext.fakeClassTag[E] + val rdd = predictionAndLabels.rdd.map { case (predictions, labels) => + (predictions.asScala.toArray, labels.asScala.toArray) + } + new RankingMetrics(rdd) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java new file mode 100644 index 0000000000..effc8a1a6d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation; + +import java.io.Serializable; +import java.util.ArrayList; + +import scala.Tuple2; +import scala.Tuple2$; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaRankingMetricsSuite implements Serializable { + private transient JavaSparkContext sc; + private transient JavaRDD<Tuple2<ArrayList<Integer>, ArrayList<Integer>>> predictionAndLabels; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); + predictionAndLabels = sc.parallelize(Lists.newArrayList( + Tuple2$.MODULE$.apply( + Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)), + Tuple2$.MODULE$.apply( + Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)), + Tuple2$.MODULE$.apply( + Lists.newArrayList(1, 2, 3, 4, 5), Lists.<Integer>newArrayList())), 2); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void rankingMetrics() { + @SuppressWarnings("unchecked") + RankingMetrics<?> metrics = RankingMetrics.of(predictionAndLabels); + Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5); + Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5); + } +} -- GitLab