From d27daa54bd341b29737a6352d9a1055151248ae7 Mon Sep 17 00:00:00 2001
From: Timothy Hunter <timhunter@databricks.com>
Date: Thu, 23 Mar 2017 18:42:13 -0700
Subject: [PATCH] [SPARK-19636][ML] Feature parity for correlation statistics
 in MLlib

## What changes were proposed in this pull request?

This patch adds the Dataframes-based support for the correlation statistics found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following the design doc discussed in the JIRA ticket.

The current implementation is a simple wrapper around the `spark.mllib` implementation. Future optimizations can be implemented at a later stage.

## How was this patch tested?

```
build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite"
```

Author: Timothy Hunter <timhunter@databricks.com>

Closes #17108 from thunterdb/19636.
---
 .../apache/spark/ml/util/TestingUtils.scala   |  8 ++
 .../apache/spark/ml/stat/Correlation.scala    | 86 +++++++++++++++++++
 .../spark/ml/stat/CorrelationSuite.scala      | 77 +++++++++++++++++
 3 files changed, 171 insertions(+)
 create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala

diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917e2c..30edd00fb5 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -32,6 +32,10 @@ object TestingUtils {
    * the relative tolerance is meaningless, so the exception will be raised to warn users.
    */
   private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+    // Special case for NaNs
+    if (x.isNaN && y.isNaN) {
+      return true
+    }
     val absX = math.abs(x)
     val absY = math.abs(y)
     val diff = math.abs(x - y)
@@ -49,6 +53,10 @@ object TestingUtils {
    * Private helper function for comparing two values using absolute tolerance.
    */
   private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+    // Special case for NaNs
+    if (x.isNaN && y.isNaN) {
+      return true
+    }
     math.abs(x - y) < eps
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
new file mode 100644
index 0000000000..a7243ccbf2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
+import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
+import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * API for correlation functions in MLlib, compatible with Dataframes and Datasets.
+ *
+ * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]]
+ * to spark.ml's Vector types.
+ */
+@Since("2.2.0")
+@Experimental
+object Correlation {
+
+  /**
+   * :: Experimental ::
+   * Compute the correlation matrix for the input RDD of Vectors using the specified method.
+   * Methods currently supported: `pearson` (default), `spearman`.
+   *
+   * @param dataset A dataset or a dataframe
+   * @param column The name of the column of vectors for which the correlation coefficient needs
+   *               to be computed. This must be a column of the dataset, and it must contain
+   *               Vector objects.
+   * @param method String specifying the method to use for computing correlation.
+   *               Supported: `pearson` (default), `spearman`
+   * @return A dataframe that contains the correlation matrix of the column of vectors. This
+   *         dataframe contains a single row and a single column of name
+   *         '$METHODNAME($COLUMN)'.
+   * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if
+   *                                  the content of this column is not of type Vector.
+   *
+   *  Here is how to access the correlation coefficient:
+   *  {{{
+   *    val data: Dataset[Vector] = ...
+   *    val Row(coeff: Matrix) = Statistics.corr(data, "value").head
+   *    // coeff now contains the Pearson correlation matrix.
+   *  }}}
+   *
+   * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
+   * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
+   * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
+   * avoid recomputing the common lineage.
+   */
+  @Since("2.2.0")
+  def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
+    val rdd = dataset.select(column).rdd.map {
+      case Row(v: Vector) => OldVectors.fromML(v)
+    }
+    val oldM = OldStatistics.corr(rdd, method)
+    val name = s"$method($column)"
+    val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
+    dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
+  }
+
+  /**
+   * Compute the Pearson correlation matrix for the input Dataset of Vectors.
+   */
+  @Since("2.2.0")
+  def corr(dataset: Dataset[_], column: String): DataFrame = {
+    corr(dataset, column, "pearson")
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
new file mode 100644
index 0000000000..7d935e651f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
+
+  val xData = Array(1.0, 0.0, -2.0)
+  val yData = Array(4.0, 5.0, 3.0)
+  val zeros = new Array[Double](3)
+  val data = Seq(
+    Vectors.dense(1.0, 0.0, 0.0, -2.0),
+    Vectors.dense(4.0, 5.0, 0.0, 3.0),
+    Vectors.dense(6.0, 7.0, 0.0, 8.0),
+    Vectors.dense(9.0, 0.0, 0.0, 1.0)
+  )
+
+  private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")
+
+  private def extract(df: DataFrame): BDM[Double] = {
+    val Array(Row(mat: Matrix)) = df.collect()
+    mat.asBreeze.toDenseMatrix
+  }
+
+
+  test("corr(X) default, pearson") {
+    val defaultMat = Correlation.corr(X, "features")
+    val pearsonMat = Correlation.corr(X, "features", "pearson")
+    // scalastyle:off
+    val expected = Matrices.fromBreeze(BDM(
+      (1.00000000, 0.05564149, Double.NaN, 0.4004714),
+      (0.05564149, 1.00000000, Double.NaN, 0.9135959),
+      (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
+      (0.40047142, 0.91359586, Double.NaN, 1.0000000)))
+    // scalastyle:on
+
+    assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4)
+    assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4)
+  }
+
+  test("corr(X) spearman") {
+    val spearmanMat = Correlation.corr(X, "features", "spearman")
+    // scalastyle:off
+    val expected = Matrices.fromBreeze(BDM(
+      (1.0000000,  0.1054093,  Double.NaN, 0.4000000),
+      (0.1054093,  1.0000000,  Double.NaN, 0.9486833),
+      (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
+      (0.4000000,  0.9486833,  Double.NaN, 1.0000000)))
+    // scalastyle:on
+    assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4)
+  }
+
+}
-- 
GitLab