diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8cf4073826192510967dc3fd207066a3710379b6..574f91b09912b411942588a80c8462d5e076f7bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -966,9 +966,9 @@ class Analyzer(
       case s @ Sort(orders, global, child)
         if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
         val newOrders = orders map {
-          case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) =>
+          case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
             if (index > 0 && index <= child.output.size) {
-              SortOrder(child.output(index - 1), direction, nullOrdering)
+              SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty)
             } else {
               s.failAnalysis(
                 s"ORDER BY position $index is not in select list " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
index af0a565f73ae9302f1fe86461f16f4d861194829..38a3d3de1288e43d70938e28a4836ee27899202c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
@@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan]
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
       val newOrders = s.order.map {
-        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
+        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
           val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
           withOrigin(order.origin)(order.copy(child = newOrdinal))
         case other => other
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 35ca2a0aa53a208bfb3fd851c21b3ba2da28a414..75bf780d41424796328c9e3bf9c46c04ac5111cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -109,9 +109,9 @@ package object dsl {
     def cast(to: DataType): Expression = Cast(expr, to)
 
     def asc: SortOrder = SortOrder(expr, Ascending)
-    def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast)
+    def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
     def desc: SortOrder = SortOrder(expr, Descending)
-    def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst)
+    def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty)
     def as(alias: String): NamedExpression = Alias(expr, alias)()
     def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3bebd552ef51a5cd3ddebe6242c352afd1405b6e..abcb9a2b939b4165a5c3242df76f6d650e6cd190 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{
 /**
  * An expression that can be used to sort a tuple.  This class extends expression primarily so that
  * transformations over expression will descend into its child.
+ * `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is
+ * derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge
+ * join.
  */
-case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering)
+case class SortOrder(
+    child: Expression,
+    direction: SortDirection,
+    nullOrdering: NullOrdering,
+    sameOrderExpressions: Set[Expression])
   extends UnaryExpression with Unevaluable {
 
   /** Sort order is not foldable because we don't have an eval for it. */
@@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering:
   override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql
 
   def isAscending: Boolean = direction == Ascending
+
+  def satisfies(required: SortOrder): Boolean = {
+    (sameOrderExpressions + child).exists(required.child.semanticEquals) &&
+      direction == required.direction && nullOrdering == required.nullOrdering
+  }
 }
 
 object SortOrder {
-  def apply(child: Expression, direction: SortDirection): SortOrder = {
-    new SortOrder(child, direction, direction.defaultNullOrdering)
+  def apply(
+     child: Expression,
+     direction: SortDirection,
+     sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = {
+    new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions)
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 4c9fb2ec2774a5b971667d842916d95be52c438b..cd238e05d4102bf77ba678510af5b84bf7f2ba74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
     } else {
       direction.defaultNullOrdering
     }
-    SortOrder(expression(ctx.expression), direction, nullOrdering)
+    SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 38029552d13bdc045f1ba46869521f636e4d9ea5..ae0703513cf424c7f596780df5c8f72a72bce9a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) }
+  def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) }
 
   /**
    * Returns a descending ordering used in sorting, where null values appear after non-null values.
@@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) }
+  def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) }
 
   /**
    * Returns an ascending ordering used in sorting.
@@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) }
+  def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) }
 
   /**
    * Returns an ordering used in sorting, where null values appear after non-null values.
@@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) }
+  def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) }
 
   /**
    * Prints the expression to the console for debugging purpose.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index f17049949aa473b5044b5f54545d1cd6762dfd6f..b91d0774425579be2bc4998f42f2319096ce1175 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
         } else {
           requiredOrdering.zip(child.outputOrdering).forall {
             case (requiredOrder, childOutputOrder) =>
-              requiredOrder.semanticEquals(childOutputOrder)
+              childOutputOrder.satisfies(requiredOrder)
           }
         }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 02f4f55c7999a6a1f8c68ce3340aea55f49dab39..c6aae1a4db2e41cee3f8ca79d967e103ac29b37b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -81,17 +81,37 @@ case class SortMergeJoinExec(
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
   override def outputOrdering: Seq[SortOrder] = joinType match {
+    // For inner join, orders of both sides keys should be kept.
+    case Inner =>
+      val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
+      val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
+      leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
+        // Also add the right key and its `sameOrderExpressions`
+        SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey
+          .sameOrderExpressions)
+      }
     // For left and right outer joins, the output is ordered by the streamed input's join keys.
-    case LeftOuter => requiredOrders(leftKeys)
-    case RightOuter => requiredOrders(rightKeys)
+    case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
+    case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering)
     // There are null rows in both streams, so there is no order.
     case FullOuter => Nil
-    case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys)
+    case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering)
     case x =>
       throw new IllegalArgumentException(
         s"${getClass.getSimpleName} should not take $x as the JoinType")
   }
 
+  /**
+   * For SMJ, child's output must have been sorted on key or expressions with the same order as
+   * key, so we can get ordering for key from child's output ordering.
+   */
+  private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder])
+    : Seq[SortOrder] = {
+    keys.zip(childOutputOrdering).map { case (key, childOrder) =>
+      SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key)
+    }
+  }
+
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index f2232fc489b78c62dd0d61e215b3a641a0b70deb..4d155d538d63720533a3cecf52d3637ce8a3f51d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext {
 
   private val exprA = Literal(1)
   private val exprB = Literal(2)
+  private val exprC = Literal(3)
   private val orderingA = SortOrder(exprA, Ascending)
   private val orderingB = SortOrder(exprB, Ascending)
+  private val orderingC = SortOrder(exprC, Ascending)
   private val planA = DummySparkPlan(outputOrdering = Seq(orderingA),
     outputPartitioning = HashPartitioning(exprA :: Nil, 5))
   private val planB = DummySparkPlan(outputOrdering = Seq(orderingB),
     outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+  private val planC = DummySparkPlan(outputOrdering = Seq(orderingC),
+    outputPartitioning = HashPartitioning(exprC :: Nil, 5))
 
-  assert(orderingA != orderingB)
+  assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC)
 
   private def assertSortRequirementsAreSatisfied(
       childPlan: SparkPlan,
@@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
+  test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") {
+    val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
+    // Both left and right keys should be sorted after the SMJ.
+    Seq(orderingA, orderingB).foreach { ordering =>
+      assertSortRequirementsAreSatisfied(
+        childPlan = innerSmj,
+        requiredOrdering = Seq(ordering),
+        shouldHaveSort = false)
+    }
+  }
+
+  test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " +
+    "child SMJ") {
+    val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
+    val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC)
+    // After the second SMJ, exprA, exprB and exprC should all be sorted.
+    Seq(orderingA, orderingB, orderingC).foreach { ordering =>
+      assertSortRequirementsAreSatisfied(
+        childPlan = parentSmj,
+        requiredOrdering = Seq(ordering),
+        shouldHaveSort = false)
+    }
+  }
+
   test("EnsureRequirements for sort operator after left outer sort merge join") {
     // Only left key is sorted after left outer SMJ (thus doesn't need a sort).
     val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB)