Skip to content
Snippets Groups Projects
Commit 1748f824 authored by Reynold Xin's avatar Reynold Xin Committed by Wenchen Fan
Browse files

[SPARK-16391][SQL] Support partial aggregation for reduceGroups

## What changes were proposed in this pull request?
This patch introduces a new private ReduceAggregator interface that is a subclass of Aggregator. ReduceAggregator only requires a single associative and commutative reduce function. ReduceAggregator is also used to implement KeyValueGroupedDataset.reduceGroups in order to support partial aggregation.

Note that the pull request was initially done by viirya.

## How was this patch tested?
Covered by original tests for reduceGroups, as well as a new test suite for ReduceAggregator.

Author: Reynold Xin <rxin@databricks.com>
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #14576 from rxin/reduceAggregator.
parent 3e6ef2e8
No related branches found
No related tags found
No related merge requests found
......@@ -21,10 +21,11 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
/**
* :: Experimental ::
......@@ -177,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
flatMapGroups(func)
val vEncoder = encoderFor[V]
val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
agg(aggregator)
}
/**
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
/**
* An aggregator that uses a single associative and commutative reduce function. This reduce
* function can be used to go through all input values and reduces them to a single value.
* If there is no input, a null value is returned.
*
* This class currently assumes there is at least one input row.
*/
private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
extends Aggregator[T, (Boolean, T), T] {
private val encoder = implicitly[Encoder[T]]
override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
override def bufferEncoder: Encoder[(Boolean, T)] =
ExpressionEncoder.tuple(
ExpressionEncoder[Boolean](),
encoder.asInstanceOf[ExpressionEncoder[T]])
override def outputEncoder: Encoder[T] = encoder
override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
if (b._1) {
(true, func(b._2, a))
} else {
(true, a)
}
}
override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
if (!b1._1) {
b2
} else if (!b2._1) {
b1
} else {
(true, func(b1._2, b2._2))
}
}
override def finish(reduction: (Boolean, T)): T = {
if (!reduction._1) {
throw new IllegalStateException("ReduceAggregator requires at least one input row")
}
reduction._2
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
class ReduceAggregatorSuite extends SparkFunSuite {
test("zero value") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
assert(aggregator.zero == (false, null))
}
test("reduce, merge and finish") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
val firstReduce = aggregator.reduce(aggregator.zero, 1)
assert(firstReduce == (true, 1))
val secondReduce = aggregator.reduce(firstReduce, 2)
assert(secondReduce == (true, 3))
val thirdReduce = aggregator.reduce(secondReduce, 3)
assert(thirdReduce == (true, 6))
val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
assert(mergeWithZero1 == (true, 1))
val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
assert(mergeWithZero2 == (true, 3))
val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
assert(mergeTwoReduced == (true, 4))
assert(aggregator.finish(firstReduce)== 1)
assert(aggregator.finish(secondReduce) == 3)
assert(aggregator.finish(thirdReduce) == 6)
assert(aggregator.finish(mergeWithZero1) == 1)
assert(aggregator.finish(mergeWithZero2) == 3)
assert(aggregator.finish(mergeTwoReduced) == 4)
}
test("requires at least one input row") {
val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
intercept[IllegalStateException] {
aggregator.finish(aggregator.zero)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment