Skip to content
Snippets Groups Projects
Commit fdfac22d authored by Narine Kokhlikyan's avatar Narine Kokhlikyan Committed by Reynold Xin
Browse files

[SPARK-12509][SQL] Fixed error messages for DataFrame correlation and covariance

Currently, when we call corr or cov on dataframe with invalid input we see these error messages for both corr and cov:
   -  "Currently cov supports calculating the covariance between two columns"
   -  "Covariance calculation for columns with dataType "[DataType Name]" not supported."

I've fixed this issue by passing the function name as an argument. We could also do the input checks separately for each function. I avoided doing that because of code duplication.

Thanks!

Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>

Closes #10458 from NarineK/sparksqlstatsmessages.
parent 34de24ab
No related branches found
No related tags found
No related merge requests found
...@@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging { ...@@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging {
/** Calculate the Pearson Correlation Coefficient for the given columns */ /** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols) val counts = collectStatisticalData(df, cols, "correlation")
counts.Ck / math.sqrt(counts.MkX * counts.MkY) counts.Ck / math.sqrt(counts.MkX * counts.MkY)
} }
...@@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging { ...@@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging {
def cov: Double = Ck / (count - 1) def cov: Double = Ck / (count - 1)
} }
private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { private def collectStatisticalData(df: DataFrame, cols: Seq[String],
require(cols.length == 2, "Currently cov supports calculating the covariance " + functionName: String): CovarianceCounter = {
require(cols.length == 2, s"Currently $functionName calculation is supported " +
"between two columns.") "between two columns.")
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
require(data.nonEmpty, s"Couldn't find column with name $name") require(data.nonEmpty, s"Couldn't find column with name $name")
require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " +
s"with dataType ${data.get.dataType} not supported.") s"for columns with dataType ${data.get.dataType} not supported.")
} }
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)(
...@@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging { ...@@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging {
* @return the covariance of the two columns. * @return the covariance of the two columns.
*/ */
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols) val counts = collectStatisticalData(df, cols, "covariance")
counts.cov counts.cov
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment