Skip to content
Snippets Groups Projects
Commit 5c6ec94d authored by Ximo Guanter Gonzalbez's avatar Ximo Guanter Gonzalbez Committed by Michael Armbrust
Browse files

SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG

**Description** This patch enables using the `.select()` function in SchemaRDD with functions such as `Sum`, `Count` and other.
**Testing** Unit tests added.

Author: Ximo Guanter Gonzalbez <ximo@tid.es>

Closes #1211 from edrevo/add-expression-support-in-select and squashes the following commits:

fe4a1e1 [Ximo Guanter Gonzalbez] Extend SQL DSL to functions
e1d344a [Ximo Guanter Gonzalbez] SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG
parent 6596392d
No related branches found
No related tags found
No related merge requests found
...@@ -108,6 +108,17 @@ package object dsl { ...@@ -108,6 +108,17 @@ package object dsl {
implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
def sum(e: Expression) = Sum(e)
def sumDistinct(e: Expression) = SumDistinct(e)
def count(e: Expression) = Count(e)
def countDistinct(e: Expression*) = CountDistinct(e)
def avg(e: Expression) = Average(e)
def first(e: Expression) = First(e)
def min(e: Expression) = Min(e)
def max(e: Expression) = Max(e)
def upper(e: Expression) = Upper(e)
def lower(e: Expression) = Lower(e)
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
// TODO more implicit class for literal? // TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators { implicit class DslString(val s: String) extends ImplicitOperators {
......
...@@ -133,8 +133,13 @@ class SchemaRDD( ...@@ -133,8 +133,13 @@ class SchemaRDD(
* *
* @group Query * @group Query
*/ */
def select(exprs: NamedExpression*): SchemaRDD = def select(exprs: Expression*): SchemaRDD = {
new SchemaRDD(sqlContext, Project(exprs, logicalPlan)) val aliases = exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
case (e, i) => Alias(e, s"c$i")()
}
new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
}
/** /**
* Filters the output, only returning those rows where `condition` evaluates to true. * Filters the output, only returning those rows where `condition` evaluates to true.
......
...@@ -60,6 +60,26 @@ class DslQuerySuite extends QueryTest { ...@@ -60,6 +60,26 @@ class DslQuerySuite extends QueryTest {
Seq(Seq("1"))) Seq(Seq("1")))
} }
test("select with functions") {
checkAnswer(
testData.select(sum('value), avg('value), count(1)),
Seq(Seq(5050.0, 50.5, 100)))
checkAnswer(
testData2.select('a + 'b, 'a < 'b),
Seq(
Seq(2, false),
Seq(3, true),
Seq(3, false),
Seq(4, false),
Seq(4, false),
Seq(5, false)))
checkAnswer(
testData2.select(sumDistinct('a)),
Seq(Seq(6)))
}
test("sorting") { test("sorting") {
checkAnswer( checkAnswer(
testData2.orderBy('a.asc, 'b.asc), testData2.orderBy('a.asc, 'b.asc),
...@@ -110,17 +130,17 @@ class DslQuerySuite extends QueryTest { ...@@ -110,17 +130,17 @@ class DslQuerySuite extends QueryTest {
test("average") { test("average") {
checkAnswer( checkAnswer(
testData2.groupBy()(Average('a)), testData2.groupBy()(avg('a)),
2.0) 2.0)
} }
test("null average") { test("null average") {
checkAnswer( checkAnswer(
testData3.groupBy()(Average('b)), testData3.groupBy()(avg('b)),
2.0) 2.0)
checkAnswer( checkAnswer(
testData3.groupBy()(Average('b), CountDistinct('b :: Nil)), testData3.groupBy()(avg('b), countDistinct('b)),
(2.0, 1) :: Nil) (2.0, 1) :: Nil)
} }
...@@ -130,17 +150,17 @@ class DslQuerySuite extends QueryTest { ...@@ -130,17 +150,17 @@ class DslQuerySuite extends QueryTest {
test("null count") { test("null count") {
checkAnswer( checkAnswer(
testData3.groupBy('a)('a, Count('b)), testData3.groupBy('a)('a, count('b)),
Seq((1,0), (2, 1)) Seq((1,0), (2, 1))
) )
checkAnswer( checkAnswer(
testData3.groupBy('a)('a, Count('a + 'b)), testData3.groupBy('a)('a, count('a + 'b)),
Seq((1,0), (2, 1)) Seq((1,0), (2, 1))
) )
checkAnswer( checkAnswer(
testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
(2, 1, 2, 2, 1) :: Nil (2, 1, 2, 2, 1) :: Nil
) )
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment