diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 501dff090313cef44b0b3a2d15dda5d4a607e56e..da9538b3f13cc2f4d448ac16447e3058afe0e634 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions;
 import java.math.BigDecimal;
 import java.math.BigInteger;
 
-import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.types.*;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -256,7 +255,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   @Override
-  public InternalRow getStruct(int ordinal, int numFields) {
+  public UnsafeRow getStruct(int ordinal, int numFields) {
     assertIndexIsValid(ordinal);
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
@@ -267,7 +266,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   @Override
-  public ArrayData getArray(int ordinal) {
+  public UnsafeArrayData getArray(int ordinal) {
     assertIndexIsValid(ordinal);
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
@@ -276,7 +275,7 @@ public class UnsafeArrayData extends ArrayData {
   }
 
   @Override
-  public MapData getMap(int ordinal) {
+  public UnsafeMapData getMap(int ordinal) {
     assertIndexIsValid(ordinal);
     final int offset = getElementOffset(ordinal);
     if (offset < 0) return null;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index 46216054ab38bee3de846632b8bfa27bf14b3240..e9dab9edb6bd17cd7a4354878488e7ec605d1632 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,18 +17,23 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
-import org.apache.spark.sql.types.ArrayData;
 import org.apache.spark.sql.types.MapData;
 
 /**
  * An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
  *
  * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
+ *
+ * Note that when we write out this map, we should write out the `numElements` at first 4 bytes,
+ * and numBytes of key array at second 4 bytes, then follows key array content and value array
+ * content without `numElements` header.
+ * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as
+ * numBytes of key array, and construct unsafe key array and value array with these 2 information.
  */
 public class UnsafeMapData extends MapData {
 
-  public final UnsafeArrayData keys;
-  public final UnsafeArrayData values;
+  private final UnsafeArrayData keys;
+  private final UnsafeArrayData values;
   // The number of elements in this array
   private int numElements;
   // The size of this array's backing data, in bytes
@@ -50,12 +55,12 @@ public class UnsafeMapData extends MapData {
   }
 
   @Override
-  public ArrayData keyArray() {
+  public UnsafeArrayData keyArray() {
     return keys;
   }
 
   @Override
-  public ArrayData valueArray() {
+  public UnsafeArrayData valueArray() {
     return values;
   }
 
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
index 7b03185a30e3c48bdca66291ae658c8545edae7c..6c5fcbca63fd79f08d38c86aef0cc4ba41f27f0b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
@@ -21,6 +21,9 @@ import org.apache.spark.unsafe.Platform;
 
 public class UnsafeReaders {
 
+  /**
+   * Reads in unsafe array according to the format described in `UnsafeArrayData`.
+   */
   public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
     // Read the number of elements from first 4 bytes.
     final int numElements = Platform.getInt(baseObject, baseOffset);
@@ -30,6 +33,9 @@ public class UnsafeReaders {
     return array;
   }
 
+  /**
+   * Reads in unsafe map according to the format described in `UnsafeMapData`.
+   */
   public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
     // Read the number of elements from first 4 bytes.
     final int numElements = Platform.getInt(baseObject, baseOffset);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6c020045c311ac343bfa1bd1f33c098c25ec6a31..e8ac2999c2d29569507ccdb6e66966c66e70e91c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -446,7 +446,7 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
-  public ArrayData getArray(int ordinal) {
+  public UnsafeArrayData getArray(int ordinal) {
     if (isNullAt(ordinal)) {
       return null;
     } else {
@@ -458,7 +458,7 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
-  public MapData getMap(int ordinal) {
+  public UnsafeMapData getMap(int ordinal) {
     if (isNullAt(ordinal)) {
       return null;
     } else {
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 2f43db68a750e757686eae249d6b4e18b422d791..0f1e0202aace15a303e83040b747ad7b8ef00c91 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -233,8 +233,8 @@ public class UnsafeRowWriters {
 
     public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) {
       final long offset = target.getBaseOffset() + cursor;
-      final UnsafeArrayData keyArray = input.keys;
-      final UnsafeArrayData valueArray = input.values;
+      final UnsafeArrayData keyArray = input.keyArray();
+      final UnsafeArrayData valueArray = input.valueArray();
       final int keysNumBytes = keyArray.getSizeInBytes();
       final int valuesNumBytes = valueArray.getSizeInBytes();
       final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
index cd83695fca033566f7eb26fcf170fbe0eec9500f..ce2d9c4ffbf9f1860db507fffccc4bc985c29dd1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
@@ -168,8 +168,8 @@ public class UnsafeWriters {
     }
 
     public static int write(Object targetObject, long targetOffset, UnsafeMapData input) {
-      final UnsafeArrayData keyArray = input.keys;
-      final UnsafeArrayData valueArray = input.values;
+      final UnsafeArrayData keyArray = input.keyArray();
+      final UnsafeArrayData valueArray = input.valueArray();
       final int keysNumBytes = keyArray.getSizeInBytes();
       final int valuesNumBytes = valueArray.getSizeInBytes();
       final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
new file mode 100644
index 0000000000000000000000000000000000000000..9c9468678065d43ca299ba39ec767c9931d03d0e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * A helper class to manage the row buffer used in `GenerateUnsafeProjection`.
+ *
+ * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables
+ * public for ease of use.
+ */
+public class BufferHolder {
+  public byte[] buffer = new byte[64];
+  public int cursor = Platform.BYTE_ARRAY_OFFSET;
+
+  public void grow(int neededSize) {
+    final int length = totalSize() + neededSize;
+    if (buffer.length < length) {
+      // This will not happen frequently, because the buffer is re-used.
+      final byte[] tmp = new byte[length * 2];
+      Platform.copyMemory(
+        buffer,
+        Platform.BYTE_ARRAY_OFFSET,
+        tmp,
+        Platform.BYTE_ARRAY_OFFSET,
+        totalSize());
+      buffer = tmp;
+    }
+  }
+
+  public void reset() {
+    cursor = Platform.BYTE_ARRAY_OFFSET;
+  }
+
+  public int totalSize() {
+    return cursor - Platform.BYTE_ARRAY_OFFSET;
+  }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
new file mode 100644
index 0000000000000000000000000000000000000000..138178ce99d853d1a510286d6e7048011b980c61
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write data into global row buffer using `UnsafeArrayData` format,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeArrayWriter {
+
+  private BufferHolder holder;
+  // The offset of the global buffer where we start to write this array.
+  private int startingOffset;
+
+  public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
+    // We need 4 bytes each element to store offset.
+    final int fixedSize = 4 * numElements;
+
+    this.holder = holder;
+    this.startingOffset = holder.cursor;
+
+    holder.grow(fixedSize);
+    holder.cursor += fixedSize;
+
+    // Grows the global buffer ahead for fixed size data.
+    holder.grow(fixedElementSize * numElements);
+  }
+
+  private long getElementOffset(int ordinal) {
+    return startingOffset + 4 * ordinal;
+  }
+
+  public void setNullAt(int ordinal) {
+    final int relativeOffset = holder.cursor - startingOffset;
+    // Writes negative offset value to represent null element.
+    Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset);
+  }
+
+  public void setOffset(int ordinal) {
+    final int relativeOffset = holder.cursor - startingOffset;
+    Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
+  }
+
+  public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
+    // make sure Decimal object has the same scale as DecimalType
+    if (input.changePrecision(precision, scale)) {
+      Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
+      setOffset(ordinal);
+      holder.cursor += 8;
+    } else {
+      setNullAt(ordinal);
+    }
+  }
+
+  public void write(int ordinal, Decimal input, int precision, int scale) {
+    // make sure Decimal object has the same scale as DecimalType
+    if (input.changePrecision(precision, scale)) {
+      final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+      assert bytes.length <= 16;
+      holder.grow(bytes.length);
+
+      // Write the bytes to the variable length portion.
+      Platform.copyMemory(
+        bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+      setOffset(ordinal);
+      holder.cursor += bytes.length;
+    } else {
+      setNullAt(ordinal);
+    }
+  }
+
+  public void write(int ordinal, UTF8String input) {
+    final int numBytes = input.numBytes();
+
+    // grow the global buffer before writing data.
+    holder.grow(numBytes);
+
+    // Write the bytes to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+
+    setOffset(ordinal);
+
+    // move the cursor forward.
+    holder.cursor += numBytes;
+  }
+
+  public void write(int ordinal, byte[] input) {
+    // grow the global buffer before writing data.
+    holder.grow(input.length);
+
+    // Write the bytes to the variable length portion.
+    Platform.copyMemory(
+      input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length);
+
+    setOffset(ordinal);
+
+    // move the cursor forward.
+    holder.cursor += input.length;
+  }
+
+  public void write(int ordinal, CalendarInterval input) {
+    // grow the global buffer before writing data.
+    holder.grow(16);
+
+    // Write the months and microseconds fields of Interval to the variable length portion.
+    Platform.putLong(holder.buffer, holder.cursor, input.months);
+    Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
+
+    setOffset(ordinal);
+
+    // move the cursor forward.
+    holder.cursor += 16;
+  }
+
+
+
+  // If this array is already an UnsafeArray, we don't need to go through all elements, we can
+  // directly write it.
+  public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
+    final int numBytes = input.getSizeInBytes();
+
+    // grow the global buffer before writing data.
+    holder.grow(numBytes);
+
+    // Writes the array content to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+
+    holder.cursor += numBytes;
+  }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
new file mode 100644
index 0000000000000000000000000000000000000000..8b7debd440031be801cef3bfa6c493add33e70a1
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write data into global row buffer using `UnsafeRow` format,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeRowWriter {
+
+  private BufferHolder holder;
+  // The offset of the global buffer where we start to write this row.
+  private int startingOffset;
+  private int nullBitsSize;
+
+  public void initialize(BufferHolder holder, int numFields) {
+    this.holder = holder;
+    this.startingOffset = holder.cursor;
+    this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
+
+    // grow the global buffer to make sure it has enough space to write fixed-length data.
+    final int fixedSize = nullBitsSize + 8 * numFields;
+    holder.grow(fixedSize);
+    holder.cursor += fixedSize;
+
+    // zero-out the null bits region
+    for (int i = 0; i < nullBitsSize; i += 8) {
+      Platform.putLong(holder.buffer, startingOffset + i, 0L);
+    }
+  }
+
+  private void zeroOutPaddingBytes(int numBytes) {
+    if ((numBytes & 0x07) > 0) {
+      Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+    }
+  }
+
+  public void setNullAt(int ordinal) {
+    BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+    Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
+  }
+
+  public long getFieldOffset(int ordinal) {
+    return startingOffset + nullBitsSize + 8 * ordinal;
+  }
+
+  public void setOffsetAndSize(int ordinal, long size) {
+    setOffsetAndSize(ordinal, holder.cursor, size);
+  }
+
+  public void setOffsetAndSize(int ordinal, long currentCursor, long size) {
+    final long relativeOffset = currentCursor - startingOffset;
+    final long fieldOffset = getFieldOffset(ordinal);
+    final long offsetAndSize = (relativeOffset << 32) | size;
+
+    Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
+  }
+
+  // Do word alignment for this row and grow the row buffer if needed.
+  // todo: remove this after we make unsafe array data word align.
+  public void alignToWords(int numBytes) {
+    final int remainder = numBytes & 0x07;
+
+    if (remainder > 0) {
+      final int paddingBytes = 8 - remainder;
+      holder.grow(paddingBytes);
+
+      for (int i = 0; i < paddingBytes; i++) {
+        Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
+        holder.cursor++;
+      }
+    }
+  }
+
+  public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
+    // make sure Decimal object has the same scale as DecimalType
+    if (input.changePrecision(precision, scale)) {
+      Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+    } else {
+      setNullAt(ordinal);
+    }
+  }
+
+  public void write(int ordinal, Decimal input, int precision, int scale) {
+    // grow the global buffer before writing data.
+    holder.grow(16);
+
+    // zero-out the bytes
+    Platform.putLong(holder.buffer, holder.cursor, 0L);
+    Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
+
+    // Make sure Decimal object has the same scale as DecimalType.
+    // Note that we may pass in null Decimal object to set null for it.
+    if (input == null || !input.changePrecision(precision, scale)) {
+      BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+      // keep the offset for future update
+      setOffsetAndSize(ordinal, 0L);
+    } else {
+      final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+      assert bytes.length <= 16;
+
+      // Write the bytes to the variable length portion.
+      Platform.copyMemory(
+        bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+      setOffsetAndSize(ordinal, bytes.length);
+    }
+
+    // move the cursor forward.
+    holder.cursor += 16;
+  }
+
+  public void write(int ordinal, UTF8String input) {
+    final int numBytes = input.numBytes();
+    final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+    // grow the global buffer before writing data.
+    holder.grow(roundedSize);
+
+    zeroOutPaddingBytes(numBytes);
+
+    // Write the bytes to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+
+    setOffsetAndSize(ordinal, numBytes);
+
+    // move the cursor forward.
+    holder.cursor += roundedSize;
+  }
+
+  public void write(int ordinal, byte[] input) {
+    final int numBytes = input.length;
+    final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+    // grow the global buffer before writing data.
+    holder.grow(roundedSize);
+
+    zeroOutPaddingBytes(numBytes);
+
+    // Write the bytes to the variable length portion.
+    Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
+
+    setOffsetAndSize(ordinal, numBytes);
+
+    // move the cursor forward.
+    holder.cursor += roundedSize;
+  }
+
+  public void write(int ordinal, CalendarInterval input) {
+    // grow the global buffer before writing data.
+    holder.grow(16);
+
+    // Write the months and microseconds fields of Interval to the variable length portion.
+    Platform.putLong(holder.buffer, holder.cursor, input.months);
+    Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
+
+    setOffsetAndSize(ordinal, 16);
+
+    // move the cursor forward.
+    holder.cursor += 16;
+  }
+
+
+
+  // If this struct is already an UnsafeRow, we don't need to go through all fields, we can
+  // directly write it.
+  public static void directWrite(BufferHolder holder, UnsafeRow input) {
+    // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
+    final int numBytes = input.getSizeInBytes();
+    // grow the global buffer before writing data.
+    holder.grow(numBytes);
+    // Write the bytes to the variable length portion.
+    input.writeToMemory(holder.buffer, holder.cursor);
+    // move the cursor forward.
+    holder.cursor += numBytes;
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index da3103b4ebb6b7242f73392b898030d7f04f8dae..9a2878113304cdd7ea88461858c28a8c409d61d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -272,6 +272,7 @@ class CodeGenContext {
    * 64kb code size limit in JVM
    *
    * @param row the variable name of row that is used by expressions
+   * @param expressions the codes to evaluate expressions.
    */
   def splitExpressions(row: String, expressions: Seq[String]): String = {
     val blocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 55562facf9652fcf40ca4b11e30780a1a2333dea..99bf50a84571b727815828d5900c12a71f02bf5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -393,10 +393,292 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     case _ => input
   }
 
+  private val rowWriterClass = classOf[UnsafeRowWriter].getName
+  private val arrayWriterClass = classOf[UnsafeArrayWriter].getName
+
+  // TODO: if the nullability of field is correct, we can use it to save null check.
+  private def writeStructToBuffer(
+      ctx: CodeGenContext,
+      input: String,
+      fieldTypes: Seq[DataType],
+      bufferHolder: String): String = {
+    val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
+      val fieldName = ctx.freshName("fieldName")
+      val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
+      val isNull = s"$input.isNullAt($i)"
+      GeneratedExpressionCode(code, isNull, fieldName)
+    }
+
+    s"""
+      if ($input instanceof UnsafeRow) {
+        $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+      } else {
+        ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
+      }
+    """
+  }
+
+  private def writeExpressionsToBuffer(
+      ctx: CodeGenContext,
+      row: String,
+      inputs: Seq[GeneratedExpressionCode],
+      inputTypes: Seq[DataType],
+      bufferHolder: String): String = {
+    val rowWriter = ctx.freshName("rowWriter")
+    ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
+
+    val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
+      case ((input, dt), index) =>
+        val tmpCursor = ctx.freshName("tmpCursor")
+
+        val setNull = dt match {
+          case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+            // Can't call setNullAt() for DecimalType with precision larger than 18.
+            s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});"
+          case _ => s"$rowWriter.setNullAt($index);"
+        }
+
+        val writeField = dt match {
+          case t: StructType =>
+            s"""
+              // Remember the current cursor so that we can calculate how many bytes are
+              // written later.
+              final int $tmpCursor = $bufferHolder.cursor;
+              ${writeStructToBuffer(ctx, input.primitive, t.map(_.dataType), bufferHolder)}
+              $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+            """
+
+          case a @ ArrayType(et, _) =>
+            s"""
+              // Remember the current cursor so that we can calculate how many bytes are
+              // written later.
+              final int $tmpCursor = $bufferHolder.cursor;
+              ${writeArrayToBuffer(ctx, input.primitive, et, bufferHolder)}
+              $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+              $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
+            """
+
+          case m @ MapType(kt, vt, _) =>
+            s"""
+              // Remember the current cursor so that we can calculate how many bytes are
+              // written later.
+              final int $tmpCursor = $bufferHolder.cursor;
+              ${writeMapToBuffer(ctx, input.primitive, kt, vt, bufferHolder)}
+              $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+              $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
+            """
+
+          case _ if ctx.isPrimitiveType(dt) =>
+            val fieldOffset = ctx.freshName("fieldOffset")
+            s"""
+              final long $fieldOffset = $rowWriter.getFieldOffset($index);
+              Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L);
+              ${writePrimitiveType(ctx, input.primitive, dt, s"$bufferHolder.buffer", fieldOffset)}
+            """
+
+          case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+            s"$rowWriter.writeCompactDecimal($index, ${input.primitive}, " +
+              s"${t.precision}, ${t.scale});"
+
+          case t: DecimalType =>
+            s"$rowWriter.write($index, ${input.primitive}, ${t.precision}, ${t.scale});"
+
+          case NullType => ""
+
+          case _ => s"$rowWriter.write($index, ${input.primitive});"
+        }
+
+        s"""
+          ${input.code}
+          if (${input.isNull}) {
+            $setNull
+          } else {
+            $writeField
+          }
+        """
+    }
+
+    s"""
+      $rowWriter.initialize($bufferHolder, ${inputs.length});
+      ${ctx.splitExpressions(row, writeFields)}
+    """
+  }
+
+  // TODO: if the nullability of array element is correct, we can use it to save null check.
+  private def writeArrayToBuffer(
+      ctx: CodeGenContext,
+      input: String,
+      elementType: DataType,
+      bufferHolder: String,
+      needHeader: Boolean = true): String = {
+    val arrayWriter = ctx.freshName("arrayWriter")
+    ctx.addMutableState(arrayWriterClass, arrayWriter,
+      s"this.$arrayWriter = new $arrayWriterClass();")
+    val numElements = ctx.freshName("numElements")
+    val index = ctx.freshName("index")
+    val element = ctx.freshName("element")
+
+    val jt = ctx.javaType(elementType)
+
+    val fixedElementSize = elementType match {
+      case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
+      case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize
+      case _ => 0
+    }
+
+    val writeElement = elementType match {
+      case t: StructType =>
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
+        """
+
+      case a @ ArrayType(et, _) =>
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writeArrayToBuffer(ctx, element, et, bufferHolder)}
+        """
+
+      case m @ MapType(kt, vt, _) =>
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
+        """
+
+      case _ if ctx.isPrimitiveType(elementType) =>
+        // Should we do word align?
+        val dataSize = elementType.defaultSize
+
+        s"""
+          $arrayWriter.setOffset($index);
+          ${writePrimitiveType(ctx, element, elementType,
+            s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
+          $bufferHolder.cursor += $dataSize;
+        """
+
+      case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+        s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});"
+
+      case t: DecimalType =>
+        s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});"
+
+      case NullType => ""
+
+      case _ => s"$arrayWriter.write($index, $element);"
+    }
+
+    val writeHeader = if (needHeader) {
+      // If header is required, we need to write the number of elements into first 4 bytes.
+      s"""
+        $bufferHolder.grow(4);
+        Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements);
+        $bufferHolder.cursor += 4;
+      """
+    } else ""
+
+    s"""
+      final int $numElements = $input.numElements();
+      $writeHeader
+      if ($input instanceof UnsafeArrayData) {
+        $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+      } else {
+        $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
+
+        for (int $index = 0; $index < $numElements; $index++) {
+          if ($input.isNullAt($index)) {
+            $arrayWriter.setNullAt($index);
+          } else {
+            final $jt $element = ${ctx.getValue(input, elementType, index)};
+            $writeElement
+          }
+        }
+      }
+    """
+  }
+
+  // TODO: if the nullability of value element is correct, we can use it to save null check.
+  private def writeMapToBuffer(
+      ctx: CodeGenContext,
+      input: String,
+      keyType: DataType,
+      valueType: DataType,
+      bufferHolder: String): String = {
+    val keys = ctx.freshName("keys")
+    val values = ctx.freshName("values")
+    val tmpCursor = ctx.freshName("tmpCursor")
+
+
+    // Writes out unsafe map according to the format described in `UnsafeMapData`.
+    s"""
+      final ArrayData $keys = $input.keyArray();
+      final ArrayData $values = $input.valueArray();
+
+      $bufferHolder.grow(8);
+
+      // Write the numElements into first 4 bytes.
+      Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements());
+
+      $bufferHolder.cursor += 8;
+      // Remember the current cursor so that we can write numBytes of key array later.
+      final int $tmpCursor = $bufferHolder.cursor;
+
+      ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)}
+      // Write the numBytes of key array into second 4 bytes.
+      Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+
+      ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)}
+    """
+  }
+
+  private def writePrimitiveType(
+      ctx: CodeGenContext,
+      input: String,
+      dt: DataType,
+      buffer: String,
+      offset: String) = {
+    assert(ctx.isPrimitiveType(dt))
+
+    val putMethod = s"put${ctx.primitiveTypeName(dt)}"
+
+    dt match {
+      case FloatType | DoubleType =>
+        val normalized = ctx.freshName("normalized")
+        val boxedType = ctx.boxedType(dt)
+        val handleNaN =
+          s"""
+            final ${ctx.javaType(dt)} $normalized;
+            if ($boxedType.isNaN($input)) {
+              $normalized = $boxedType.NaN;
+            } else {
+              $normalized = $input;
+            }
+          """
+
+        s"""
+          $handleNaN
+          Platform.$putMethod($buffer, $offset, $normalized);
+        """
+      case _ => s"Platform.$putMethod($buffer, $offset, $input);"
+    }
+  }
+
   def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
     val exprEvals = expressions.map(e => e.gen(ctx))
     val exprTypes = expressions.map(_.dataType)
-    createCodeForStruct(ctx, "i", exprEvals, exprTypes)
+
+    val result = ctx.freshName("result")
+    ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();")
+    val bufferHolder = ctx.freshName("bufferHolder")
+    val holderClass = classOf[BufferHolder].getName
+    ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
+
+    val code =
+      s"""
+        $bufferHolder.reset();
+        ${writeExpressionsToBuffer(ctx, "i", exprEvals, exprTypes, bufferHolder)}
+        $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
+      """
+    GeneratedExpressionCode(code, "false", result)
   }
 
   protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 8c7220319363088017a696206b5986ab3c5a9c01..c991cd86d28c8ac0070c3017e5893b04ff8de2b3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.sql.{Date, Timestamp}
-import java.util.Arrays
 
 import org.scalatest.Matchers
 
@@ -43,7 +42,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     row.setInt(2, 2)
 
     val unsafeRow: UnsafeRow = converter.apply(row)
-    assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8))
+    assert(unsafeRow.getSizeInBytes === 8 + (3 * 8))
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getLong(1) === 1)
     assert(unsafeRow.getInt(2) === 2)
@@ -62,6 +61,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRowCopy.getLong(0) === 0)
     assert(unsafeRowCopy.getLong(1) === 1)
     assert(unsafeRowCopy.getInt(2) === 2)
+
+    // Make sure the converter can be reused, i.e. we correctly reset all states.
+    val unsafeRow2: UnsafeRow = converter.apply(row)
+    assert(unsafeRow2.getSizeInBytes === 8 + (3 * 8))
+    assert(unsafeRow2.getLong(0) === 0)
+    assert(unsafeRow2.getLong(1) === 1)
+    assert(unsafeRow2.getInt(2) === 2)
   }
 
   test("basic conversion with primitive, string and binary types") {
@@ -176,7 +182,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       r
     }
 
-    // todo: we reuse the UnsafeRow in projection, so these tests are meaningless.
     val setToNullAfterCreation = converter.apply(rowWithNoNullColumns)
     assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
     assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
@@ -192,7 +197,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       rowWithNoNullColumns.getDecimal(10, 10, 0))
     assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
       rowWithNoNullColumns.getDecimal(11, 38, 18))
-    // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
 
     for (i <- fieldTypes.indices) {
       // Cann't call setNullAt() on DecimalType
@@ -202,8 +206,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
         setToNullAfterCreation.setNullAt(i)
       }
     }
-    // There are some garbage left in the var-length area
-    assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
 
     setToNullAfterCreation.setNullAt(0)
     setToNullAfterCreation.setBoolean(1, false)
@@ -251,107 +253,274 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
   }
 
+  test("basic conversion with struct type") {
+    val fieldTypes: Array[DataType] = Array(
+      new StructType().add("i", IntegerType),
+      new StructType().add("nest", new StructType().add("l", LongType))
+    )
+
+    val converter = UnsafeProjection.create(fieldTypes)
+
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, InternalRow(1))
+    row.update(1, InternalRow(InternalRow(2L)))
+
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields == 2)
+
+    val row1 = unsafeRow.getStruct(0, 1)
+    assert(row1.getSizeInBytes == 8 + 1 * 8)
+    assert(row1.numFields == 1)
+    assert(row1.getInt(0) == 1)
+
+    val row2 = unsafeRow.getStruct(1, 1)
+    assert(row2.numFields() == 1)
+
+    val innerRow = row2.getStruct(0, 1)
+
+    {
+      assert(innerRow.getSizeInBytes == 8 + 1 * 8)
+      assert(innerRow.numFields == 1)
+      assert(innerRow.getLong(0) == 2L)
+    }
+
+    assert(row2.getSizeInBytes == 8 + 1 * 8 + innerRow.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes == 8 + 2 * 8 + row1.getSizeInBytes + row2.getSizeInBytes)
+  }
+
+  private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+
+  private def createMap(keys: Any*)(values: Any*): MapData = {
+    assert(keys.length == values.length)
+    new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
+  }
+
+  private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
+
+  private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
+
+  private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
+    assert(array.numElements == values.length)
+    assert(array.getSizeInBytes == (4 + 4) * values.length)
+    values.zipWithIndex.foreach {
+      case (value, index) => assert(array.getInt(index) == value)
+    }
+  }
+
+  private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values: Seq[Int]): Unit = {
+    assert(keys.length == values.length)
+    assert(map.numElements == keys.length)
+
+    testArrayInt(map.keyArray, keys)
+    testArrayInt(map.valueArray, values)
+
+    assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+  }
+
   test("basic conversion with array type") {
     val fieldTypes: Array[DataType] = Array(
-      ArrayType(LongType),
-      ArrayType(ArrayType(LongType))
+      ArrayType(IntegerType),
+      ArrayType(ArrayType(IntegerType))
     )
     val converter = UnsafeProjection.create(fieldTypes)
 
-    val array1 = new GenericArrayData(Array[Any](1L, 2L))
-    val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L))))
     val row = new GenericMutableRow(fieldTypes.length)
-    row.update(0, array1)
-    row.update(1, array2)
+    row.update(0, createArray(1, 2))
+    row.update(1, createArray(createArray(3, 4)))
 
     val unsafeRow: UnsafeRow = converter.apply(row)
     assert(unsafeRow.numFields() == 2)
 
-    val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
-    assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2)
-    assert(unsafeArray1.numElements() == 2)
-    assert(unsafeArray1.getLong(0) == 1L)
-    assert(unsafeArray1.getLong(1) == 2L)
+    val unsafeArray1 = unsafeRow.getArray(0)
+    testArrayInt(unsafeArray1, Seq(1, 2))
 
-    val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData]
-    assert(unsafeArray2.numElements() == 1)
+    val unsafeArray2 = unsafeRow.getArray(1)
+    assert(unsafeArray2.numElements == 1)
 
-    val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData]
-    assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2)
-    assert(nestedArray.numElements() == 2)
-    assert(nestedArray.getLong(0) == 3L)
-    assert(nestedArray.getLong(1) == 4L)
+    val nestedArray = unsafeArray2.getArray(0)
+    testArrayInt(nestedArray, Seq(3, 4))
 
-    assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+    assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
 
-    val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes)
-    val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes)
+    val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
+    val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
     assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
   }
 
   test("basic conversion with map type") {
-    def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+    val fieldTypes: Array[DataType] = Array(
+      MapType(IntegerType, IntegerType),
+      MapType(IntegerType, MapType(IntegerType, IntegerType))
+    )
+    val converter = UnsafeProjection.create(fieldTypes)
 
-    def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = {
-      val numElements = keys.length
-      assert(map.numElements() == numElements)
+    val map1 = createMap(1, 2)(3, 4)
 
-      val keyArray = map.keys
-      assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements)
-      assert(keyArray.numElements() == numElements)
-      keys.zipWithIndex.foreach { case (key, i) =>
-        assert(keyArray.getInt(i) == key)
-      }
+    val innerMap = createMap(5, 6)(7, 8)
+    val map2 = createMap(9)(innerMap)
 
-      val valueArray = map.values
-      assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements)
-      assert(valueArray.numElements() == numElements)
-      values.zipWithIndex.foreach { case (value, i) =>
-        assert(valueArray.getLong(i) == value)
-      }
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, map1)
+    row.update(1, map2)
 
-      assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields == 2)
+
+    val unsafeMap1 = unsafeRow.getMap(0)
+    testMapInt(unsafeMap1, Seq(1, 2), Seq(3, 4))
+
+    val unsafeMap2 = unsafeRow.getMap(1)
+    assert(unsafeMap2.numElements == 1)
+
+    val keyArray = unsafeMap2.keyArray
+    testArrayInt(keyArray, Seq(9))
+
+    val valueArray = unsafeMap2.valueArray
+
+    {
+      assert(valueArray.numElements == 1)
+
+      val nestedMap = valueArray.getMap(0)
+      testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
+
+      assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
     }
 
+    assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+    val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
+    val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+    assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+  }
+
+  test("basic conversion with struct and array") {
     val fieldTypes: Array[DataType] = Array(
-      MapType(IntegerType, LongType),
-      MapType(IntegerType, MapType(IntegerType, LongType))
+      new StructType().add("arr", ArrayType(IntegerType)),
+      ArrayType(new StructType().add("l", LongType))
     )
     val converter = UnsafeProjection.create(fieldTypes)
 
-    val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L))
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, InternalRow(createArray(1)))
+    row.update(1, createArray(InternalRow(2L)))
+
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields() == 2)
+
+    val field1 = unsafeRow.getStruct(0, 1)
+    assert(field1.numFields == 1)
+
+    val innerArray = field1.getArray(0)
+    testArrayInt(innerArray, Seq(1))
 
-    val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L))
-    val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap))
+    assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes))
+
+    val field2 = unsafeRow.getArray(1)
+    assert(field2.numElements == 1)
+
+    val innerStruct = field2.getStruct(0, 1)
+
+    {
+      assert(innerStruct.numFields == 1)
+      assert(innerStruct.getSizeInBytes == 8 + 8)
+      assert(innerStruct.getLong(0) == 2L)
+    }
+
+    assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes ==
+      8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes))
+  }
+
+  test("basic conversion with struct and map") {
+    val fieldTypes: Array[DataType] = Array(
+      new StructType().add("map", MapType(IntegerType, IntegerType)),
+      MapType(IntegerType, new StructType().add("l", LongType))
+    )
+    val converter = UnsafeProjection.create(fieldTypes)
 
     val row = new GenericMutableRow(fieldTypes.length)
-    row.update(0, map1)
-    row.update(1, map2)
+    row.update(0, InternalRow(createMap(1)(2)))
+    row.update(1, createMap(3)(InternalRow(4L)))
 
     val unsafeRow: UnsafeRow = converter.apply(row)
     assert(unsafeRow.numFields() == 2)
 
-    val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData]
-    testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L))
+    val field1 = unsafeRow.getStruct(0, 1)
+    assert(field1.numFields == 1)
 
-    val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData]
-    assert(unsafeMap2.numElements() == 1)
+    val innerMap = field1.getMap(0)
+    testMapInt(innerMap, Seq(1), Seq(2))
 
-    val keyArray = unsafeMap2.keys
-    assert(keyArray.getSizeInBytes == 4 + 4)
-    assert(keyArray.numElements() == 1)
-    assert(keyArray.getInt(0) == 9)
+    assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes))
 
-    val valueArray = unsafeMap2.values
-    assert(valueArray.numElements() == 1)
-    val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData]
-    testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L))
-    assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes)
+    val field2 = unsafeRow.getMap(1)
 
-    assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+    val keyArray = field2.keyArray
+    testArrayInt(keyArray, Seq(3))
 
-    val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes)
-    val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes)
-    assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+    val valueArray = field2.valueArray
+
+    {
+      assert(valueArray.numElements == 1)
+
+      val innerStruct = valueArray.getStruct(0, 1)
+      assert(innerStruct.numFields == 1)
+      assert(innerStruct.getSizeInBytes == 8 + 8)
+      assert(innerStruct.getLong(0) == 4L)
+
+      assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+    }
+
+    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes ==
+      8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+  }
+
+  test("basic conversion with array and map") {
+    val fieldTypes: Array[DataType] = Array(
+      ArrayType(MapType(IntegerType, IntegerType)),
+      MapType(IntegerType, ArrayType(IntegerType))
+    )
+    val converter = UnsafeProjection.create(fieldTypes)
+
+    val row = new GenericMutableRow(fieldTypes.length)
+    row.update(0, createArray(createMap(1)(2)))
+    row.update(1, createMap(3)(createArray(4)))
+
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    assert(unsafeRow.numFields() == 2)
+
+    val field1 = unsafeRow.getArray(0)
+    assert(field1.numElements == 1)
+
+    val innerMap = field1.getMap(0)
+    testMapInt(innerMap, Seq(1), Seq(2))
+
+    assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+
+    val field2 = unsafeRow.getMap(1)
+    assert(field2.numElements == 1)
+
+    val keyArray = field2.keyArray
+    testArrayInt(keyArray, Seq(3))
+
+    val valueArray = field2.valueArray
+
+    {
+      assert(valueArray.numElements == 1)
+
+      val innerArray = valueArray.getArray(0)
+      testArrayInt(innerArray, Seq(4))
+
+      assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
+    }
+
+    assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+    assert(unsafeRow.getSizeInBytes ==
+      8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes))
   }
 }