diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index c119758d68b36d0a94eda600c5575c12e80d9011..a0bf8734b6545ef05ef1fdb66969712240d57d0a 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -210,104 +210,6 @@ public abstract class ColumnVector {
     }
   }
 
-  /**
-   * Holder object to return a struct. This object is intended to be reused.
-   */
-  public static final class Struct extends InternalRow {
-    // The fields that make up this struct. For example, if the struct had 2 int fields, the access
-    // to it would be:
-    //   int f1 = fields[0].getInt[rowId]
-    //   int f2 = fields[1].getInt[rowId]
-    public final ColumnVector[] fields;
-
-    @Override
-    public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); }
-
-    @Override
-    public boolean getBoolean(int ordinal) {
-      throw new NotImplementedException();
-    }
-
-    public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); }
-
-    @Override
-    public short getShort(int ordinal) {
-      throw new NotImplementedException();
-    }
-
-    public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); }
-    public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); }
-
-    @Override
-    public float getFloat(int ordinal) {
-      throw new NotImplementedException();
-    }
-
-    public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); }
-
-    @Override
-    public Decimal getDecimal(int ordinal, int precision, int scale) {
-      throw new NotImplementedException();
-    }
-
-    @Override
-    public UTF8String getUTF8String(int ordinal) {
-      Array a = getByteArray(ordinal);
-      return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
-    }
-
-    @Override
-    public byte[] getBinary(int ordinal) {
-      throw new NotImplementedException();
-    }
-
-    @Override
-    public CalendarInterval getInterval(int ordinal) {
-      throw new NotImplementedException();
-    }
-
-    @Override
-    public InternalRow getStruct(int ordinal, int numFields) {
-      return fields[ordinal].getStruct(rowId);
-    }
-
-    public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); }
-
-    @Override
-    public MapData getMap(int ordinal) {
-      throw new NotImplementedException();
-    }
-
-    @Override
-    public Object get(int ordinal, DataType dataType) {
-      throw new NotImplementedException();
-    }
-
-    public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); }
-    public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); }
-
-    @Override
-    public final int numFields() {
-      return fields.length;
-    }
-
-    @Override
-    public InternalRow copy() {
-      throw new NotImplementedException();
-    }
-
-    @Override
-    public boolean anyNull() {
-      throw new NotImplementedException();
-    }
-
-    protected int rowId;
-
-    protected Struct(ColumnVector[] fields) {
-      this.fields = fields;
-    }
-  }
-
   /**
    * Returns the data type of this column.
    */
@@ -494,7 +396,7 @@ public abstract class ColumnVector {
   /**
    * Returns a utility object to get structs.
    */
-  public Struct getStruct(int rowId) {
+  public ColumnarBatch.Row getStruct(int rowId) {
     resultStruct.rowId = rowId;
     return resultStruct;
   }
@@ -749,7 +651,7 @@ public abstract class ColumnVector {
   /**
    * Reusable Struct holder for getStruct().
    */
-  protected final Struct resultStruct;
+  protected final ColumnarBatch.Row resultStruct;
 
   /**
    * Sets up the common state and also handles creating the child columns if this is a nested
@@ -779,7 +681,7 @@ public abstract class ColumnVector {
         this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode);
       }
       this.resultArray = null;
-      this.resultStruct = new Struct(this.childColumns);
+      this.resultStruct = new ColumnarBatch.Row(this.childColumns);
     } else {
       this.childColumns = null;
       this.resultArray = null;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index d558dae50c227b21d7cfac5c32c93e388b1b4e2f..5a575811fa8963d49cfc3436da976bf5a5d85685 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -86,13 +86,23 @@ public final class ColumnarBatch {
    * performance is lost with this translation.
    */
   public static final class Row extends InternalRow {
-    private int rowId;
+    protected int rowId;
     private final ColumnarBatch parent;
     private final int fixedLenRowSize;
+    private final ColumnVector[] columns;
 
+    // Ctor used if this is a top level row.
     private Row(ColumnarBatch parent) {
       this.parent = parent;
       this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols());
+      this.columns = parent.columns;
+    }
+
+    // Ctor used if this is a struct.
+    protected Row(ColumnVector[] columns) {
+      this.parent = null;
+      this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length);
+      this.columns = columns;
     }
 
     /**
@@ -103,23 +113,23 @@ public final class ColumnarBatch {
       parent.markFiltered(rowId);
     }
 
+    public ColumnVector[] columns() { return columns; }
+
     @Override
-    public final int numFields() {
-      return parent.numCols();
-    }
+    public final int numFields() { return columns.length; }
 
     @Override
     /**
      * Revisit this. This is expensive.
      */
     public final InternalRow copy() {
-      UnsafeRow row = new UnsafeRow(parent.numCols());
+      UnsafeRow row = new UnsafeRow(numFields());
       row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize);
-      for (int i = 0; i < parent.numCols(); i++) {
+      for (int i = 0; i < numFields(); i++) {
         if (isNullAt(i)) {
           row.setNullAt(i);
         } else {
-          DataType dt = parent.schema.fields()[i].dataType();
+          DataType dt = columns[i].dataType();
           if (dt instanceof IntegerType) {
             row.setInt(i, getInt(i));
           } else if (dt instanceof LongType) {
@@ -141,7 +151,7 @@ public final class ColumnarBatch {
 
     @Override
     public final boolean isNullAt(int ordinal) {
-      return parent.column(ordinal).getIsNull(rowId);
+      return columns[ordinal].getIsNull(rowId);
     }
 
     @Override
@@ -150,7 +160,7 @@ public final class ColumnarBatch {
     }
 
     @Override
-    public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); }
+    public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
 
     @Override
     public final short getShort(int ordinal) {
@@ -159,11 +169,11 @@ public final class ColumnarBatch {
 
     @Override
     public final int getInt(int ordinal) {
-      return parent.column(ordinal).getInt(rowId);
+      return columns[ordinal].getInt(rowId);
     }
 
     @Override
-    public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); }
+    public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
 
     @Override
     public final float getFloat(int ordinal) {
@@ -172,7 +182,7 @@ public final class ColumnarBatch {
 
     @Override
     public final double getDouble(int ordinal) {
-      return parent.column(ordinal).getDouble(rowId);
+      return columns[ordinal].getDouble(rowId);
     }
 
     @Override
@@ -182,7 +192,7 @@ public final class ColumnarBatch {
 
     @Override
     public final UTF8String getUTF8String(int ordinal) {
-      ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId);
+      ColumnVector.Array a = columns[ordinal].getByteArray(rowId);
       return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
     }
 
@@ -198,12 +208,12 @@ public final class ColumnarBatch {
 
     @Override
     public final InternalRow getStruct(int ordinal, int numFields) {
-      return parent.column(ordinal).getStruct(rowId);
+      return columns[ordinal].getStruct(rowId);
     }
 
     @Override
     public final ArrayData getArray(int ordinal) {
-      return parent.column(ordinal).getArray(rowId);
+      return columns[ordinal].getArray(rowId);
     }
 
     @Override
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 215ca9ab6b77021b10118dff1fda284db15c441f..67cc08b6fc8ba73c277a159f29568cd106afaae5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -439,10 +439,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
       c2.putDouble(1, 5.67)
 
       val s = column.getStruct(0)
-      assert(s.fields(0).getInt(0) == 123)
-      assert(s.fields(0).getInt(1) == 456)
-      assert(s.fields(1).getDouble(0) == 3.45)
-      assert(s.fields(1).getDouble(1) == 5.67)
+      assert(s.columns()(0).getInt(0) == 123)
+      assert(s.columns()(0).getInt(1) == 456)
+      assert(s.columns()(1).getDouble(0) == 3.45)
+      assert(s.columns()(1).getDouble(1) == 5.67)
 
       assert(s.getInt(0) == 123)
       assert(s.getDouble(1) == 3.45)