From d380f324c6d38ffacfda83a525a1a7e23347e5b8 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Mon, 16 Feb 2015 20:42:57 -0800
Subject: [PATCH] [SPARK-5853][SQL] Schema support in Row.

Author: Reynold Xin <rxin@databricks.com>

Closes #4640 from rxin/SPARK-5853 and squashes the following commits:

9c6f569 [Reynold Xin] [SPARK-5853][SQL] Schema support in Row.
---
 sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala | 7 ++++++-
 .../org/apache/spark/sql/catalyst/ScalaReflection.scala    | 6 +++---
 .../org/apache/spark/sql/catalyst/expressions/rows.scala   | 6 +++++-
 .../test/scala/org/apache/spark/sql/DataFrameSuite.scala   | 6 ++++++
 4 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 3a70d25534..d794f034f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
 import scala.util.hashing.MurmurHash3
 
 import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.DateUtils
+import org.apache.spark.sql.types.{StructType, DateUtils}
 
 object Row {
   /**
@@ -122,6 +122,11 @@ trait Row extends Serializable {
   /** Number of elements in the Row. */
   def length: Int
 
+  /**
+   * Schema for the row.
+   */
+  def schema: StructType = null
+
   /**
    * Returns the value at position i. If the value is null, null is returned. The following
    * is a mapping between Spark SQL types and return types:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 11fd443733..d6126c24fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import java.sql.Timestamp
 
 import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
 
@@ -91,9 +91,9 @@ trait ScalaReflection {
 
   def convertRowToScala(r: Row, schema: StructType): Row = {
     // TODO: This is very slow!!!
-    new GenericRow(
+    new GenericRowWithSchema(
       r.toSeq.zip(schema.fields.map(_.dataType))
-        .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
+        .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema)
   }
 
   /** Returns a Sequence of attributes for the given case class type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 73ec7a6d11..faa3667718 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.{StructType, NativeType}
 
 
 /**
@@ -149,6 +149,10 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
   def copy() = this
 }
 
+class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
+  extends GenericRow(values) {
+}
+
 class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
   /** No-arg constructor for serialization. */
   def this() = this(null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 524571d9cc..0da619def1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -89,6 +89,12 @@ class DataFrameSuite extends QueryTest {
       testData.collect().toSeq)
   }
 
+  test("head and take") {
+    assert(testData.take(2) === testData.collect().take(2))
+    assert(testData.head(2) === testData.collect().take(2))
+    assert(testData.head(2).head.schema === testData.schema)
+  }
+
   test("self join") {
     val df1 = testData.select(testData("key")).as('df1)
     val df2 = testData.select(testData("key")).as('df2)
-- 
GitLab