diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 4c63abb071e3b392730cfcb7e61cd5d8848d6c86..761f0447943e8e4ecd5033bdb287bf81679574e1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -30,19 +30,18 @@ import org.apache.spark.unsafe.types.UTF8String;
 /**
  * An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
  *
- * Each tuple has two parts: [offsets] [values]
+ * Each tuple has three parts: [numElements] [offsets] [values]
  *
- * In the `offsets` region, we store 4 bytes per element, represents the start address of this
- * element in `values` region. We can get the length of this element by subtracting next offset.
+ * The `numElements` is 4 bytes storing the number of elements of this array.
+ *
+ * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the
+ * base address of the array) of this element in `values` region. We can get the length of this
+ * element by subtracting next offset.
  * Note that offset can by negative which means this element is null.
  *
  * In the `values` region, we store the content of elements. As we can get length info, so elements
  * can be variable-length.
  *
- * Note that when we write out this array, we should write out the `numElements` at first 4 bytes,
- * then follows content. When we read in an array, we should read first 4 bytes as `numElements`
- * and take the rest as content.
- *
  * Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
  */
 // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
@@ -54,11 +53,16 @@ public class UnsafeArrayData extends ArrayData {
   // The number of elements in this array
   private int numElements;
 
-  // The size of this array's backing data, in bytes
+  // The size of this array's backing data, in bytes.
+  // The 4-bytes header of `numElements` is also included.
   private int sizeInBytes;
 
+  public Object getBaseObject() { return baseObject; }
+  public long getBaseOffset() { return baseOffset; }
+  public int getSizeInBytes() { return sizeInBytes; }
+
   private int getElementOffset(int ordinal) {
-    return Platform.getInt(baseObject, baseOffset + ordinal * 4L);
+    return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
   }
 
   private int getElementSize(int offset, int ordinal) {
@@ -85,10 +89,6 @@ public class UnsafeArrayData extends ArrayData {
    */
   public UnsafeArrayData() { }
 
-  public Object getBaseObject() { return baseObject; }
-  public long getBaseOffset() { return baseOffset; }
-  public int getSizeInBytes() { return sizeInBytes; }
-
   @Override
   public int numElements() { return numElements; }
 
@@ -97,10 +97,13 @@ public class UnsafeArrayData extends ArrayData {
    *
    * @param baseObject the base object
    * @param baseOffset the offset within the base object
-   * @param sizeInBytes the size of this row's backing data, in bytes
+   * @param sizeInBytes the size of this array's backing data, in bytes
    */
-  public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) {
+  public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+    // Read the number of elements from the first 4 bytes.
+    final int numElements = Platform.getInt(baseObject, baseOffset);
     assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
+
     this.numElements = numElements;
     this.baseObject = baseObject;
     this.baseOffset = baseOffset;
@@ -277,7 +280,9 @@ public class UnsafeArrayData extends ArrayData {
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
     final int size = getElementSize(offset, ordinal);
-    return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+    final UnsafeArrayData array = new UnsafeArrayData();
+    array.pointTo(baseObject, baseOffset + offset, size);
+    return array;
   }
 
   @Override
@@ -286,7 +291,9 @@ public class UnsafeArrayData extends ArrayData {
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
     final int size = getElementSize(offset, ordinal);
-    return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+    final UnsafeMapData map = new UnsafeMapData();
+    map.pointTo(baseObject, baseOffset + offset, size);
+    return map;
   }
 
   @Override
@@ -328,7 +335,7 @@ public class UnsafeArrayData extends ArrayData {
     final byte[] arrayDataCopy = new byte[sizeInBytes];
     Platform.copyMemory(
       baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
-    arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes);
+    arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
     return arrayCopy;
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index e9dab9edb6bd17cd7a4354878488e7ec605d1632..5bebe2a96e391e66ddabd38a1c118aba167c8cbe 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,41 +17,73 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
+import java.nio.ByteBuffer;
+
 import org.apache.spark.sql.types.MapData;
+import org.apache.spark.unsafe.Platform;
 
 /**
  * An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
  *
- * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
- *
- * Note that when we write out this map, we should write out the `numElements` at first 4 bytes,
- * and numBytes of key array at second 4 bytes, then follows key array content and value array
- * content without `numElements` header.
- * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as
- * numBytes of key array, and construct unsafe key array and value array with these 2 information.
+ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head
+ * to indicate the number of bytes of the unsafe key array.
+ * [unsafe key array numBytes] [unsafe key array] [unsafe value array]
  */
+// TODO: Use a more efficient format which doesn't depend on unsafe array.
 public class UnsafeMapData extends MapData {
 
-  private final UnsafeArrayData keys;
-  private final UnsafeArrayData values;
-  // The number of elements in this array
-  private int numElements;
-  // The size of this array's backing data, in bytes
+  private Object baseObject;
+  private long baseOffset;
+
+  // The size of this map's backing data, in bytes.
+  // The 4-bytes header of key array `numBytes` is also included, so it's actually equal to
+  // 4 + key array numBytes + value array numBytes.
   private int sizeInBytes;
 
+  public Object getBaseObject() { return baseObject; }
+  public long getBaseOffset() { return baseOffset; }
   public int getSizeInBytes() { return sizeInBytes; }
 
-  public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) {
+  private final UnsafeArrayData keys;
+  private final UnsafeArrayData values;
+
+  /**
+   * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be usable until
+   * `pointTo()` has been called, since the value returned by this constructor is equivalent
+   * to a null pointer.
+   */
+  public UnsafeMapData() {
+    keys = new UnsafeArrayData();
+    values = new UnsafeArrayData();
+  }
+
+  /**
+   * Update this UnsafeMapData to point to different backing data.
+   *
+   * @param baseObject the base object
+   * @param baseOffset the offset within the base object
+   * @param sizeInBytes the size of this map's backing data, in bytes
+   */
+  public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+    // Read the numBytes of key array from the first 4 bytes.
+    final int keyArraySize = Platform.getInt(baseObject, baseOffset);
+    final int valueArraySize = sizeInBytes - keyArraySize - 4;
+    assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0";
+    assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0";
+
+    keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
+    values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
+
     assert keys.numElements() == values.numElements();
-    this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes();
-    this.numElements = keys.numElements();
-    this.keys = keys;
-    this.values = values;
+
+    this.baseObject = baseObject;
+    this.baseOffset = baseOffset;
+    this.sizeInBytes = sizeInBytes;
   }
 
   @Override
   public int numElements() {
-    return numElements;
+    return keys.numElements();
   }
 
   @Override
@@ -64,8 +96,26 @@ public class UnsafeMapData extends MapData {
     return values;
   }
 
+  public void writeToMemory(Object target, long targetOffset) {
+    Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
+  }
+
+  public void writeTo(ByteBuffer buffer) {
+    assert(buffer.hasArray());
+    byte[] target = buffer.array();
+    int offset = buffer.arrayOffset();
+    int pos = buffer.position();
+    writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+    buffer.position(pos + sizeInBytes);
+  }
+
   @Override
   public UnsafeMapData copy() {
-    return new UnsafeMapData(keys.copy(), values.copy());
+    UnsafeMapData mapCopy = new UnsafeMapData();
+    final byte[] mapDataCopy = new byte[sizeInBytes];
+    Platform.copyMemory(
+      baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+    mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+    return mapCopy;
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
deleted file mode 100644
index 6c5fcbca63fd79f08d38c86aef0cc4ba41f27f0b..0000000000000000000000000000000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import org.apache.spark.unsafe.Platform;
-
-public class UnsafeReaders {
-
-  /**
-   * Reads in unsafe array according to the format described in `UnsafeArrayData`.
-   */
-  public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
-    // Read the number of elements from first 4 bytes.
-    final int numElements = Platform.getInt(baseObject, baseOffset);
-    final UnsafeArrayData array = new UnsafeArrayData();
-    // Skip the first 4 bytes.
-    array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4);
-    return array;
-  }
-
-  /**
-   * Reads in unsafe map according to the format described in `UnsafeMapData`.
-   */
-  public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
-    // Read the number of elements from first 4 bytes.
-    final int numElements = Platform.getInt(baseObject, baseOffset);
-    // Read the numBytes of key array in second 4 bytes.
-    final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4);
-    final int valueArraySize = numBytes - 8 - keyArraySize;
-
-    final UnsafeArrayData keyArray = new UnsafeArrayData();
-    keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize);
-
-    final UnsafeArrayData valueArray = new UnsafeArrayData();
-    valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize);
-
-    return new UnsafeMapData(keyArray, valueArray);
-  }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 36859fbab97449efa36f65dc776de5d3609d72ef..366615f6fe69fc96425414637005dcdbb202cd23 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -461,7 +461,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
       final long offsetAndSize = getLong(ordinal);
       final int offset = (int) (offsetAndSize >> 32);
       final int size = (int) (offsetAndSize & ((1L << 32) - 1));
-      return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+      final UnsafeArrayData array = new UnsafeArrayData();
+      array.pointTo(baseObject, baseOffset + offset, size);
+      return array;
     }
   }
 
@@ -473,7 +475,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
       final long offsetAndSize = getLong(ordinal);
       final int offset = (int) (offsetAndSize >> 32);
       final int size = (int) (offsetAndSize & ((1L << 32) - 1));
-      return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+      final UnsafeMapData map = new UnsafeMapData();
+      map.pointTo(baseObject, baseOffset + offset, size);
+      return map;
     }
   }
 
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 138178ce99d853d1a510286d6e7048011b980c61..7f2a1cb07af0179b50d20a1f84438b0d9e0ddd5d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -30,17 +30,19 @@ import org.apache.spark.unsafe.types.UTF8String;
 public class UnsafeArrayWriter {
 
   private BufferHolder holder;
+
   // The offset of the global buffer where we start to write this array.
   private int startingOffset;
 
   public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
-    // We need 4 bytes each element to store offset.
-    final int fixedSize = 4 * numElements;
+    // We need 4 bytes to store numElements and 4 bytes each element to store offset.
+    final int fixedSize = 4 + 4 * numElements;
 
     this.holder = holder;
     this.startingOffset = holder.cursor;
 
     holder.grow(fixedSize);
+    Platform.putInt(holder.buffer, holder.cursor, numElements);
     holder.cursor += fixedSize;
 
     // Grows the global buffer ahead for fixed size data.
@@ -48,7 +50,7 @@ public class UnsafeArrayWriter {
   }
 
   private long getElementOffset(int ordinal) {
-    return startingOffset + 4 * ordinal;
+    return startingOffset + 4 + 4 * ordinal;
   }
 
   public void setNullAt(int ordinal) {
@@ -132,20 +134,4 @@ public class UnsafeArrayWriter {
     // move the cursor forward.
     holder.cursor += 16;
   }
-
-
-
-  // If this array is already an UnsafeArray, we don't need to go through all elements, we can
-  // directly write it.
-  public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
-    final int numBytes = input.getSizeInBytes();
-
-    // grow the global buffer before writing data.
-    holder.grow(numBytes);
-
-    // Writes the array content to the variable length portion.
-    input.writeToMemory(holder.buffer, holder.cursor);
-
-    holder.cursor += numBytes;
-  }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 8b7debd440031be801cef3bfa6c493add33e70a1..e1f5a05d1d446f4b8cc332b0fd690df42f5bd629 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -181,19 +181,4 @@ public class UnsafeRowWriter {
     // move the cursor forward.
     holder.cursor += 16;
   }
-
-
-
-  // If this struct is already an UnsafeRow, we don't need to go through all fields, we can
-  // directly write it.
-  public static void directWrite(BufferHolder holder, UnsafeRow input) {
-    // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
-    final int numBytes = input.getSizeInBytes();
-    // grow the global buffer before writing data.
-    holder.grow(numBytes);
-    // Write the bytes to the variable length portion.
-    input.writeToMemory(holder.buffer, holder.cursor);
-    // move the cursor forward.
-    holder.cursor += numBytes;
-  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1b957a508d10e2db49e6c2e22248fadea6285235..dbe92d6a8350200a3ad500e06a783a1f14d95cd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -62,7 +62,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
 
     s"""
       if ($input instanceof UnsafeRow) {
-        $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+        ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
       } else {
         ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
       }
@@ -164,8 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
       ctx: CodeGenContext,
       input: String,
       elementType: DataType,
-      bufferHolder: String,
-      needHeader: Boolean = true): String = {
+      bufferHolder: String): String = {
     val arrayWriter = ctx.freshName("arrayWriter")
     ctx.addMutableState(arrayWriterClass, arrayWriter,
       s"this.$arrayWriter = new $arrayWriterClass();")
@@ -227,21 +226,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
       case _ => s"$arrayWriter.write($index, $element);"
     }
 
-    val writeHeader = if (needHeader) {
-      // If header is required, we need to write the number of elements into first 4 bytes.
-      s"""
-        $bufferHolder.grow(4);
-        Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements);
-        $bufferHolder.cursor += 4;
-      """
-    } else ""
-
     s"""
-      final int $numElements = $input.numElements();
-      $writeHeader
       if ($input instanceof UnsafeArrayData) {
-        $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+        ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
       } else {
+        final int $numElements = $input.numElements();
         $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
 
         for (int $index = 0; $index < $numElements; $index++) {
@@ -270,23 +259,40 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
 
     // Writes out unsafe map according to the format described in `UnsafeMapData`.
     s"""
-      final ArrayData $keys = $input.keyArray();
-      final ArrayData $values = $input.valueArray();
+      if ($input instanceof UnsafeMapData) {
+        ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
+      } else {
+        final ArrayData $keys = $input.keyArray();
+        final ArrayData $values = $input.valueArray();
 
-      $bufferHolder.grow(8);
+        // preserve 4 bytes to write the key array numBytes later.
+        $bufferHolder.grow(4);
+        $bufferHolder.cursor += 4;
 
-      // Write the numElements into first 4 bytes.
-      Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements());
+        // Remember the current cursor so that we can write numBytes of key array later.
+        final int $tmpCursor = $bufferHolder.cursor;
 
-      $bufferHolder.cursor += 8;
-      // Remember the current cursor so that we can write numBytes of key array later.
-      final int $tmpCursor = $bufferHolder.cursor;
+        ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
+        // Write the numBytes of key array into the first 4 bytes.
+        Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
 
-      ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)}
-      // Write the numBytes of key array into second 4 bytes.
-      Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+        ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
+      }
+    """
+  }
 
-      ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)}
+  /**
+   * If the input is already in unsafe format, we don't need to go through all elements/fields,
+   * we can directly write it.
+   */
+  private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = {
+    val sizeInBytes = ctx.freshName("sizeInBytes")
+    s"""
+      final int $sizeInBytes = $input.getSizeInBytes();
+      // grow the global buffer before writing data.
+      $bufferHolder.grow($sizeInBytes);
+      $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
+      $bufferHolder.cursor += $sizeInBytes;
     """
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index c991cd86d28c8ac0070c3017e5893b04ff8de2b3..c6aad34e972b598732db8104497c3f18e2d985a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -296,13 +296,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
   }
 
-  private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
-
-  private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
-
   private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
     assert(array.numElements == values.length)
-    assert(array.getSizeInBytes == (4 + 4) * values.length)
+    assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
     values.zipWithIndex.foreach {
       case (value, index) => assert(array.getInt(index) == value)
     }
@@ -315,7 +311,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     testArrayInt(map.keyArray, keys)
     testArrayInt(map.valueArray, values)
 
-    assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+    assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
   }
 
   test("basic conversion with array type") {
@@ -341,10 +337,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val nestedArray = unsafeArray2.getArray(0)
     testArrayInt(nestedArray, Seq(3, 4))
 
-    assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
+    assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
 
-    val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
-    val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
+    val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
+    val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
   }
 
@@ -384,13 +380,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       val nestedMap = valueArray.getMap(0)
       testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
 
-      assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
+      assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
     }
 
-    assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
 
-    val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
-    val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+    val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
+    val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
   }
 
@@ -414,7 +410,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val innerArray = field1.getArray(0)
     testArrayInt(innerArray, Seq(1))
 
-    assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes))
+    assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerArray.getSizeInBytes))
 
     val field2 = unsafeRow.getArray(1)
     assert(field2.numElements == 1)
@@ -427,10 +423,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       assert(innerStruct.getLong(0) == 2L)
     }
 
-    assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+    assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
 
     assert(unsafeRow.getSizeInBytes ==
-      8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes))
+      8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
   }
 
   test("basic conversion with struct and map") {
@@ -453,7 +449,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val innerMap = field1.getMap(0)
     testMapInt(innerMap, Seq(1), Seq(2))
 
-    assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes))
+    assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerMap.getSizeInBytes))
 
     val field2 = unsafeRow.getMap(1)
 
@@ -470,13 +466,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       assert(innerStruct.getSizeInBytes == 8 + 8)
       assert(innerStruct.getLong(0) == 4L)
 
-      assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+      assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
     }
 
-    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
 
     assert(unsafeRow.getSizeInBytes ==
-      8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+      8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
   }
 
   test("basic conversion with array and map") {
@@ -499,7 +495,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     val innerMap = field1.getMap(0)
     testMapInt(innerMap, Seq(1), Seq(2))
 
-    assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+    assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
 
     val field2 = unsafeRow.getMap(1)
     assert(field2.numElements == 1)
@@ -518,9 +514,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
     }
 
-    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
 
     assert(unsafeRow.getSizeInBytes ==
-      8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes))
+      8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2bc2c96b61634f630f34b9931ec7e6725dd62432..a41f04dd3b59a9d8849810e37dcee919c4cabbd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -482,12 +482,14 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
   override def extract(buffer: ByteBuffer): UnsafeRow = {
     val sizeInBytes = buffer.getInt()
     assert(buffer.hasArray)
-    val base = buffer.array()
-    val offset = buffer.arrayOffset()
     val cursor = buffer.position()
     buffer.position(cursor + sizeInBytes)
     val unsafeRow = new UnsafeRow
-    unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes)
+    unsafeRow.pointTo(
+      buffer.array(),
+      Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
+      numOfFields,
+      sizeInBytes)
     unsafeRow
   }
 
@@ -508,12 +510,11 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
 
   override def actualSize(row: InternalRow, ordinal: Int): Int = {
     val unsafeArray = getField(row, ordinal)
-    4 + 4 + unsafeArray.getSizeInBytes
+    4 + unsafeArray.getSizeInBytes
   }
 
   override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
-    buffer.putInt(4 + value.getSizeInBytes)
-    buffer.putInt(value.numElements())
+    buffer.putInt(value.getSizeInBytes)
     value.writeTo(buffer)
   }
 
@@ -522,10 +523,12 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
     assert(buffer.hasArray)
     val cursor = buffer.position()
     buffer.position(cursor + numBytes)
-    UnsafeReaders.readArray(
+    val array = new UnsafeArrayData
+    array.pointTo(
       buffer.array(),
       Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
       numBytes)
+    array
   }
 
   override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
@@ -545,15 +548,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
 
   override def actualSize(row: InternalRow, ordinal: Int): Int = {
     val unsafeMap = getField(row, ordinal)
-    12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes
+    4 + unsafeMap.getSizeInBytes
   }
 
   override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {
-    buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes)
-    buffer.putInt(value.numElements())
-    buffer.putInt(value.keyArray().getSizeInBytes)
-    value.keyArray().writeTo(buffer)
-    value.valueArray().writeTo(buffer)
+    buffer.putInt(value.getSizeInBytes)
+    value.writeTo(buffer)
   }
 
   override def extract(buffer: ByteBuffer): UnsafeMapData = {
@@ -561,10 +561,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
     assert(buffer.hasArray)
     val cursor = buffer.position()
     buffer.position(cursor + numBytes)
-    UnsafeReaders.readMap(
+    val map = new UnsafeMapData
+    map.pointTo(
       buffer.array(),
       Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
       numBytes)
+    map
   }
 
   override def clone(v: UnsafeMapData): UnsafeMapData = v.copy()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 0e6e1bcf72896e71e6da6ce84891bc993375b635..63bc39bfa0307fad3c0ade6878fe0199a208401d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -73,7 +73,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
     checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
     checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
     checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
-    checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
+    checkActualSize(MAP_TYPE, Map(1 -> "a"), 29)
     checkActualSize(STRUCT_TYPE, Row("hello"), 28)
   }