Skip to content
Snippets Groups Projects
Commit 9c57bc0e authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11656][SQL] support typed aggregate in project list

insert `aEncoder` like we do in `agg`

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9630 from cloud-fan/select.
parent c964fc10
No related branches found
No related tags found
No related merge requests found
...@@ -21,14 +21,15 @@ import scala.collection.JavaConverters._ ...@@ -21,14 +21,15 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.api.java.function._ import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.execution.{Queryable, QueryExecution}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
/** /**
...@@ -359,7 +360,7 @@ class Dataset[T] private[sql]( ...@@ -359,7 +360,7 @@ class Dataset[T] private[sql](
* @since 1.6.0 * @since 1.6.0
*/ */
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan))
} }
/** /**
...@@ -368,11 +369,12 @@ class Dataset[T] private[sql]( ...@@ -368,11 +369,12 @@ class Dataset[T] private[sql](
* that cast appropriately for the user facing interface. * that cast appropriately for the user facing interface.
*/ */
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val withEncoders = columns.map(withEncoder)
val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
val unresolvedPlan = Project(aliases, logicalPlan) val unresolvedPlan = Project(aliases, logicalPlan)
val execution = new QueryExecution(sqlContext, unresolvedPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan)
// Rebind the encoders to the nested schema that will be produced by the select. // Rebind the encoders to the nested schema that will be produced by the select.
val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
case (e: ExpressionEncoder[_], a) if !e.flat => case (e: ExpressionEncoder[_], a) if !e.flat =>
e.nested(a.toAttribute).resolve(execution.analyzed.output) e.nested(a.toAttribute).resolve(execution.analyzed.output)
case (e, a) => case (e, a) =>
...@@ -381,6 +383,16 @@ class Dataset[T] private[sql]( ...@@ -381,6 +383,16 @@ class Dataset[T] private[sql](
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
} }
private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = {
val e = c.expr transform {
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
ta.copy(
aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]),
children = queryExecution.analyzed.output)
}
new TypedColumn(e, c.encoder)
}
/** /**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0 * @since 1.6.0
......
...@@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ...@@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ComplexResultAgg.toColumn), ComplexResultAgg.toColumn),
("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
} }
test("typed aggregation: in project list") {
val ds = Seq(1, 3, 2, 5).toDS()
checkAnswer(
ds.select(sum((i: Int) => i)),
11)
checkAnswer(
ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
11 -> 22)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment