diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1f6526ef66c56a40021fbce160e4227960199465..566b34f7c3a6a92e8ecb2e3f6fa18099cc6dfefa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -369,6 +369,51 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } +case class MinOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable && right.nullable + + override def children: Seq[Expression] = left :: right :: Nil + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val ordering = left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 + } else { + if (ordering.compare(evalE1, evalE2) < 0) { + evalE1 + } else { + evalE2 + } + } + } + + override def toString: String = s"MinOf($left, $right)" +} + /** * A function that get the absolute value of the numeric value. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index aac56e15683324bc28c4c946c58e47abcd20156e..d141354a0f42793e233d1ec140afc4f57e49889a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -524,6 +524,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case MinOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + case UnscaledValue(child) => val childEval = expressionEvaluator(child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index d2b1090a0cdd50e51b07b3097204b0952b00acfd..d4362a91d992ca47c10fab23f1c08ed24115b27d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -233,6 +233,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) } + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b510cf033c4a4818809031a0a1aac5790428a492..b1ef6556de1e9e3c38a5dd2cc8aae53333ac7f78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -164,6 +164,17 @@ case class GeneratedAggregate( updateMax :: Nil, currentMax) + case m @ Min(expr) => + val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMin = MinOf(currentMin, expr) + + AggregateEvaluation( + currentMin :: Nil, + initialValue :: Nil, + updateMin :: Nil, + currentMin) + case CollectHashSet(Seq(expr)) => val set = AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() @@ -188,6 +199,8 @@ case class GeneratedAggregate( initialValue :: Nil, collectSets :: Nil, CountSet(set)) + + case o => sys.error(s"$o can't be codegened.") } val computationSchema = computeFunctions.flatMap(_.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f0d92ffffcda3830acdf0b1b3c52897e84776b5e..5b99e40c2f4918dbce9c0d26b1c7f9adf5560252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e453e05e2ac75090c98a0ac15c24a83be468027..73fb791c3ead795a76a63497a2981a749975b671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -172,6 +172,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { testCodeGen( "SELECT max(key) FROM testData3x", Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) // Some combinations. testCodeGen( """ @@ -179,16 +186,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | value, | sum(key), | max(key), + | min(key), | avg(key), | count(key), | count(distinct key) |FROM testData3x |GROUP BY value """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1))) + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) testCodeGen( - "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 50.5, 300, 100) :: Nil) + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData",