From 5c6ec94da1bacd8e65a43acb92b6721493484e7b Mon Sep 17 00:00:00 2001 From: Ximo Guanter Gonzalbez <ximo@tid.es> Date: Wed, 2 Jul 2014 10:03:44 -0700 Subject: [PATCH] SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG **Description** This patch enables using the `.select()` function in SchemaRDD with functions such as `Sum`, `Count` and other. **Testing** Unit tests added. Author: Ximo Guanter Gonzalbez <ximo@tid.es> Closes #1211 from edrevo/add-expression-support-in-select and squashes the following commits: fe4a1e1 [Ximo Guanter Gonzalbez] Extend SQL DSL to functions e1d344a [Ximo Guanter Gonzalbez] SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG --- .../spark/sql/catalyst/dsl/package.scala | 11 +++++++ .../org/apache/spark/sql/SchemaRDD.scala | 9 ++++-- .../org/apache/spark/sql/DslQuerySuite.scala | 32 +++++++++++++++---- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 26ad4837b0..1b503b957d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -108,6 +108,17 @@ package object dsl { implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) + def sum(e: Expression) = Sum(e) + def sumDistinct(e: Expression) = SumDistinct(e) + def count(e: Expression) = Count(e) + def countDistinct(e: Expression*) = CountDistinct(e) + def avg(e: Expression) = Average(e) + def first(e: Expression) = First(e) + def min(e: Expression) = Min(e) + def max(e: Expression) = Max(e) + def upper(e: Expression) = Upper(e) + def lower(e: Expression) = Lower(e) + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 7c0efb4566..8f9f54f610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -133,8 +133,13 @@ class SchemaRDD( * * @group Query */ - def select(exprs: NamedExpression*): SchemaRDD = - new SchemaRDD(sqlContext, Project(exprs, logicalPlan)) + def select(exprs: Expression*): SchemaRDD = { + val aliases = exprs.zipWithIndex.map { + case (ne: NamedExpression, _) => ne + case (e, i) => Alias(e, s"c$i")() + } + new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) + } /** * Filters the output, only returning those rows where `condition` evaluates to true. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index e4a64a7a48..04ac008682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -60,6 +60,26 @@ class DslQuerySuite extends QueryTest { Seq(Seq("1"))) } + test("select with functions") { + checkAnswer( + testData.select(sum('value), avg('value), count(1)), + Seq(Seq(5050.0, 50.5, 100))) + + checkAnswer( + testData2.select('a + 'b, 'a < 'b), + Seq( + Seq(2, false), + Seq(3, true), + Seq(3, false), + Seq(4, false), + Seq(4, false), + Seq(5, false))) + + checkAnswer( + testData2.select(sumDistinct('a)), + Seq(Seq(6))) + } + test("sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), @@ -110,17 +130,17 @@ class DslQuerySuite extends QueryTest { test("average") { checkAnswer( - testData2.groupBy()(Average('a)), + testData2.groupBy()(avg('a)), 2.0) } test("null average") { checkAnswer( - testData3.groupBy()(Average('b)), + testData3.groupBy()(avg('b)), 2.0) checkAnswer( - testData3.groupBy()(Average('b), CountDistinct('b :: Nil)), + testData3.groupBy()(avg('b), countDistinct('b)), (2.0, 1) :: Nil) } @@ -130,17 +150,17 @@ class DslQuerySuite extends QueryTest { test("null count") { checkAnswer( - testData3.groupBy('a)('a, Count('b)), + testData3.groupBy('a)('a, count('b)), Seq((1,0), (2, 1)) ) checkAnswer( - testData3.groupBy('a)('a, Count('a + 'b)), + testData3.groupBy('a)('a, count('a + 'b)), Seq((1,0), (2, 1)) ) checkAnswer( - testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), (2, 1, 2, 2, 1) :: Nil ) } -- GitLab