diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1049915986d9b114bfb5e29cc2d6e4eb7a039d77..bb1273f5c3d842bf2868441c99018de34a04399d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
     })
   }
 
-  private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
-    case dt if dt == from => identity[Any]
-    case StringType => castToString(from)
-    case BinaryType => castToBinary(from)
-    case DateType => castToDate(from)
-    case decimal: DecimalType => castToDecimal(from, decimal)
-    case TimestampType => castToTimestamp(from)
-    case CalendarIntervalType => castToInterval(from)
-    case BooleanType => castToBoolean(from)
-    case ByteType => castToByte(from)
-    case ShortType => castToShort(from)
-    case IntegerType => castToInt(from)
-    case FloatType => castToFloat(from)
-    case LongType => castToLong(from)
-    case DoubleType => castToDouble(from)
-    case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
-    case map: MapType => castMap(from.asInstanceOf[MapType], map)
-    case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
-    case udt: UserDefinedType[_]
-      if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
-      identity[Any]
-    case _: UserDefinedType[_] =>
-      throw new SparkException(s"Cannot cast $from to $to.")
+  private[this] def cast(from: DataType, to: DataType): Any => Any = {
+    // If the cast does not change the structure, then we don't really need to cast anything.
+    // We can return what the children return. Same thing should happen in the codegen path.
+    if (DataType.equalsStructurally(from, to)) {
+      identity
+    } else {
+      to match {
+        case dt if dt == from => identity[Any]
+        case StringType => castToString(from)
+        case BinaryType => castToBinary(from)
+        case DateType => castToDate(from)
+        case decimal: DecimalType => castToDecimal(from, decimal)
+        case TimestampType => castToTimestamp(from)
+        case CalendarIntervalType => castToInterval(from)
+        case BooleanType => castToBoolean(from)
+        case ByteType => castToByte(from)
+        case ShortType => castToShort(from)
+        case IntegerType => castToInt(from)
+        case FloatType => castToFloat(from)
+        case LongType => castToLong(from)
+        case DoubleType => castToDouble(from)
+        case array: ArrayType =>
+          castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
+        case map: MapType => castMap(from.asInstanceOf[MapType], map)
+        case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
+        case udt: UserDefinedType[_]
+          if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
+          identity[Any]
+        case _: UserDefinedType[_] =>
+          throw new SparkException(s"Cannot cast $from to $to.")
+      }
+    }
   }
 
   private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
 
   protected override def nullSafeEval(input: Any): Any = cast(input)
 
+  override def genCode(ctx: CodegenContext): ExprCode = {
+    // If the cast does not change the structure, then we don't really need to cast anything.
+    // We can return what the children return. Same thing should happen in the interpreted path.
+    if (DataType.equalsStructurally(child.dataType, dataType)) {
+      child.genCode(ctx)
+    } else {
+      super.genCode(ctx)
+    }
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val eval = child.genCode(ctx)
     val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 520aff5e2b6775e386826b24be983ce5d5fcd4dc..30745c6a9d42a4fc3b94ecdb7262efa20eccff0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -288,4 +288,30 @@ object DataType {
       case (fromDataType, toDataType) => fromDataType == toDataType
     }
   }
+
+  /**
+   * Returns true if the two data types share the same "shape", i.e. the types (including
+   * nullability) are the same, but the field names don't need to be the same.
+   */
+  def equalsStructurally(from: DataType, to: DataType): Boolean = {
+    (from, to) match {
+      case (left: ArrayType, right: ArrayType) =>
+        equalsStructurally(left.elementType, right.elementType) &&
+          left.containsNull == right.containsNull
+
+      case (left: MapType, right: MapType) =>
+        equalsStructurally(left.keyType, right.keyType) &&
+          equalsStructurally(left.valueType, right.valueType) &&
+          left.valueContainsNull == right.valueContainsNull
+
+      case (StructType(fromFields), StructType(toFields)) =>
+        fromFields.length == toFields.length &&
+          fromFields.zip(toFields)
+            .forall { case (l, r) =>
+              equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable
+            }
+
+      case (fromDataType, toDataType) => fromDataType == toDataType
+    }
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 8eccadbdd8afbfa95028adc932ebd9348cd75cde..a7ffa884d2286a54d08a72a1cf32c96beac7b6f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
     assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
   }
+
+  test("SPARK-20302 cast with same structure") {
+    val from = new StructType()
+      .add("a", IntegerType)
+      .add("b", new StructType().add("b1", LongType))
+
+    val to = new StructType()
+      .add("a1", IntegerType)
+      .add("b1", new StructType().add("b11", LongType))
+
+    val input = Row(10, Row(12L))
+
+    checkEvaluation(cast(Literal.create(input, from), to), input)
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index f078ef013387bf32bbae5d0ce1adda13b57e7faf..c4635c8f126af1698c3b020425c9073c6fb0e5f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite {
   checkCatalogString(ArrayType(createStruct(40)))
   checkCatalogString(MapType(IntegerType, StringType))
   checkCatalogString(MapType(IntegerType, createStruct(40)))
+
+  def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = {
+    val testName = s"equalsStructurally: (from: $from, to: $to)"
+    test(testName) {
+      assert(DataType.equalsStructurally(from, to) === expected)
+    }
+  }
+
+  checkEqualsStructurally(BooleanType, BooleanType, true)
+  checkEqualsStructurally(IntegerType, IntegerType, true)
+  checkEqualsStructurally(IntegerType, LongType, false)
+  checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true)
+  checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false)
+
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType),
+    new StructType().add("f2", IntegerType),
+    true)
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType),
+    new StructType().add("f2", IntegerType, false),
+    false)
+
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)),
+    new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
+    true)
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
+    new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
+    false)
 }