diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 54df96cd2446aa6328b221ea1415fb526eacdcf7..ec0c8b483a909e77c41e778e897c0be32ba0d406 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate {
   )
 
   override val evaluateExpression = Cast(count, LongType)
+
+  override def defaultResult: Option[Literal] = Option(Literal(0L))
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index 644c6211d5f311f53320b5160b8c93fb8486434f..39010c3be6d4eec412bfb8146be6dd3341e1d082 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
 
 /**
  * Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -41,7 +42,7 @@ object Utils {
 
   private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
     case p: Aggregate if supportsGroupingKeySchema(p) =>
-      val converted = p.transformExpressionsDown {
+      val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
         case expressions.Average(child) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.Average(child),
@@ -144,7 +145,8 @@ object Utils {
             aggregateFunction = aggregate.VarianceSamp(child),
             mode = aggregate.Complete,
             isDistinct = false)
-      }
+      })
+
       // Check if there is any expressions.AggregateExpression1 left.
       // If so, we cannot convert this plan.
       val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
@@ -156,6 +158,7 @@ object Utils {
       }
 
       // Check if there are multiple distinct columns.
+      // TODO remove this.
       val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
         expr.collect {
           case agg: AggregateExpression2 => agg
@@ -213,3 +216,178 @@ object Utils {
     case other => None
   }
 }
+
+/**
+ * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
+ * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
+ * in a separate group. The results are then combined in a second aggregate.
+ *
+ * TODO Expression cannocalization
+ * TODO Eliminate foldable expressions from distinct clauses.
+ * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
+ *      operator. Perhaps this is a good thing? It is much simpler to plan later on...
+ */
+object MultipleDistinctRewriter extends Rule[LogicalPlan] {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    case a: Aggregate => rewrite(a)
+    case p => p
+  }
+
+  def rewrite(a: Aggregate): Aggregate = {
+
+    // Collect all aggregate expressions.
+    val aggExpressions = a.aggregateExpressions.flatMap { e =>
+      e.collect {
+        case ae: AggregateExpression2 => ae
+      }
+    }
+
+    // Extract distinct aggregate expressions.
+    val distinctAggGroups = aggExpressions
+      .filter(_.isDistinct)
+      .groupBy(_.aggregateFunction.children.toSet)
+
+    // Only continue to rewrite if there is more than one distinct group.
+    if (distinctAggGroups.size > 1) {
+      // Create the attributes for the grouping id and the group by clause.
+      val gid = new AttributeReference("gid", IntegerType, false)()
+      val groupByMap = a.groupingExpressions.collect {
+        case ne: NamedExpression => ne -> ne.toAttribute
+        case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)()
+      }
+      val groupByAttrs = groupByMap.map(_._2)
+
+      // Functions used to modify aggregate functions and their inputs.
+      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
+      def patchAggregateFunctionChildren(
+          af: AggregateFunction2,
+          id: Literal,
+          attrs: Map[Expression, Expression]): AggregateFunction2 = {
+        af.withNewChildren(af.children.map { case afc =>
+          evalWithinGroup(id, attrs(afc))
+        }).asInstanceOf[AggregateFunction2]
+      }
+
+      // Setup unique distinct aggregate children.
+      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
+      val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
+      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+
+      // Setup expand & aggregate operators for distinct aggregate expressions.
+      val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
+        case ((group, expressions), i) =>
+          val id = Literal(i + 1)
+
+          // Expand projection
+          val projection = distinctAggChildren.map {
+            case e if group.contains(e) => e
+            case e => nullify(e)
+          } :+ id
+
+          // Final aggregate
+          val operators = expressions.map { e =>
+            val af = e.aggregateFunction
+            val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
+            (e, e.copy(aggregateFunction = naf, isDistinct = false))
+          }
+
+          (projection, operators)
+      }
+
+      // Setup expand for the 'regular' aggregate expressions.
+      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+      val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+      val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
+
+      // Setup aggregates for 'regular' aggregate expressions.
+      val regularGroupId = Literal(0)
+      val regularAggOperatorMap = regularAggExprs.map { e =>
+        // Perform the actual aggregation in the initial aggregate.
+        val af = patchAggregateFunctionChildren(
+          e.aggregateFunction,
+          regularGroupId,
+          regularAggChildAttrMap)
+        val a = Alias(e.copy(aggregateFunction = af), e.toString)()
+
+        // Get the result of the first aggregate in the last aggregate.
+        val b = AggregateExpression2(
+          aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
+          mode = Complete,
+          isDistinct = false)
+
+        // Some aggregate functions (COUNT) have the special property that they can return a
+        // non-null result without any input. We need to make sure we return a result in this case.
+        val c = af.defaultResult match {
+          case Some(lit) => Coalesce(Seq(b, lit))
+          case None => b
+        }
+
+        (e, a, c)
+      }
+
+      // Construct the regular aggregate input projection only if we need one.
+      val regularAggProjection = if (regularAggExprs.nonEmpty) {
+        Seq(a.groupingExpressions ++
+          distinctAggChildren.map(nullify) ++
+          Seq(regularGroupId) ++
+          regularAggChildren)
+      } else {
+        Seq.empty[Seq[Expression]]
+      }
+
+      // Construct the distinct aggregate input projections.
+      val regularAggNulls = regularAggChildren.map(nullify)
+      val distinctAggProjections = distinctAggOperatorMap.map {
+        case (projection, _) =>
+          a.groupingExpressions ++
+            projection ++
+            regularAggNulls
+      }
+
+      // Construct the expand operator.
+      val expand = Expand(
+        regularAggProjection ++ distinctAggProjections,
+        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
+        a.child)
+
+      // Construct the first aggregate operator. This de-duplicates the all the children of
+      // distinct operators, and applies the regular aggregate operators.
+      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
+      val firstAggregate = Aggregate(
+        firstAggregateGroupBy,
+        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+        expand)
+
+      // Construct the second aggregate
+      val transformations: Map[Expression, Expression] =
+        (distinctAggOperatorMap.flatMap(_._2) ++
+          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+      val patchedAggExpressions = a.aggregateExpressions.map { e =>
+        e.transformDown {
+          case e: Expression =>
+            // The same GROUP BY clauses can have different forms (different names for instance) in
+            // the groupBy and aggregate expressions of an aggregate. This makes a map lookup
+            // tricky. So we do a linear search for a semantically equal group by expression.
+            groupByMap
+              .find(ge => e.semanticEquals(ge._1))
+              .map(_._2)
+              .getOrElse(transformations.getOrElse(e, e))
+        }.asInstanceOf[NamedExpression]
+      }
+      Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
+    } else {
+      a
+    }
+  }
+
+  private def nullify(e: Expression) = Literal.create(null, e.dataType)
+
+  private def expressionAttributePair(e: Expression) =
+    // We are creating a new reference here instead of reusing the attribute in case of a
+    // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
+    // children, in this case attribute reuse causes the input of the regular aggregate to bound to
+    // the (nulled out) input of the distinct aggregate.
+    e -> new AttributeReference(e.prettyName, e.dataType, true)()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index a2fab258fcac3cf16f4c0ea079be9a59b8469c4a..5c5b3d1ccd3cd0c677cc4d0930352602e15f2b24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
    */
   def supportsPartial: Boolean = true
 
