diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 214e8d309de1108831d5846ce4f60d229fa2daf2..7063b08f7c64406c70ec43afbc46a1eb9e9b0ea8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -42,7 +42,9 @@ case class InMemoryTableScanExec(
   override def output: Seq[Attribute] = attributes
 
   private def updateAttribute(expr: Expression): Expression = {
-    val attrMap = AttributeMap(relation.child.output.zip(output))
+    // attributes can be pruned so using relation's output.
+    // E.g., relation.output is [id, item] but this scan's output can be [item] only.
+    val attrMap = AttributeMap(relation.child.output.zip(relation.output))
     expr.transform {
       case attr: Attribute => attrMap.getOrElse(attr, attr)
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 1e6a6a8ba3362ea133df94a329d71802a6f8e8e1..109b1d9db60d2552bbc1f69212badf4ebf9e935f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -414,4 +414,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
       assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet))
     }
   }
+
+  test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") {
+    withSQLConf("spark.sql.shuffle.partitions" -> "200") {
+      val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group")
+      val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id")
+      val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct()
+
+      df3.unpersist()
+      val agg_without_cache = df3.groupBy($"item").count()
+
+      df3.cache()
+      val agg_with_cache = df3.groupBy($"item").count()
+      checkAnswer(agg_without_cache, agg_with_cache)
+    }
+  }
 }