From d27daa54bd341b29737a6352d9a1055151248ae7 Mon Sep 17 00:00:00 2001 From: Timothy Hunter <timhunter@databricks.com> Date: Thu, 23 Mar 2017 18:42:13 -0700 Subject: [PATCH] [SPARK-19636][ML] Feature parity for correlation statistics in MLlib ## What changes were proposed in this pull request? This patch adds the Dataframes-based support for the correlation statistics found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following the design doc discussed in the JIRA ticket. The current implementation is a simple wrapper around the `spark.mllib` implementation. Future optimizations can be implemented at a later stage. ## How was this patch tested? ``` build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite" ``` Author: Timothy Hunter <timhunter@databricks.com> Closes #17108 from thunterdb/19636. --- .../apache/spark/ml/util/TestingUtils.scala | 8 ++ .../apache/spark/ml/stat/Correlation.scala | 86 +++++++++++++++++++ .../spark/ml/stat/CorrelationSuite.scala | 77 +++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 2327917e2c..30edd00fb5 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -32,6 +32,10 @@ object TestingUtils { * the relative tolerance is meaningless, so the exception will be raised to warn users. */ private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } val absX = math.abs(x) val absY = math.abs(y) val diff = math.abs(x - y) @@ -49,6 +53,10 @@ object TestingUtils { * Private helper function for comparing two values using absolute tolerance. */ private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } math.abs(x - y) < eps } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala new file mode 100644 index 0000000000..a7243ccbf2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{SQLDataTypes, Vector} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * API for correlation functions in MLlib, compatible with Dataframes and Datasets. + * + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] + * to spark.ml's Vector types. + */ +@Since("2.2.0") +@Experimental +object Correlation { + + /** + * :: Experimental :: + * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Methods currently supported: `pearson` (default), `spearman`. + * + * @param dataset A dataset or a dataframe + * @param column The name of the column of vectors for which the correlation coefficient needs + * to be computed. This must be a column of the dataset, and it must contain + * Vector objects. + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + * @return A dataframe that contains the correlation matrix of the column of vectors. This + * dataframe contains a single row and a single column of name + * '$METHODNAME($COLUMN)'. + * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if + * the content of this column is not of type Vector. + * + * Here is how to access the correlation coefficient: + * {{{ + * val data: Dataset[Vector] = ... + * val Row(coeff: Matrix) = Statistics.corr(data, "value").head + * // coeff now contains the Pearson correlation matrix. + * }}} + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { + val rdd = dataset.select(column).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } + val oldM = OldStatistics.corr(rdd, method) + val name = s"$method($column)" + val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false))) + dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema) + } + + /** + * Compute the Pearson correlation matrix for the input Dataset of Vectors. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String): DataFrame = { + corr(dataset, column, "pearson") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala new file mode 100644 index 0000000000..7d935e651f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + + +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + val xData = Array(1.0, 0.0, -2.0) + val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) + val data = Seq( + Vectors.dense(1.0, 0.0, 0.0, -2.0), + Vectors.dense(4.0, 5.0, 0.0, 3.0), + Vectors.dense(6.0, 7.0, 0.0, 8.0), + Vectors.dense(9.0, 0.0, 0.0, 1.0) + ) + + private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + private def extract(df: DataFrame): BDM[Double] = { + val Array(Row(mat: Matrix)) = df.collect() + mat.asBreeze.toDenseMatrix + } + + + test("corr(X) default, pearson") { + val defaultMat = Correlation.corr(X, "features") + val pearsonMat = Correlation.corr(X, "features", "pearson") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.00000000, 0.05564149, Double.NaN, 0.4004714), + (0.05564149, 1.00000000, Double.NaN, 0.9135959), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.40047142, 0.91359586, Double.NaN, 1.0000000))) + // scalastyle:on + + assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4) + assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4) + } + + test("corr(X) spearman") { + val spearmanMat = Correlation.corr(X, "features", "spearman") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.0000000, 0.1054093, Double.NaN, 0.4000000), + (0.1054093, 1.0000000, Double.NaN, 0.9486833), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.4000000, 0.9486833, Double.NaN, 1.0000000))) + // scalastyle:on + assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4) + } + +} -- GitLab