+  /**
+   * Result of the aggregate function when the input is empty. This is currently only used for the
+   * proper rewriting of distinct aggregate functions.
+   */
+  def defaultResult: Option[Literal] = None
+
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
     throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 338c5193cb7a2d25ee22736e404efd2dd85ff37a..d222dfa33ad8ab5f2a904ed6d98c7bf8544a5cab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
  */
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
-      if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
-      a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
+    case a @ Aggregate(_, _, e @ Expand(_, _, child))
+      if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
+      a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))
 
     // Eliminate attributes that are not needed to calculate the specified aggregates.
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 4cb67aacf33ee703da6629345bd9e5bcd9d83a1b..fb963e2f8f7e7f0ab8d43a72ce9125f8ea14af5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -235,33 +235,17 @@ case class Window(
     projectList ++ windowExpressions.map(_.toAttribute)
 }
 
-/**
- * Apply the all of the GroupExpressions to every input row, hence we will get
- * multiple output rows for a input row.
- * @param bitmasks The bitmask set represents the grouping sets
- * @param groupByExprs The grouping by expressions
- * @param child       Child operator
- */
-case class Expand(
-    bitmasks: Seq[Int],
-    groupByExprs: Seq[Expression],
-    gid: Attribute,
-    child: LogicalPlan) extends UnaryNode {
-  override def statistics: Statistics = {
-    val sizeInBytes = child.statistics.sizeInBytes * projections.length
-    Statistics(sizeInBytes = sizeInBytes)
-  }
-
-  val projections: Seq[Seq[Expression]] = expand()
-
+private[sql] object Expand {
   /**
-   * Extract attribute set according to the grouping id
+   * Extract attribute set according to the grouping id.
+   *
    * @param bitmask bitmask to represent the selected of the attribute sequence
    * @param exprs the attributes in sequence
    * @return the attributes of non selected specified via bitmask (with the bit set to 1)
    */
-  private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
-  : OpenHashSet[Expression] = {
+  private def buildNonSelectExprSet(
+      bitmask: Int,
+      exprs: Seq[Expression]): OpenHashSet[Expression] = {
     val set = new OpenHashSet[Expression](2)
 
     var bit = exprs.length - 1
@@ -274,18 +258,28 @@ case class Expand(
   }
 
   /**
-   * Create an array of Projections for the child projection, and replace the projections'
-   * expressions which equal GroupBy expressions with Literal(null), if those expressions
-   * are not set for this grouping set (according to the bit mask).
+   * Apply the all of the GroupExpressions to every input row, hence we will get
+   * multiple output rows for a input row.
+   *
+   * @param bitmasks The bitmask set represents the grouping sets
+   * @param groupByExprs The grouping by expressions
+   * @param gid Attribute of the grouping id
+   * @param child Child operator
    */
-  private[this] def expand(): Seq[Seq[Expression]] = {
-    val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
-
-    bitmasks.foreach { bitmask =>
+  def apply(
+    bitmasks: Seq[Int],
+    groupByExprs: Seq[Expression],
+    gid: Attribute,
+    child: LogicalPlan): Expand = {
+    // Create an array of Projections for the child projection, and replace the projections'
+    // expressions which equal GroupBy expressions with Literal(null), if those expressions
+    // are not set for this grouping set (according to the bit mask).
+    val projections = bitmasks.map { bitmask =>
       // get the non selected grouping attributes according to the bit mask
       val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
 
-      val substitution = (child.output :+ gid).map(expr => expr transformDown {
+      (child.output :+ gid).map(expr => expr transformDown {
+        // TODO this causes a problem when a column is used both for grouping and aggregation.
         case x: Expression if nonSelectedGroupExprSet.contains(x) =>
           // if the input attribute in the Invalid Grouping Expression set of for this group
           // replace it with constant null
@@ -294,15 +288,29 @@ case class Expand(
           // replace the groupingId with concrete value (the bit mask)
           Literal.create(bitmask, IntegerType)
       })
-
-      result += substitution
     }
-
-    result.toSeq
+    Expand(projections, child.output :+ gid, child)
   }
+}
 
-  override def output: Seq[Attribute] = {
-    child.output :+ gid
+/**
+ * Apply a number of projections to every input row, hence we will get multiple output rows for
+ * a input row.
+ *
+ * @param projections to apply
+ * @param output of all projections.
+ * @param child operator.
+ */
+case class Expand(
+    projections: Seq[Seq[Expression]],
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+
+  override def statistics: Statistics = {
+    // TODO shouldn't we factor in the size of the projection versus the size of the backing child
+    //      row?
+    val sizeInBytes = child.statistics.sizeInBytes * projections.length
+    Statistics(sizeInBytes = sizeInBytes)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f4464e0b916f8f2d7512780aeb664d6860e8cdfd..dd3bb33c5728704236ef1ef478e07ccb903ec791 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         }
       case logical.Filter(condition, child) =>
         execution.Filter(condition, planLater(child)) :: Nil
-      case e @ logical.Expand(_, _, _, child) =>
+      case e @ logical.Expand(_, _, child) =>
         execution.Expand(e.projections, e.output, planLater(child)) :: Nil
       case a @ logical.Aggregate(group, agg, child) => {
         val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled