From 0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Thu, 5 Mar 2015 11:50:09 -0800
Subject: [PATCH] [SPARK-6090][MLLIB] add a basic BinaryClassificationMetrics
 to PySpark/MLlib

A simple wrapper around the Scala implementation. `DataFrame` is used for serialization/deserialization. Methods that return `RDD`s are not supported in this PR.

davies If we recognize Scala's `Product`s in Py4J, we can easily add wrappers for Scala methods that returns `RDD[(Double, Double)]`. Is it easy to register serializer for `Product` in PySpark?

Author: Xiangrui Meng <meng@databricks.com>

Closes #4863 from mengxr/SPARK-6090 and squashes the following commits:

009a3a3 [Xiangrui Meng] provide schema
dcddab5 [Xiangrui Meng] add a basic BinaryClassificationMetrics to PySpark/MLlib
---
 .../BinaryClassificationMetrics.scala         |  8 ++
 python/docs/pyspark.mllib.rst                 |  7 ++
 python/pyspark/mllib/evaluation.py            | 83 +++++++++++++++++++
 python/run-tests                              |  1 +
 4 files changed, 99 insertions(+)
 create mode 100644 python/pyspark/mllib/evaluation.py

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index ced042e2f9..c1d1a22481 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -22,6 +22,7 @@ import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.evaluation.binary._
 import org.apache.spark.rdd.{RDD, UnionRDD}
+import org.apache.spark.sql.DataFrame
 
 /**
  * :: Experimental ::
@@ -53,6 +54,13 @@ class BinaryClassificationMetrics(
    */
   def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
 
+  /**
+   * An auxiliary constructor taking a DataFrame.
+   * @param scoreAndLabels a DataFrame with two double columns: score and label
+   */
+  private[mllib] def this(scoreAndLabels: DataFrame) =
+    this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
+
   /** Unpersist intermediate RDDs used in the computation. */
   def unpersist() {
     cumulativeCounts.unpersist()
diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
index b706c5e376..15101470af 100644
--- a/python/docs/pyspark.mllib.rst
+++ b/python/docs/pyspark.mllib.rst
@@ -16,6 +16,13 @@ pyspark.mllib.clustering module
     :members:
     :undoc-members:
 
+pyspark.mllib.evaluation module
+-------------------------------
+
+.. automodule:: pyspark.mllib.evaluation
+      :members:
+      :undoc-members:
+
 pyspark.mllib.feature module
 -------------------------------
 
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
new file mode 100644
index 0000000000..16cb49cc0c
--- /dev/null
+++ b/python/pyspark/mllib/evaluation.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.mllib.common import JavaModelWrapper
+from pyspark.sql import SQLContext
+from pyspark.sql.types import StructField, StructType, DoubleType
+
+
+class BinaryClassificationMetrics(JavaModelWrapper):
+    """
+    Evaluator for binary classification.
+
+    >>> scoreAndLabels = sc.parallelize([
+    ...     (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
+    >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
+    >>> metrics.areaUnderROC()
+    0.70...
+    >>> metrics.areaUnderPR()
+    0.83...
+    >>> metrics.unpersist()
+    """
+
+    def __init__(self, scoreAndLabels):
+        """
+        :param scoreAndLabels: an RDD of (score, label) pairs
+        """
+        sc = scoreAndLabels.ctx
+        sql_ctx = SQLContext(sc)
+        df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
+            StructField("score", DoubleType(), nullable=False),
+            StructField("label", DoubleType(), nullable=False)]))
+        java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+        java_model = java_class(df._jdf)
+        super(BinaryClassificationMetrics, self).__init__(java_model)
+
+    def areaUnderROC(self):
+        """
+        Computes the area under the receiver operating characteristic
+        (ROC) curve.
+        """
+        return self.call("areaUnderROC")
+
+    def areaUnderPR(self):
+        """
+        Computes the area under the precision-recall curve.
+        """
+        return self.call("areaUnderPR")
+
+    def unpersist(self):
+        """
+        Unpersists intermediate RDDs used in the computation.
+        """
+        self.call("unpersist")
+
+
+def _test():
+    import doctest
+    from pyspark import SparkContext
+    import pyspark.mllib.evaluation
+    globs = pyspark.mllib.evaluation.__dict__.copy()
+    globs['sc'] = SparkContext('local[4]', 'PythonTest')
+    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/run-tests b/python/run-tests
index a2c2f37a54..b7630c356c 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -75,6 +75,7 @@ function run_mllib_tests() {
     echo "Run mllib tests ..."
     run_test "pyspark/mllib/classification.py"
     run_test "pyspark/mllib/clustering.py"
+    run_test "pyspark/mllib/evaluation.py"
     run_test "pyspark/mllib/feature.py"
     run_test "pyspark/mllib/linalg.py"
     run_test "pyspark/mllib/rand.py"
-- 
GitLab