diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c2aa3c06b3e7138438ad273313c8f6e1181369f..d9009e3848e583206aaa7ae72898484c81d791fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -182,6 +182,8 @@ object FunctionRegistry { expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), + expression[CovPopulation]("covar_pop"), + expression[CovSample]("covar_samp"), expression[First]("first"), expression[First]("first_value"), expression[Last]("last"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala new file mode 100644 index 0000000000000000000000000000000000000000..f53b01be2a0d5c02090b7a4605e497b0db44e299 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +/** + * Compute the covariance between two expressions. + * When applied on empty data (i.e., count is zero), it returns NULL. + * + */ +abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate + with Serializable { + override def children: Seq[Expression] = Seq(left, right) + + override def nullable: Boolean = true + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"covariance requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = { + aggBufferAttributes.map(_.newInstance()) + } + + override val aggBufferAttributes: Seq[AttributeReference] = Seq( + AttributeReference("xAvg", DoubleType)(), + AttributeReference("yAvg", DoubleType)(), + AttributeReference("Ck", DoubleType)(), + AttributeReference("count", LongType)()) + + // Local cache of mutableAggBufferOffset(s) that will be used in update and merge + val xAvgOffset = mutableAggBufferOffset + val yAvgOffset = mutableAggBufferOffset + 1 + val CkOffset = mutableAggBufferOffset + 2 + val countOffset = mutableAggBufferOffset + 3 + + // Local cache of inputAggBufferOffset(s) that will be used in update and merge + val inputXAvgOffset = inputAggBufferOffset + val inputYAvgOffset = inputAggBufferOffset + 1 + val inputCkOffset = inputAggBufferOffset + 2 + val inputCountOffset = inputAggBufferOffset + 3 + + override def initialize(buffer: MutableRow): Unit = { + buffer.setDouble(xAvgOffset, 0.0) + buffer.setDouble(yAvgOffset, 0.0) + buffer.setDouble(CkOffset, 0.0) + buffer.setLong(countOffset, 0L) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val leftEval = left.eval(input) + val rightEval = right.eval(input) + + if (leftEval != null && rightEval != null) { + val x = leftEval.asInstanceOf[Double] + val y = rightEval.asInstanceOf[Double] + + var xAvg = buffer.getDouble(xAvgOffset) + var yAvg = buffer.getDouble(yAvgOffset) + var Ck = buffer.getDouble(CkOffset) + var count = buffer.getLong(countOffset) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + + buffer.setDouble(xAvgOffset, xAvg) + buffer.setDouble(yAvgOffset, yAvg) + buffer.setDouble(CkOffset, Ck) + buffer.setLong(countOffset, count) + } + } + + // Merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val count2 = buffer2.getLong(inputCountOffset) + + // We only go to merge two buffers if there is at least one record aggregated in buffer2. + // We don't need to check count in buffer1 because if count2 is more than zero, totalCount + // is more than zero too, then we won't get a divide by zero exception. + if (count2 > 0) { + var xAvg = buffer1.getDouble(xAvgOffset) + var yAvg = buffer1.getDouble(yAvgOffset) + var Ck = buffer1.getDouble(CkOffset) + var count = buffer1.getLong(countOffset) + + val xAvg2 = buffer2.getDouble(inputXAvgOffset) + val yAvg2 = buffer2.getDouble(inputYAvgOffset) + val Ck2 = buffer2.getDouble(inputCkOffset) + + val totalCount = count + count2 + val deltaX = xAvg - xAvg2 + val deltaY = yAvg - yAvg2 + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 + xAvg = (xAvg * count + xAvg2 * count2) / totalCount + yAvg = (yAvg * count + yAvg2 * count2) / totalCount + count = totalCount + + buffer1.setDouble(xAvgOffset, xAvg) + buffer1.setDouble(yAvgOffset, yAvg) + buffer1.setDouble(CkOffset, Ck) + buffer1.setLong(countOffset, count) + } + } +} + +case class CovSample( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends Covariance(left, right) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(countOffset) + if (count > 1) { + val Ck = buffer.getDouble(CkOffset) + val cov = Ck / (count - 1) + if (cov.isNaN) { + null + } else { + cov + } + } else { + null + } + } +} + +case class CovPopulation( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends Covariance(left, right) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(countOffset) + if (count > 0) { + val Ck = buffer.getDouble(CkOffset) + if (Ck.isNaN) { + null + } else { + Ck / count + } + } else { + null + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 592d79df3109ae6a9dd78ed1be6e82e02dbe19e5..71fea2716bd9f9ccc2a1f5840cc6f79c473658ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -308,6 +308,46 @@ object functions extends LegacyFunctions { def countDistinct(columnName: String, columnNames: String*): Column = countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) + /** + * Aggregate function: returns the population covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction { + CovPopulation(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the population covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_pop(columnName1: String, columnName2: String): Column = { + covar_pop(Column(columnName1), Column(columnName2)) + } + + /** + * Aggregate function: returns the sample covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction { + CovSample(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the sample covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_samp(columnName1: String, columnName2: String): Column = { + covar_samp(Column(columnName1), Column(columnName2)) + } + /** * Aggregate function: returns the first value in a group. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 5550198c02fbf89032d4888eb0ce573d76062d2f..76b36aa89182e4daef1bdc7fb894477bc50529f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) } + test("covariance: covar_pop and covar_samp") { + // non-trivial example. To reproduce in python, use: + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> np.cov(a, b, bias = 0)[0][1] + // 595.0 + // >>> np.cov(a, b, bias = 1)[0][1] + // 565.25 + val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp - 595.0) < 1e-12) + + val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop - 565.25) < 1e-12) + + val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp2 - 11564.0) < 1e-12) + + val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12) + + // one row test + val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0) + assert(cov_samp3 == null) + + val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(cov_pop3 == 0.0) + } + test("no aggregation function (SPARK-11486)") { val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count()