diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 2ccc27324ac8c64994b0b41c86a3a5da5d655755..6fcf9e31543ed4ff156d5dc6192319503b8532f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -241,9 +241,9 @@ class DAGScheduler(
       callSite: CallSite)
     : Stage =
   {
+    val parentStages = getParentStages(rdd, jobId)
     val id = nextStageId.getAndIncrement()
-    val stage =
-      new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+    val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite)
     stageIdToStage(id) = stage
     updateJobIdStageIdMaps(jobId, stage)
     stage
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 0bb91febde9d722d1e533c6711e92d647864f71d..aa73469b6acd8b96ef0dd1255db0043b31ae2dc1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
+import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -97,10 +98,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
   /** Length of time to wait while draining listener events. */
   val WAIT_TIMEOUT_MILLIS = 10000
   val sparkListener = new SparkListener() {
-    val successfulStages = new HashSet[Int]()
-    val failedStages = new ArrayBuffer[Int]()
+    val successfulStages = new HashSet[Int]
+    val failedStages = new ArrayBuffer[Int]
+    val stageByOrderOfExecution = new ArrayBuffer[Int]
     override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
       val stageInfo = stageCompleted.stageInfo
+      stageByOrderOfExecution += stageInfo.stageId
       if (stageInfo.failureReason.isEmpty) {
         successfulStages += stageInfo.stageId
       } else {
@@ -231,6 +234,13 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
     runEvent(JobCancelled(jobId))
   }
 
+  test("[SPARK-3353] parent stage should have lower stage id") {
+    sparkListener.stageByOrderOfExecution.clear()
+    sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
+    assert(sparkListener.stageByOrderOfExecution.length === 2)
+    assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1))
+  }
+
   test("zero split job") {
     var numResults = 0
     val fakeListener = new JobListener() {
@@ -457,7 +467,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
       null,
       null))
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
-    assert(sparkListener.failedStages.contains(0))
+    assert(sparkListener.failedStages.contains(1))
 
     // The second ResultTask fails, with a fetch failure for the output from the second mapper.
     runEvent(CompletionEvent(
@@ -515,8 +525,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
     // Listener bus should get told about the map stage failing, but not the reduce stage
     // (since the reduce stage hasn't been started yet).
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
-    assert(sparkListener.failedStages.contains(1))
-    assert(sparkListener.failedStages.size === 1)
+    assert(sparkListener.failedStages.toSet === Set(0))
 
     assertDataStructuresEmpty
   }
@@ -563,14 +572,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
     val stageFailureMessage = "Exception failure in map stage"
     failed(taskSets(0), stageFailureMessage)
 
-    assert(cancelledStages.contains(1))
+    assert(cancelledStages.toSet === Set(0, 2))
 
     // Make sure the listeners got told about both failed stages.
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
     assert(sparkListener.successfulStages.isEmpty)
-    assert(sparkListener.failedStages.contains(1))
-    assert(sparkListener.failedStages.contains(3))
-    assert(sparkListener.failedStages.size === 2)
+    assert(sparkListener.failedStages.toSet === Set(0, 2))
 
     assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
     assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 3b0b8e2f68c97fb224bc5553b699b30fa1a8329f..ab35e8edc4ebf3965587377b0ec5a234fd73ef48 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -180,7 +180,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
     rdd3.count()
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
     listener.stageInfos.size should be {2} // Shuffle map stage + result stage
-    val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get
+    val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get
     stageInfo3.rddInfos.size should be {1} // ShuffledRDD
     stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true}
     stageInfo3.rddInfos.exists(_.name == "Trois") should be {true}