Skip to content
Snippets Groups Projects
Commit 19c8fb02 authored by Michael Armbrust's avatar Michael Armbrust Committed by Reynold Xin
Browse files

[SQL] Improve SparkSQL Aggregates

* Add native min/max (was using hive before).
* Handle nulls correctly in Avg and Sum.

Author: Michael Armbrust <michael@databricks.com>

Closes #683 from marmbrus/aggFixes and squashes the following commits:

64fe30b [Michael Armbrust] Improve SparkSQL Aggregates * Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum.
parent 6ed7e2cd
No related branches found
No related tags found
No related merge requests found
......@@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val JOIN = Keyword("JOIN")
protected val LEFT = Keyword("LEFT")
protected val LIMIT = Keyword("LIMIT")
protected val MAX = Keyword("MAX")
protected val MIN = Keyword("MIN")
protected val NOT = Keyword("NOT")
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
......@@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
} |
......
......@@ -86,6 +86,67 @@ abstract class AggregateFunction
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"MIN($child)"
override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
}
override def newInstance() = new MinFunction(child, this)
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
var currentMin: Any = _
override def update(input: Row): Unit = {
if (currentMin == null) {
currentMin = expr.eval(input)
} else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
currentMin = expr.eval(input)
}
}
override def eval(input: Row): Any = currentMin
}
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"MAX($child)"
override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
}
override def newInstance() = new MaxFunction(child, this)
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
var currentMax: Any = _
override def update(input: Row): Unit = {
if (currentMax == null) {
currentMax = expr.eval(input)
} else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
currentMax = expr.eval(input)
}
}
override def eval(input: Row): Any = currentMax
}
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
......@@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
}
override def newInstance()= new CountFunction(child, this)
override def newInstance() = new CountFunction(child, this)
}
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
......@@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def newInstance()= new CountDistinctFunction(expressions, this)
override def newInstance() = new CountDistinctFunction(expressions, this)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
......@@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
}
override def newInstance()= new AverageFunction(child, this)
override def newInstance() = new AverageFunction(child, this)
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
......@@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
partialSum :: Nil)
}
override def newInstance()= new SumFunction(child, this)
override def newInstance() = new SumFunction(child, this)
}
case class SumDistinct(child: Expression)
......@@ -153,7 +214,7 @@ case class SumDistinct(child: Expression)
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"
override def newInstance()= new SumDistinctFunction(child, this)
override def newInstance() = new SumDistinctFunction(child, this)
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
......@@ -168,7 +229,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
override def newInstance()= new FirstFunction(child, this)
override def newInstance() = new FirstFunction(child, this)
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
......@@ -176,11 +237,13 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
private val zero = Cast(Literal(0), expr.dataType)
private var count: Long = _
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
private val sum = MutableLiteral(zero.eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)
private val addFunction = Add(sum, expr)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
......@@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
private val zero = Cast(Literal(0), expr.dataType)
private val sum = MutableLiteral(zero.eval(null))
private val addFunction = Add(sum, expr)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
override def update(input: Row): Unit = {
sum.update(addFunction, input)
......
......@@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest {
Seq((1,3),(2,3),(3,3)))
}
test("aggregates with nulls") {
checkAnswer(
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
(1, 3, 2, 6, 3) :: Nil
)
}
test("select *") {
checkAnswer(
sql("SELECT * FROM testData"),
......
......@@ -84,4 +84,14 @@ object TestData {
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
nullableRepeatedData.registerAsTable("nullableRepeatedData")
case class NullInts(a: Integer)
val nullInts =
TestSQLContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
)
nullInts.registerAsTable("nullInts")
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment