Skip to content
Snippets Groups Projects
Commit d19f4fda authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-11505][SQL] Break aggregate functions into multiple files

functions.scala was getting pretty long. I broke it into multiple files.

I also added explicit data types for some public vals, and renamed aggregate function pretty names to lower case, which is more consistent with rest of the functions.

Author: Reynold Xin <rxin@databricks.com>

Closes #9471 from rxin/SPARK-11505.
parent abf5e428
No related branches found
No related tags found
No related merge requests found
Showing
with 1223 additions and 950 deletions
......@@ -157,11 +157,14 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
*/
@Override
public long spill(long size, MemoryConsumer trigger) throws IOException {
assert(inMemSorter != null);
if (trigger != this) {
if (readingIterator != null) {
return readingIterator.spill();
} else {
}
return 0L;
return 0L; // this should throw exception
}
if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Average(child: Expression) extends DeclarativeAggregate {
override def prettyName: String = "avg"
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = resultType
// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select avg(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
private val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _ => DoubleType
}
private val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _ => DoubleType
}
private val sum = AttributeReference("sum", sumDataType)()
private val count = AttributeReference("count", LongType)()
override val aggBufferAttributes = sum :: count :: Nil
override val initialValues = Seq(
/* sum = */ Cast(Literal(0), sumDataType),
/* count = */ Literal(0L)
)
override val updateExpressions = Seq(
/* sum = */
Add(
sum,
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
/* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
)
// If all input are nulls, count will be 0 and we will get null after the division.
override val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
Cast(Cast(sum, dt) / Cast(count, dt), resultType)
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
* A central moment is the expected value of a specified power of the deviation of a random
* variable from the mean. Central moments are often used to characterize the properties of about
* the shape of a distribution.
*
* This class implements online, one-pass algorithms for computing the central moments of a set of
* points.
*
* Behavior:
* - null values are ignored
* - returns `Double.NaN` when the column contains `Double.NaN` values
*
* References:
* - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments."
* 2015. http://arxiv.org/abs/1510.04923
*
* @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
* Algorithms for calculating variance (Wikipedia)]]
*
* @param child to compute central moments of.
*/
abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable {
/**
* The central moment order to be computed.
*/
protected def momentOrder: Int
override def children: Seq[Expression] = Seq(child)
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select avg(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
/**
* Size of aggregation buffer.
*/
private[this] val bufferSize = 5
override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i =>
AttributeReference(s"M$i", DoubleType)()
}
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())
// buffer offsets
private[this] val nOffset = mutableAggBufferOffset
private[this] val meanOffset = mutableAggBufferOffset + 1
private[this] val secondMomentOffset = mutableAggBufferOffset + 2
private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
private[this] val fourthMomentOffset = mutableAggBufferOffset + 4
// frequently used values for online updates
private[this] var delta = 0.0
private[this] var deltaN = 0.0
private[this] var delta2 = 0.0
private[this] var deltaN2 = 0.0
private[this] var n = 0.0
private[this] var mean = 0.0
private[this] var m2 = 0.0
private[this] var m3 = 0.0
private[this] var m4 = 0.0
/**
* Initialize all moments to zero.
*/
override def initialize(buffer: MutableRow): Unit = {
for (aggIndex <- 0 until bufferSize) {
buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
}
}
/**
* Update the central moments buffer.
*/
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val v = Cast(child, DoubleType).eval(input)
if (v != null) {
val updateValue = v match {
case d: Double => d
}
n = buffer.getDouble(nOffset)
mean = buffer.getDouble(meanOffset)
n += 1.0
buffer.setDouble(nOffset, n)
delta = updateValue - mean
deltaN = delta / n
mean += deltaN
buffer.setDouble(meanOffset, mean)
if (momentOrder >= 2) {
m2 = buffer.getDouble(secondMomentOffset)
m2 += delta * (delta - deltaN)
buffer.setDouble(secondMomentOffset, m2)
}
if (momentOrder >= 3) {
delta2 = delta * delta
deltaN2 = deltaN * deltaN
m3 = buffer.getDouble(thirdMomentOffset)
m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
buffer.setDouble(thirdMomentOffset, m3)
}
if (momentOrder >= 4) {
m4 = buffer.getDouble(fourthMomentOffset)
m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
delta * (delta * delta2 - deltaN * deltaN2)
buffer.setDouble(fourthMomentOffset, m4)
}
}
}
/**
* Merge two central moment buffers.
*/
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val n1 = buffer1.getDouble(nOffset)
val n2 = buffer2.getDouble(inputAggBufferOffset)
val mean1 = buffer1.getDouble(meanOffset)
val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
var secondMoment1 = 0.0
var secondMoment2 = 0.0
var thirdMoment1 = 0.0
var thirdMoment2 = 0.0
var fourthMoment1 = 0.0
var fourthMoment2 = 0.0
n = n1 + n2
buffer1.setDouble(nOffset, n)
delta = mean2 - mean1
deltaN = if (n == 0.0) 0.0 else delta / n
mean = mean1 + deltaN * n2
buffer1.setDouble(mutableAggBufferOffset + 1, mean)
// higher order moments computed according to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
if (momentOrder >= 2) {
secondMoment1 = buffer1.getDouble(secondMomentOffset)
secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
buffer1.setDouble(secondMomentOffset, m2)
}
if (momentOrder >= 3) {
thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
(n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
buffer1.setDouble(thirdMomentOffset, m3)
}
if (momentOrder >= 4) {
fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 *
n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
(n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
buffer1.setDouble(fourthMomentOffset, m4)
}
}
/**
* Compute aggregate statistic from sufficient moments.
* @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
* needed to compute the aggregate stat.
*/
def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double
override final def eval(buffer: InternalRow): Any = {
val n = buffer.getDouble(nOffset)
val mean = buffer.getDouble(meanOffset)
val moments = Array.ofDim[Double](momentOrder + 1)
moments(0) = 1.0
moments(1) = 0.0
if (momentOrder >= 2) {
moments(2) = buffer.getDouble(secondMomentOffset)
}
if (momentOrder >= 3) {
moments(3) = buffer.getDouble(thirdMomentOffset)
}
if (momentOrder >= 4) {
moments(4) = buffer.getDouble(fourthMomentOffset)
}
getStatistic(n, mean, moments)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
* Compute Pearson correlation between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
case class Corr(
left: Expression,
right: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate {
override def children: Seq[Expression] = Seq(left, right)
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override def inputAggBufferAttributes: Seq[AttributeReference] = {
aggBufferAttributes.map(_.newInstance())
}
override val aggBufferAttributes: Seq[AttributeReference] = Seq(
AttributeReference("xAvg", DoubleType)(),
AttributeReference("yAvg", DoubleType)(),
AttributeReference("Ck", DoubleType)(),
AttributeReference("MkX", DoubleType)(),
AttributeReference("MkY", DoubleType)(),
AttributeReference("count", LongType)())
// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5
// Local cache of inputAggBufferOffset(s) that will be used in update and merge
private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def initialize(buffer: MutableRow): Unit = {
buffer.setDouble(mutableAggBufferOffset, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
}
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val leftEval = left.eval(input)
val rightEval = right.eval(input)
if (leftEval != null && rightEval != null) {
val x = leftEval.asInstanceOf[Double]
val y = rightEval.asInstanceOf[Double]
var xAvg = buffer.getDouble(mutableAggBufferOffset)
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
var count = buffer.getLong(mutableAggBufferOffsetPlus5)
val deltaX = x - xAvg
val deltaY = y - yAvg
count += 1
xAvg += deltaX / count
yAvg += deltaY / count
Ck += deltaX * (y - yAvg)
MkX += deltaX * (x - xAvg)
MkY += deltaY * (y - yAvg)
buffer.setDouble(mutableAggBufferOffset, xAvg)
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
buffer.setLong(mutableAggBufferOffsetPlus5, count)
}
}
// Merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)
// We only go to merge two buffers if there is at least one record aggregated in buffer2.
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
// is more than zero too, then we won't get a divide by zero exception.
if (count2 > 0) {
var xAvg = buffer1.getDouble(mutableAggBufferOffset)
var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
var count = buffer1.getLong(mutableAggBufferOffsetPlus5)
val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)
val totalCount = count + count2
val deltaX = xAvg - xAvg2
val deltaY = yAvg - yAvg2
Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
xAvg = (xAvg * count + xAvg2 * count2) / totalCount
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
count = totalCount
buffer1.setDouble(mutableAggBufferOffset, xAvg)
buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
buffer1.setLong(mutableAggBufferOffsetPlus5, count)
}
}
override def eval(buffer: InternalRow): Any = {
val count = buffer.getLong(mutableAggBufferOffsetPlus5)
if (count > 0) {
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
val corr = Ck / math.sqrt(MkX * MkY)
if (corr.isNaN) {
null
} else {
corr
}
} else {
null
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Count(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = false
// Return data type.
override def dataType: DataType = LongType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val count = AttributeReference("count", LongType)()
override val aggBufferAttributes = count :: Nil
override val initialValues = Seq(
/* count = */ Literal(0L)
)
override val updateExpressions = Seq(
/* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
override val evaluateExpression = Cast(count, LongType)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
* Returns the first value of `child` for a group of rows. If the first value of `child`
* is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already
* sorted column, if we do partial aggregation and final aggregation (when mergeExpression
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
*/
case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// First is not a deterministic function.
override def deterministic: Boolean = false
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val first = AttributeReference("first", child.dataType)()
private val valueSet = AttributeReference("valueSet", BooleanType)()
override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil
override val initialValues: Seq[Literal] = Seq(
/* first = */ Literal.create(null, child.dataType),
/* valueSet = */ Literal.create(false, BooleanType)
)
override val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* first = */ If(Or(valueSet, IsNull(child)), first, child),
/* valueSet = */ Or(valueSet, IsNotNull(child))
)
} else {
Seq(
/* first = */ If(valueSet, first, child),
/* valueSet = */ Literal.create(true, BooleanType)
)
}
}
override val mergeExpressions: Seq[Expression] = {
// For first, we can just check if valueSet.left is set to true. If it is set
// to true, we use first.right. If not, we use first.right (even if valueSet.right is
// false, we are safe to do so because first.right will be null in this case).
Seq(
/* first = */ If(valueSet.left, first.left, first.right),
/* valueSet = */ Or(valueSet.left, valueSet.right)
)
}
override val evaluateExpression: AttributeReference = first
override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions._
case class Kurtosis(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "kurtosis"
override protected val momentOrder = 4
// NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)
if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
n * m4 / (m2 * m2) - 3.0
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
* Returns the last value of `child` for a group of rows. If the last value of `child`
* is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already
* sorted column, if we do partial aggregation and final aggregation (when mergeExpression
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
*/
case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Last is not a deterministic function.
override def deterministic: Boolean = false
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val last = AttributeReference("last", child.dataType)()
override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
override val initialValues: Seq[Literal] = Seq(
/* last = */ Literal.create(null, child.dataType)
)
override val updateExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(child), last, child)
)
} else {
Seq(
/* last = */ child
)
}
}
override val mergeExpressions: Seq[Expression] = {
if (ignoreNulls) {
Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
)
} else {
Seq(
/* last = */ last.right
)
}
}
override val evaluateExpression: AttributeReference = last
override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Max(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val max = AttributeReference("max", child.dataType)()
override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
override val initialValues: Seq[Literal] = Seq(
/* max = */ Literal.create(null, child.dataType)
)
override val updateExpressions: Seq[Expression] = Seq(
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
)
override val mergeExpressions: Seq[Expression] = {
val greatest = Greatest(Seq(max.left, max.right))
Seq(
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
)
}
override val evaluateExpression: AttributeReference = max
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Min(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val min = AttributeReference("min", child.dataType)()
override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
override val initialValues: Seq[Expression] = Seq(
/* min = */ Literal.create(null, child.dataType)
)
override val updateExpressions: Seq[Expression] = Seq(
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
)
override val mergeExpressions: Seq[Expression] = {
val least = Least(Seq(min.left, min.right))
Seq(
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
)
}
override val evaluateExpression: AttributeReference = min
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions._
case class Skewness(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "skewness"
override protected val momentOrder = 3
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)
if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = false
override def prettyName: String = "stddev_pop"
}
// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = true
override def prettyName: String = "stddev_samp"
}
// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
def isSample: Boolean
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = resultType
// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select stddev(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
private val resultType = DoubleType
private val count = AttributeReference("count", resultType)()
private val avg = AttributeReference("avg", resultType)()
private val mk = AttributeReference("mk", resultType)()
override val aggBufferAttributes = count :: avg :: mk :: Nil
override val initialValues: Seq[Expression] = Seq(
/* count = */ Cast(Literal(0), resultType),
/* avg = */ Cast(Literal(0), resultType),
/* mk = */ Cast(Literal(0), resultType)
)
override val updateExpressions: Seq[Expression] = {
val value = Cast(child, resultType)
val newCount = count + Cast(Literal(1), resultType)
// update average
// avg = avg + (value - avg)/count
val newAvg = avg + (value - avg) / newCount
// update sum ofference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
val newMk = mk + (value - avg) * (value - newAvg)
Seq(
/* count = */ If(IsNull(child), count, newCount),
/* avg = */ If(IsNull(child), avg, newAvg),
/* mk = */ If(IsNull(child), mk, newMk)
)
}
override val mergeExpressions: Seq[Expression] = {
// count merge
val newCount = count.left + count.right
// average merge
val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount
// update sum of square differences
val newMk = {
val avgDelta = avg.right - avg.left
val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
mk.left + mk.right + mkDelta
}
Seq(
/* count = */ If(IsNull(count.left), count.right,
If(IsNull(count.right), count.left, newCount)),
/* avg = */ If(IsNull(avg.left), avg.right,
If(IsNull(avg.right), avg.left, newAvg)),
/* mk = */ If(IsNull(mk.left), mk.right,
If(IsNull(mk.right), mk.left, newMk))
)
}
override val evaluateExpression: Expression = {
// when count == 0, return null
// when count == 1, return 0
// when count >1
// stddev_samp = sqrt (mk/(count -1))
// stddev_pop = sqrt (mk/count)
val varCol =
if (isSample) {
mk / Cast(count - Cast(Literal(1), resultType), resultType)
} else {
mk / count
}
If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Sum(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = resultType
// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select sum(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
// TODO: Remove this line once we remove the NullType from inputTypes.
case NullType => IntegerType
case _ => child.dataType
}
private val sumDataType = resultType
private val sum = AttributeReference("sum", sumDataType)()
private val zero = Cast(Literal(0), sumDataType)
override val aggBufferAttributes = sum :: Nil
override val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
)
override val updateExpressions: Seq[Expression] = Seq(
/* sum = */
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
override val mergeExpressions: Seq[Expression] = {
val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
/* sum = */
Coalesce(Seq(add, sum.left))
)
}
override val evaluateExpression: Expression = Cast(sum, resultType)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions._
case class VarianceSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "var_samp"
override protected val momentOrder = 2
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
}
}
case class VariancePop(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "var_pop"
override protected val momentOrder = 2
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
if (n == 0.0) Double.NaN else moments(2) / n
}
}
......@@ -549,7 +549,7 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
case _ =>
child.dataType
}
override def toString: String = s"SUM(DISTINCT $child)"
override def toString: String = s"sum(distinct $child)"
override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
override def asPartial: SplitEvaluation = {
......@@ -646,7 +646,7 @@ case class First(
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})"
override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})"
override def asPartial: SplitEvaluation = {
val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
......@@ -707,7 +707,7 @@ case class Last(
override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"
override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
override def asPartial: SplitEvaluation = {
val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
......@@ -756,7 +756,7 @@ case class Corr(left: Expression, right: Expression)
extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
override def nullable: Boolean = false
override def dataType: DoubleType.type = DoubleType
override def toString: String = s"CORRELATION($left, $right)"
override def toString: String = s"corr($left, $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
override def newInstance(): AggregateFunction1 = {
throw new UnsupportedOperationException(
......@@ -788,14 +788,14 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression with Partia
// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg1(child) {
override def toString: String = s"STDDEV_POP($child)"
override def toString: String = s"stddev_pop($child)"
override def isSample: Boolean = false
}
// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends StddevAgg1(child) {
override def toString: String = s"STDDEV_SAMP($child)"
override def toString: String = s"stddev_samp($child)"
override def isSample: Boolean = true
}
......@@ -1019,8 +1019,6 @@ case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExp
override def foldable: Boolean = false
override def prettyName: String = "kurtosis"
override def toString: String = s"KURTOSIS($child)"
}
// placeholder
......@@ -1038,8 +1036,6 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp
override def foldable: Boolean = false
override def prettyName: String = "skewness"
override def toString: String = s"SKEWNESS($child)"
}
// placeholder
......@@ -1056,9 +1052,7 @@ case class VariancePop(child: Expression) extends UnaryExpression with Aggregate
override def foldable: Boolean = false
override def prettyName: String = "variance_pop"
override def toString: String = s"VAR_POP($child)"
override def prettyName: String = "var_pop"
}
// placeholder
......@@ -1075,7 +1069,5 @@ case class VarianceSamp(child: Expression) extends UnaryExpression with Aggregat
override def foldable: Boolean = false
override def prettyName: String = "variance_samp"
override def toString: String = s"VAR_SAMP($child)"
override def prettyName: String = "var_samp"
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment