diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3bc48c95c5653e77879f7c9512e473157d31cf65..fd58b9681ea241ac4e4b6991c22bf5eabe745a90 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -50,6 +50,7 @@ object DefaultOptimizer extends Optimizer {
       CombineFilters,
       PushPredicateThroughProject,
       PushPredicateThroughJoin,
+      PushPredicateThroughGenerate,
       ColumnPruning) ::
     Batch("LocalRelation", FixedPoint(100),
       ConvertToLocalRelation) :: Nil
@@ -455,6 +456,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
   }
 }
 
+/**
+ * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
+ * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
+ */
+object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case filter @ Filter(condition,
+    generate @ Generate(generator, join, outer, alias, grandChild)) =>
+      // Predicates that reference attributes produced by the `Generate` operator cannot
+      // be pushed below the operator.
+      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
+        conjunct => conjunct.references subsetOf grandChild.outputSet
+      }
+      if (pushDown.nonEmpty) {
+        val pushDownPredicate = pushDown.reduce(And)
+        val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
+        stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+      } else {
+        filter
+      }
+  }
+}
+
 /**
  * Pushes down [[Filter]] operators where the `condition` can be
  * evaluated using only the attributes of the left or right side of a join.  Other
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ebb123c1f909ec15b18c5478889a782c764aaa6e..1158b5dfc61474608d6aa8256fe68f3891b434ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.expressions.Explode
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.IntegerType
 
 class FilterPushdownSuite extends PlanTest {
 
@@ -34,7 +36,8 @@ class FilterPushdownSuite extends PlanTest {
       Batch("Filter Pushdown", Once,
         CombineFilters,
         PushPredicateThroughProject,
-        PushPredicateThroughJoin) :: Nil
+        PushPredicateThroughJoin,
+        PushPredicateThroughGenerate) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -411,4 +414,62 @@ class FilterPushdownSuite extends PlanTest {
 
     comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
   }
+
+  val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
+
+  test("generate: predicate referenced no generated column") {
+    val originalQuery = {
+      testRelationWithArrayType
+        .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+        .where(('b >= 5) && ('a > 6))
+    }
+    val optimized = Optimize(originalQuery.analyze)
+    val correctAnswer = {
+      testRelationWithArrayType
+        .where(('b >= 5) && ('a > 6))
+        .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
+    }
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("generate: part of conjuncts referenced generated column") {
+    val generator = Explode(Seq("c"), 'c_arr)
+    val originalQuery = {
+      testRelationWithArrayType
+        .generate(generator, true, false, Some("arr"))
+        .where(('b >= 5) && ('c > 6))
+    }
+    val optimized = Optimize(originalQuery.analyze)
+    val referenceResult = {
+      testRelationWithArrayType
+        .where('b >= 5)
+        .generate(generator, true, false, Some("arr"))
+        .where('c > 6).analyze
+    }
+
+    // Since newly generated columns get different ids every time being analyzed
+    // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails.
+    // So we check operators manually here.
+    // Filter("c" > 6)
+    assertResult(classOf[Filter])(optimized.getClass)
+    assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size)
+    assertResult("c"){
+      optimized.asInstanceOf[Filter].condition.references.toSeq(0).name
+    }
+
+    // the rest part
+    comparePlans(optimized.children(0), referenceResult.children(0))
+  }
+
+  test("generate: all conjuncts referenced generated column") {
+    val originalQuery = {
+      testRelationWithArrayType
+        .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+        .where(('c > 6) || ('b > 5)).analyze
+    }
+    val optimized = Optimize(originalQuery)
+
+    comparePlans(optimized, originalQuery)
+  }
 }