diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4952ba3b2b99d56aad310e4a22d7b1090df62213..9df8ce1fa3b28793298e61f87151dca00c80c516 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.annotation.tailrec
 import scala.collection.immutable.HashSet
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.api.java.function.FilterFunction
@@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
-import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
+import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
@@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
  * Combines all adjacent [[Union]] operators into a single [[Union]].
  */
 object CombineUnions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case Unions(children) => Union(children)
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+    case u: Union => flattenUnion(u, false)
+    case Distinct(u: Union) => Distinct(flattenUnion(u, true))
+  }
+
+  private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = {
+    val stack = mutable.Stack[LogicalPlan](union)
+    val flattened = mutable.ArrayBuffer.empty[LogicalPlan]
+    while (stack.nonEmpty) {
+      stack.pop() match {
+        case Distinct(Union(children)) if flattenDistinct =>
+          stack.pushAll(children.reverse)
+        case Union(children) =>
+          stack.pushAll(children.reverse)
+        case child =>
+          flattened += child
+      }
+    }
+    Union(flattened)
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 41cabb8cb3390113899979dd4286b4894741de35..bdae56881bf460d2ba9533181de004ea6b9124bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
   }
 }
 
-
-/**
- * A pattern that collects all adjacent unions and returns their children as a Seq.
- */
-object Unions {
-  def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
-    case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
-    case _ => None
-  }
-
-  // Doing a depth-first tree traversal to combine all the union children.
-  @tailrec
-  private def collectUnionChildren(
-      plans: mutable.Stack[LogicalPlan],
-      children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
-    if (plans.isEmpty) children
-    else {
-      plans.pop match {
-        case Union(grandchildren) =>
-          grandchildren.reverseMap(plans.push(_))
-          collectUnionChildren(plans, children)
-        case other => collectUnionChildren(plans, children :+ other)
-      }
-    }
-  }
-}
-
 /**
  * An extractor used when planning the physical execution of an aggregation. Compared with a logical
  * aggregation, the following transformations are performed:
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index 7227706ab2b36d7e00c6ce535353b330bc64a91e..21b7f49e14bd58ae99093c37f0984b932e42e151 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
@@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest {
         testRelation3.select('g) :: Nil).analyze
     comparePlans(unionOptimized, unionCorrectAnswer)
   }
+
+  test("Remove unnecessary distincts in multiple unions") {
+    val query1 = OneRowRelation
+      .select(Literal(1).as('a))
+    val query2 = OneRowRelation
+      .select(Literal(2).as('b))
+    val query3 = OneRowRelation
+      .select(Literal(3).as('c))
+
+    // D - U - D - U - query1
+    //     |       |
+    //     query3  query2
+    val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze
+    val optimized1 = Optimize.execute(unionQuery1)
+    val distinctUnionCorrectAnswer1 =
+      Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze
+    comparePlans(distinctUnionCorrectAnswer1, optimized1)
+
+    //         query1
+    //         |
+    // D - U - U - query2
+    //     |
+    //     D - U - query2
+    //         |
+    //         query3
+    val unionQuery2 = Distinct(Union(Union(query1, query2),
+      Distinct(Union(query2, query3)))).analyze
+    val optimized2 = Optimize.execute(unionQuery2)
+    val distinctUnionCorrectAnswer2 =
+      Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze
+    comparePlans(distinctUnionCorrectAnswer2, optimized2)
+  }
+
+  test("Keep necessary distincts in multiple unions") {
+    val query1 = OneRowRelation
+      .select(Literal(1).as('a))
+    val query2 = OneRowRelation
+      .select(Literal(2).as('b))
+    val query3 = OneRowRelation
+      .select(Literal(3).as('c))
+    val query4 = OneRowRelation
+      .select(Literal(4).as('d))
+
+    // U - D - U - query1
+    // |       |
+    // query3  query2
+    val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze
+    val optimized1 = Optimize.execute(unionQuery1)
+    val distinctUnionCorrectAnswer1 =
+      Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze
+    comparePlans(distinctUnionCorrectAnswer1, optimized1)
+
+    //         query1
+    //         |
+    // U - D - U - query2
+    // |
+    // D - U - query3
+    //     |
+    //     query4
+    val unionQuery2 =
+      Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze
+    val optimized2 = Optimize.execute(unionQuery2)
+    val distinctUnionCorrectAnswer2 =
+      Union(Distinct(Union(query1 :: query2 :: Nil)),
+            Distinct(Union(query3 :: query4 :: Nil))).analyze
+    comparePlans(distinctUnionCorrectAnswer2, optimized2)
+  }
 }