diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c25161ee81b6614762c34ad2fc59b7fa39ba7a76..9cbb7c2ffdc764334fc2cb7855aa7cd2b25b57d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,6 +146,10 @@ trait ScalaReflection {
    * row with a compatible schema.  Fields of the row will be extracted using UnresolvedAttributes
    * of the same name as the constructor arguments.  Nested classes will have their fields accessed
    * using UnresolvedExtractValue.
+   *
+   * When used on a primitive type, the constructor will instead default to extracting the value
+   * from ordinal 0 (since there are no names to map to).  The actual location can be moved by
+   * calling unbind/bind with a new schema.
    */
   def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
 
@@ -159,8 +163,14 @@ trait ScalaReflection {
         .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
         .getOrElse(UnresolvedAttribute(part))
 
+    /** Returns the current path with a field at ordinal extracted. */
+    def addToPathOrdinal(ordinal: Int, dataType: DataType) =
+      path
+        .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal))
+        .getOrElse(BoundReference(ordinal, dataType, false))
+
     /** Returns the current path or throws an error. */
-    def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+    def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true))
 
     tpe match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
@@ -387,12 +397,17 @@ trait ScalaReflection {
         val className: String = t.erasure.typeSymbol.asClass.fullName
         val cls = Utils.classForName(className)
 
-        val arguments = params.head.map { p =>
+        val arguments = params.head.zipWithIndex.map { case (p, i) =>
           val fieldName = p.name.toString
           val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
-          val dataType = dataTypeFor(fieldType)
+          val dataType = schemaFor(fieldType).dataType
 
-          constructorFor(fieldType, Some(addToPath(fieldName)))
+          // For tuples, we based grab the inner fields by ordinal instead of name.
+          if (className startsWith "scala.Tuple") {
+            constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+          } else {
+            constructorFor(fieldType, Some(addToPath(fieldName)))
+          }
         }
 
         val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
@@ -413,7 +428,10 @@ trait ScalaReflection {
   /** Returns expressions for extracting all the fields from the given type. */
   def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
     ScalaReflectionLock.synchronized {
-      extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
+      extractorFor(inputObject, typeTag[T].tpe) match {
+        case s: CreateNamedStruct => s
+        case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil)
+      }
     }
   }
 
@@ -602,6 +620,21 @@ trait ScalaReflection {
         case t if t <:< localTypeOf[java.lang.Boolean] =>
           Invoke(inputObject, "booleanValue", BooleanType)
 
+        case t if t <:< definitions.IntTpe =>
+          BoundReference(0, IntegerType, false)
+        case t if t <:< definitions.LongTpe =>
+          BoundReference(0, LongType, false)
+        case t if t <:< definitions.DoubleTpe =>
+          BoundReference(0, DoubleType, false)
+        case t if t <:< definitions.FloatTpe =>
+          BoundReference(0, FloatType, false)
+        case t if t <:< definitions.ShortTpe =>
+          BoundReference(0, ShortType, false)
+        case t if t <:< definitions.ByteTpe =>
+          BoundReference(0, ByteType, false)
+        case t if t <:< definitions.BooleanTpe =>
+          BoundReference(0, BooleanType, false)
+
         case other =>
           throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
       }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
deleted file mode 100644
index b484b8fde63692b7c73cfedb4a69144d8b4dea9d..0000000000000000000000000000000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- *                           extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
-    schema: StructType,
-    extractExpressions: Seq[Expression],
-    constructExpression: Expression,
-    clsTag: ClassTag[T])
-  extends Encoder[T] {
-
-  @transient
-  private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
-  private val inputRow = new GenericMutableRow(1)
-
-  @transient
-  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
-  private val dataType = ObjectType(clsTag.runtimeClass)
-
-  override def toRow(t: T): InternalRow = {
-    inputRow(0) = t
-    extractProjection(inputRow)
-  }
-
-  override def fromRow(row: InternalRow): T = {
-    constructProjection(row).get(0, dataType).asInstanceOf[T]
-  }
-
-  override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
-    val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
-    val analyzedPlan = SimpleAnalyzer.execute(plan)
-    val resolvedExpression = analyzedPlan.expressions.head.children.head
-    val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
-    copy(constructExpression = boundExpression)
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
-    val positionToAttribute = AttributeMap.toIndex(oldSchema)
-    val attributeToNewPosition = AttributeMap.byIndex(newSchema)
-    copy(constructExpression = constructExpression transform {
-      case r: BoundReference =>
-        r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
-    })
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
-    var remaining = schema
-    copy(constructExpression = constructExpression transform {
-      case u: UnresolvedAttribute =>
-        val pos = remaining.head
-        remaining = remaining.drop(1)
-        pos
-    })
-  }
-
-  protected val attrs = extractExpressions.map(_.collect {
-    case a: Attribute => s"#${a.exprId}"
-    case b: BoundReference => s"[${b.ordinal}]"
-  }.headOption.getOrElse(""))
-
-
-  protected val schemaString =
-    schema
-      .zip(attrs)
-      .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
-
-  override def toString: String = s"class[$schemaString]"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index efb872ddb81e52ba1ca35222aefb5af36979739c..329a132d3d8b241d9c2878b874c43769238adbc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.sql.catalyst.encoders
 
 
+
 import scala.reflect.ClassTag
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType
  * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
  * and reuse internal buffers to improve performance.
  */
-trait Encoder[T] {
+trait Encoder[T] extends Serializable {
 
   /** Returns the schema of encoding this type of object as a Row. */
   def schema: StructType
 
   /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
   def clsTag: ClassTag[T]
-
-  /**
-   * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple calls to
-   * toRow are allowed to return the same actual [[InternalRow]] object.  Thus, the caller should
-   * copy the result before making another call if required.
-   */
-  def toRow(t: T): InternalRow
-
-  /**
-   * Returns an object of type `T`, extracting the required values from the provided row.  Note that
-   * you must `bind` an encoder to a specific schema before you can call this function.
-   */
-  def fromRow(row: InternalRow): T
-
-  /**
-   * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
-   * given schema.
-   */
-  def bind(schema: Seq[Attribute]): Encoder[T]
-
-  /**
-   * Binds this encoder to the given schema positionally.  In this binding, the first reference to
-   * any input is mapped to `schema(0)`, and so on for each input that is encountered.
-   */
-  def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
-
-  /**
-   * Given an encoder that has already been bound to a given schema, returns a new encoder that
-   * where the positions are mapped from `oldSchema` to `newSchema`.  This can be used, for example,
-   * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
-   * row, but now you have projected out only the key expressions.
-   */
-  def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c287aebeeee059ae107ddd5c6af7b2251d00f06f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+
+/**
+ * A factory for constructing encoders that convert objects and primitves to and from the
+ * internal row format using catalyst expressions and code generation.  By default, the
+ * expressions used to retrieve values from an input row when producing an object will be created as
+ * follows:
+ *  - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions
+ *    and [[UnresolvedExtractValue]] expressions.
+ *  - Tuples will have their subfields extracted by position using [[BoundReference]] expressions.
+ *  - Primitives will have their values extracted from the first ordinal with a schema that defaults
+ *    to the name `value`.
+ */
+object ExpressionEncoder {
+  def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = {
+    // We convert the not-serializable TypeTag into StructType and ClassTag.
+    val mirror = typeTag[T].mirror
+    val cls = mirror.runtimeClass(typeTag[T].tpe)
+
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+    val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
+    val constructExpression = ScalaReflection.constructorFor[T]
+
+    new ExpressionEncoder[T](
+      extractExpression.dataType,
+      flat,
+      extractExpression.flatten,
+      constructExpression,
+      ClassTag[T](cls))
+  }
+
+  /**
+   * Given a set of N encoders, constructs a new encoder that produce objects as items in an
+   * N-tuple.  Note that these encoders should first be bound correctly to the combined input
+   * schema.
+   */
+  def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+    val schema =
+      StructType(
+        encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+    val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+    val extractExpressions = encoders.map {
+      case e if e.flat => e.extractExpressions.head
+      case other => CreateStruct(other.extractExpressions)
+    }
+    val constructExpression =
+      NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+
+    new ExpressionEncoder[Any](
+      schema,
+      false,
+      extractExpressions,
+      constructExpression,
+      ClassTag.apply(cls))
+  }
+
+  /** A helper for producing encoders of Tuple2 from other encoders. */
+  def tuple[T1, T2](
+      e1: ExpressionEncoder[T1],
+      e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
+    tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ *                           extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ExpressionEncoder[T](
+    schema: StructType,
+    flat: Boolean,
+    extractExpressions: Seq[Expression],
+    constructExpression: Expression,
+    clsTag: ClassTag[T])
+  extends Encoder[T] {
+
+  if (flat) require(extractExpressions.size == 1)
+
+  @transient
+  private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+  private val inputRow = new GenericMutableRow(1)
+
+  @transient
+  private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+
+  /**
+   * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple calls to
+   * toRow are allowed to return the same actual [[InternalRow]] object.  Thus, the caller should
+   * copy the result before making another call if required.
+   */
+  def toRow(t: T): InternalRow = {
+    inputRow(0) = t
+    extractProjection(inputRow)
+  }
+
+  /**
+   * Returns an object of type `T`, extracting the required values from the provided row.  Note that
+   * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+   * function.
+   */
+  def fromRow(row: InternalRow): T = try {
+    constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
+  } catch {
+    case e: Exception =>
+      throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+  }
+
+  /**
+   * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+   * given schema.
+   */
+  def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+    val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+    val analyzedPlan = SimpleAnalyzer.execute(plan)
+    copy(constructExpression = analyzedPlan.expressions.head.children.head)
+  }
+
+  /**
+   * Returns a copy of this encoder where the expressions used to construct an object from an input
+   * row have been bound to the ordinals of the given schema.  Note that you need to first call
+   * resolve before bind.
+   */
+  def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+    copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
+  }
+
+  /**
+   * Replaces any bound references in the schema with the attributes at the corresponding ordinal
+   * in the provided schema.  This can be used to "relocate" a given encoder to pull values from
+   * a different schema than it was initially bound to.  It can also be used to assign attributes
+   * to ordinal based extraction (i.e. because the input data was a tuple).
+   */
+  def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+    val positionToAttribute = AttributeMap.toIndex(schema)
+    copy(constructExpression = constructExpression transform {
+      case b: BoundReference => positionToAttribute(b.ordinal)
+    })
+  }
+
+  /**
+   * Given an encoder that has already been bound to a given schema, returns a new encoder
+   * where the positions are mapped from `oldSchema` to `newSchema`.  This can be used, for example,
+   * when you are trying to use an encoder on grouping keys that were originally part of a larger
+   * row, but now you have projected out only the key expressions.
+   */
+  def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
+    val positionToAttribute = AttributeMap.toIndex(oldSchema)
+    val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+    copy(constructExpression = constructExpression transform {
+      case r: BoundReference =>
+        r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+    })
+  }
+
+  /**
+   * Returns a copy of this encoder where the expressions used to create an object given an
+   * input row have been modified to pull the object out from a nested struct, instead of the
+   * top level fields.
+   */
+  def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
+    copy(constructExpression = constructExpression transform {
+      case u: Attribute if u != input =>
+        UnresolvedExtractValue(input, Literal(u.name))
+      case b: BoundReference if b != input =>
+        GetStructField(
+          input,
+          StructField(s"i[${b.ordinal}]", b.dataType),
+          b.ordinal)
+    })
+  }
+
+  protected val attrs = extractExpressions.flatMap(_.collect {
+    case _: UnresolvedAttribute => ""
+    case a: Attribute => s"#${a.exprId}"
+    case b: BoundReference => s"[${b.ordinal}]"
+  })
+
+  protected val schemaString =
+    schema
+      .zip(attrs)
+      .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+  override def toString: String = s"class[$schemaString]"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
deleted file mode 100644
index 34f5e6c030f5889ecf32ebb433b5b0156aeca208..0000000000000000000000000000000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag}
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
- * internal binary representation.
- */
-object ProductEncoder {
-  def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
-    // We convert the not-serializable TypeTag into StructType and ClassTag.
-    val mirror = typeTag[T].mirror
-    val cls = mirror.runtimeClass(typeTag[T].tpe)
-
-    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
-    val constructExpression = ScalaReflection.constructorFor[T]
-
-    new ClassEncoder[T](
-      extractExpression.dataType,
-      extractExpression.flatten,
-      constructExpression,
-      ClassTag[T](cls))
-  }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e9cc00a2b64ce054899432b999f6f40f13916c77..0b42130a013b2575959d6169cafdeb3c06cd1523 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
  * internal binary representation.
  */
 object RowEncoder {
-  def apply(schema: StructType): ClassEncoder[Row] = {
+  def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
     val extractExpressions = extractorsFor(inputObject, schema)
     val constructExpression = constructorFor(schema)
-    new ClassEncoder[Row](
+    new ExpressionEncoder[Row](
       schema,
+      flat = false,
       extractExpressions.asInstanceOf[CreateStruct].children,
       constructExpression,
       ClassTag(cls))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
similarity index 56%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 52f8383faca92fc43eeb0d3e1bafadbcc85a0702..d4642a500672ef82629f05a94d563be005a44ed0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -15,29 +15,12 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.catalyst.encoders
+package org.apache.spark.sql.catalyst
 
-import org.apache.spark.SparkFunSuite
-
-class PrimitiveEncoderSuite extends SparkFunSuite {
-  test("long encoder") {
-    val enc = new LongEncoder()
-    val row = enc.toRow(10)
-    assert(row.getLong(0) == 10)
-    assert(enc.fromRow(row) == 10)
-  }
-
-  test("int encoder") {
-    val enc = new IntEncoder()
-    val row = enc.toRow(10)
-    assert(row.getInt(0) == 10)
-    assert(enc.fromRow(row) == 10)
-  }
-
-  test("string encoder") {
-    val enc = new StringEncoder()
-    val row = enc.toRow("test")
-    assert(row.getString(0) == "test")
-    assert(enc.fromRow(row) == "test")
+package object encoders {
+  private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
+    case e: ExpressionEncoder[A] => e
+    case _ => sys.error(s"Only expression encoders are supported today")
   }
 }
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
deleted file mode 100644
index a93f2d7c6115dad6363617123275827163b3bfbe..0000000000000000000000000000000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.sql.types._
-
-/** An encoder for primitive Long types. */
-case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
-  private val row = UnsafeRow.createFromByteArray(64, 1)
-
-  override def clsTag: ClassTag[Long] = ClassTag.Long
-  override def schema: StructType =
-    StructType(StructField(fieldName, LongType) :: Nil)
-
-  override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
-
-  override def toRow(t: Long): InternalRow = {
-    row.setLong(ordinal, t)
-    row
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
-  override def bind(schema: Seq[Attribute]): Encoder[Long] = this
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
-}
-
-/** An encoder for primitive Integer types. */
-case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
-  private val row = UnsafeRow.createFromByteArray(64, 1)
-
-  override def clsTag: ClassTag[Int] = ClassTag.Int
-  override def schema: StructType =
-    StructType(StructField(fieldName, IntegerType) :: Nil)
-
-  override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
-
-  override def toRow(t: Int): InternalRow = {
-    row.setInt(ordinal, t)
-    row
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
-  override def bind(schema: Seq[Attribute]): Encoder[Int] = this
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
-}
-
-/** An encoder for String types. */
-case class StringEncoder(
-    fieldName: String = "value",
-    ordinal: Int = 0) extends Encoder[String] {
-
-  val record = new SpecificMutableRow(StringType :: Nil)
-
-  @transient
-  lazy val projection =
-    GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
-
-  override def schema: StructType =
-    StructType(
-      StructField("value", StringType, nullable = false) :: Nil)
-
-  override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
-
-
-  override final def fromRow(row: InternalRow): String = {
-    row.getString(ordinal)
-  }
-
-  override final def toRow(value: String): InternalRow = {
-    val utf8String = UTF8String.fromString(value)
-    record(0) = utf8String
-    // TODO: this is a bit of a hack to produce UnsafeRows
-    projection(record)
-  }
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
-  override def bind(schema: Seq[Attribute]): Encoder[String] = this
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
deleted file mode 100644
index a48eeda7d2e6f034e6a5e006d6ec200bff31e5ff..0000000000000000000000000000000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.types.{StructField, StructType}
-
-// Most of this file is codegen.
-// scalastyle:off
-
-/**
- * A set of composite encoders that take sub encoders and map each of their objects to a
- * Scala tuple.  Note that currently the implementation is fairly limited and only supports going
- * from an internal row to a tuple.
- */
-object TupleEncoder {
-
-  /** Code generator for composite tuple encoders. */
-  def main(args: Array[String]): Unit = {
-    (2 to 5).foreach { i =>
-      val types = (1 to i).map(t => s"T$t").mkString(", ")
-      val tupleType = s"($types)"
-      val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
-      val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
-      val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
-
-      println(
-        s"""
-          |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
-          |  val schema = StructType(Array($fields))
-          |
-          |  def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
-          |
-          |  def fromRow(row: InternalRow): $tupleType = {
-          |    ($fromRow)
-          |  }
-          |
-          |  override def toRow(t: $tupleType): InternalRow =
-          |    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-          |
-          |  override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
-          |    this
-          |  }
-          |
-          |  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
-          |    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-          |
-          |
-          |  override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
-          |    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-          |}
-        """.stripMargin)
-    }
-  }
-}
-
-class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
-
-  def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
-
-  def fromRow(row: InternalRow): (T1, T2) = {
-    (e1.fromRow(row), e2.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
-
-  def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
-
-  def fromRow(row: InternalRow): (T1, T2, T3) = {
-    (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2, T3)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
-
-  def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
-
-  def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
-    (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2, T3, T4)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
-  val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
-
-  def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
-
-  def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
-    (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
-  }
-
-  override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
-    throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
-  override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
-    this
-  }
-
-  override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
-  override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
-    throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 21a55a53718410d5d70d6fbbf90b87034628a794..d2d3db0a4448474000b750474c0ea534dacb0dfd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
 import org.apache.spark.sql.catalyst.plans._
@@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode {
  */
 case class MapPartitions[T, U](
     func: Iterator[T] => Iterator[U],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     output: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
   override def missingInput: AttributeSet = AttributeSet.empty
@@ -460,8 +460,8 @@ case class MapPartitions[T, U](
 /** Factory for constructing new `AppendColumn` nodes. */
 object AppendColumn {
   def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
-    val attrs = implicitly[Encoder[U]].schema.toAttributes
-    new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+    val attrs = encoderFor[U].schema.toAttributes
+    new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
   }
 }
 
@@ -472,8 +472,8 @@ object AppendColumn {
  */
 case class AppendColumn[T, U](
     func: T => U,
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     newColumns: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output ++ newColumns
@@ -488,11 +488,11 @@ object MapGroups {
       child: LogicalPlan): MapGroups[K, T, U] = {
     new MapGroups(
       func,
-      implicitly[Encoder[K]],
-      implicitly[Encoder[T]],
-      implicitly[Encoder[U]],
+      encoderFor[K],
+      encoderFor[T],
+      encoderFor[U],
       groupingAttributes,
-      implicitly[Encoder[U]].schema.toAttributes,
+      encoderFor[U].schema.toAttributes,
       child)
   }
 }
@@ -504,9 +504,9 @@ object MapGroups {
  */
 case class MapGroups[K, T, U](
     func: (K, Iterator[T]) => Iterator[U],
-    kEncoder: Encoder[K],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    kEncoder: ExpressionEncoder[K],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     groupingAttributes: Seq[Attribute],
     output: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
similarity index 91%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 008d0bea8a941499438c0b47ce53e22a85203283..a374da4da1f081f75b61d761acc6ea74e5115bdb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -47,7 +47,16 @@ case class RepeatedData(
 
 case class SpecificCollection(l: List[Int])
 
-class ProductEncoderSuite extends SparkFunSuite {
+class ExpressionEncoderSuite extends SparkFunSuite {
+
+  encodeDecodeTest(1)
+  encodeDecodeTest(1L)
+  encodeDecodeTest(1.toDouble)
+  encodeDecodeTest(1.toFloat)
+  encodeDecodeTest(true)
+  encodeDecodeTest(false)
+  encodeDecodeTest(1.toShort)
+  encodeDecodeTest(1.toByte)
 
   encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
 
@@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite {
     { (l, r) => l._2.toString == r._2.toString }
 
   /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
-  protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
+  protected def encodeDecodeTest[T : TypeTag](inputData: T) =
     encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
 
   /**
    * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
    * matches the original.
    */
-  protected def encodeDecodeTestCustom[T <: Product : TypeTag](
+  protected def encodeDecodeTestCustom[T : TypeTag](
       inputData: T)(
       c: (T, T) => Boolean) = {
-    test(s"encode/decode: $inputData") {
-      val encoder = try ProductEncoder[T] catch {
+    test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
+      val encoder = try ExpressionEncoder[T]() catch {
         case e: Exception =>
           fail(s"Exception thrown generating encoder", e)
       }
       val convertedData = encoder.toRow(inputData)
       val schema = encoder.schema.toAttributes
-      val boundEncoder = encoder.bind(schema)
+      val boundEncoder = encoder.resolve(schema).bind(schema)
       val convertedBack = try boundEncoder.fromRow(convertedData) catch {
         case e: Exception =>
           fail(
@@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite {
               |Schema: ${schema.mkString(",")}
               |${encoder.schema.treeString}
               |
-              |Construct Expressions:
-              |${boundEncoder.constructExpression.treeString}
+              |Encoder:
+              |$boundEncoder
               |
             """.stripMargin, e)
       }
 
       if (!c(inputData, convertedBack)) {
-        val types =
-          convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+        val types = convertedBack match {
+          case c: Product =>
+            c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+          case other => other.getClass.getName
+        }
+
 
         val encodedData = try {
           convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
@@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite {
              |${encoder.schema.treeString}
              |
              |Extract Expressions:
-             |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
-             |
-             |Construct Expressions:
-             |${boundEncoder.constructExpression.treeString}
-             |
+             |$boundEncoder
          """.stripMargin)
         }
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 32d9b0b1d9888f062018fa11a471719c0d1cb306..aa817a037ef5e93e6088b86391095799b9e0b129 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -267,7 +267,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+  def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
 
   /**
    * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96213c7630400c2e93f77a98a1b4c13517e7be1d..e0ab5f593e933212327adc75b5e1cbfdf52d0f08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.types.StructType
@@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType
  * @since 1.6.0
  */
 @Experimental
-class Dataset[T] private[sql](
+class Dataset[T] private(
     @transient val sqlContext: SQLContext,
-    @transient val queryExecution: QueryExecution)(
-    implicit val encoder: Encoder[T]) extends Serializable {
+    @transient val queryExecution: QueryExecution,
+    unresolvedEncoder: Encoder[T]) extends Serializable {
+
+  /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+  private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
+    case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
+    case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
+  }
 
   private implicit def classTag = encoder.clsTag
 
   private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
-    this(sqlContext, new QueryExecution(sqlContext, plan))
+    this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
 
   /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
   def schema: StructType = encoder.schema
@@ -76,7 +83,9 @@ class Dataset[T] private[sql](
    * TODO: document binding rules
    * @since 1.6.0
    */
-  def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+  def as[U : Encoder]: Dataset[U] = {
+    new Dataset(sqlContext, queryExecution, encoderFor[U])
+  }
 
   /**
    * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
@@ -103,7 +112,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def rdd: RDD[T] = {
-    val tEnc = implicitly[Encoder[T]]
+    val tEnc = encoderFor[T]
     val input = queryExecution.analyzed.output
     queryExecution.toRdd.mapPartitions { iter =>
       val bound = tEnc.bind(input)
@@ -150,9 +159,9 @@ class Dataset[T] private[sql](
       sqlContext,
       MapPartitions[T, U](
         func,
-        implicitly[Encoder[T]],
-        implicitly[Encoder[U]],
-        implicitly[Encoder[U]].schema.toAttributes,
+        encoderFor[T],
+        encoderFor[U],
+        encoderFor[U].schema.toAttributes,
         logicalPlan))
   }
 
@@ -209,8 +218,8 @@ class Dataset[T] private[sql](
     val executed = sqlContext.executePlan(withGroupingKey)
 
     new GroupedDataset(
-      implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
-      implicitly[Encoder[T]].bind(inputPlan.output),
+      encoderFor[K].resolve(withGroupingKey.newColumns),
+      encoderFor[T].bind(inputPlan.output),
       executed,
       inputPlan.output,
       withGroupingKey.newColumns)
@@ -220,6 +229,18 @@ class Dataset[T] private[sql](
    *  Typed Relational  *
    * ****************** */
 
+  /**
+   * Selects a set of column based expressions.
+   * {{{
+   *   df.select($"colA", $"colB" + 1)
+   * }}}
+   * @group dfops
+   * @since 1.3.0
+   */
+  // Copied from Dataframe to make sure we don't have invalid overloads.
+  @scala.annotation.varargs
+  def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
    *
@@ -233,88 +254,64 @@ class Dataset[T] private[sql](
     new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
   }
 
-  // Codegen
-  // scalastyle:off
-
-  /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
-  private def genSelect: String = {
-    (2 to 5).map { n =>
-      val types = (1 to n).map(i =>s"U$i").mkString(", ")
-      val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
-      val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
-      val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
-      s"""
-         |/**
-         | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
-         | * @since 1.6.0
-         | */
-         |def select[$types]($args): Dataset[($types)] = {
-         |  implicit val te = new Tuple${n}Encoder($encoders)
-         |  new Dataset[($types)](sqlContext,
-         |    Project(
-         |      $schema :: Nil,
-         |      logicalPlan))
-         |}
-         |
-       """.stripMargin
-    }.mkString("\n")
+  /**
+   * Internal helper function for building typed selects that return tuples.  For simplicity and
+   * code reuse, we do this without the help of the type system and then use helper functions
+   * that cast appropriately for the user facing interface.
+   */
+  protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+    val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+    val unresolvedPlan = Project(aliases, logicalPlan)
+    val execution = new QueryExecution(sqlContext, unresolvedPlan)
+    // Rebind the encoders to the nested schema that will be produced by the select.
+    val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+      case (e: ExpressionEncoder[_], a) if !e.flat =>
+        e.nested(a.toAttribute).resolve(execution.analyzed.output)
+      case (e, a) =>
+        e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
+    }
+    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
   }
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
-    implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
-    new Dataset[(U1, U2)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
-        logicalPlan))
-  }
-
-
+  def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+    selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
-    implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
-    new Dataset[(U1, U2, U3)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
-        logicalPlan))
-  }
-
-
+  def select[U1, U2, U3](
+      c1: TypedColumn[U1],
+      c2: TypedColumn[U2],
+      c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+    selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
-    implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
-    new Dataset[(U1, U2, U3, U4)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
-        logicalPlan))
-  }
-
-
+  def select[U1, U2, U3, U4](
+      c1: TypedColumn[U1],
+      c2: TypedColumn[U2],
+      c3: TypedColumn[U3],
+      c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+    selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
 
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
    * @since 1.6.0
    */
-  def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
-    implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
-    new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
-      Project(
-        Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
-        logicalPlan))
-  }
-
-  // scalastyle:on
+  def select[U1, U2, U3, U4, U5](
+      c1: TypedColumn[U1],
+      c2: TypedColumn[U2],
+      c3: TypedColumn[U3],
+      c4: TypedColumn[U4],
+      c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+    selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
 
   /* **************** *
    *  Set operations  *
@@ -360,6 +357,48 @@ class Dataset[T] private[sql](
    */
   def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
 
+  /* ****** *
+   *  Joins *
+   * ****** */
+
+  /**
+   * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+   * true.
+   *
+   * This is similar to the relation `join` function with one important difference in the
+   * result schema. Since `joinWith` preserves objects present on either side of the join, the
+   * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+   *
+   * This type of join can be useful both for preserving type-safety with the original object
+   * types as well as working with relational data where either side of the join has column
+   * names in common.
+   */
+  def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+    val left = this.logicalPlan
+    val right = other.logicalPlan
+
+    val leftData = this.encoder match {
+      case e if e.flat => Alias(left.output.head, "_1")()
+      case _ => Alias(CreateStruct(left.output), "_1")()
+    }
+    val rightData = other.encoder match {
+      case e if e.flat => Alias(right.output.head, "_2")()
+      case _ => Alias(CreateStruct(right.output), "_2")()
+    }
+    val leftEncoder =
+      if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
+    val rightEncoder =
+      if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
+    implicit val tuple2Encoder: Encoder[(T, U)] =
+      ExpressionEncoder.tuple(leftEncoder, rightEncoder)
+
+    withPlan[(T, U)](other) { (left, right) =>
+      Project(
+        leftData :: rightData :: Nil,
+        Join(left, right, Inner, Some(condition.expr)))
+    }
+  }
+
   /* ************************** *
    *  Gather to Driver Actions  *
    * ************************** */
@@ -380,13 +419,10 @@ class Dataset[T] private[sql](
   private[sql] def logicalPlan = queryExecution.analyzed
 
   private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
-    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
 
   private[sql] def withPlan[R : Encoder](
       other: Dataset[_])(
       f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
-    new Dataset[R](
-      sqlContext,
-      sqlContext.executePlan(
-        f(logicalPlan, other.logicalPlan)))
+    new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5e7198f974389d3ec2c0554892a6d82eff099472..2cb94430e6178e57e0e7d22f668eaec599d7536d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
 import org.apache.spark.sql.SQLConf.SQLConfEntry
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -491,7 +491,7 @@ class SQLContext private[sql](
 
 
   def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
-    val enc = implicitly[Encoder[T]]
+    val enc = encoderFor[T]
     val attributes = enc.schema.toAttributes
     val encoded = data.map(d => enc.toRow(d).copy())
     val plan = new LocalRelation(attributes, encoded)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index af8474df0de808b1ea6ffd23adf53cb70c7b3432..f460a86414c4159f3506dd0c6a6bb60db36a5b37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String
 abstract class SQLImplicits {
   protected def _sqlContext: SQLContext
 
-  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
 
-  implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
-  implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
-  implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+  implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
+  implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+  implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
+  implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
+  implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
+  implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
+  implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
+  implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
 
   implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
     DatasetHolder(_sqlContext.createDataset(s))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2bb3dba5bd2baf6f7c58d132a714590c67f32fe4..89938471ee381aea994b7491688a8728ea46b26a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
  */
 case class MapPartitions[T, U](
     func: Iterator[T] => Iterator[U],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     output: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {
 
@@ -337,8 +337,8 @@ case class MapPartitions[T, U](
  */
 case class AppendColumns[T, U](
     func: T => U,
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     newColumns: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {
 
@@ -363,9 +363,9 @@ case class AppendColumns[T, U](
  */
 case class MapGroups[K, T, U](
     func: (K, Iterator[T]) => Iterator[U],
-    kEncoder: Encoder[K],
-    tEncoder: Encoder[T],
-    uEncoder: Encoder[U],
+    kEncoder: ExpressionEncoder[K],
+    tEncoder: ExpressionEncoder[T],
+    uEncoder: ExpressionEncoder[U],
     groupingAttributes: Seq[Attribute],
     output: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 08496249c60cc0e55c07de1926e74caf796c8a50..aebb390a1d15de736e5903ba619929d619a6efa6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       data: _*)
   }
 
+  test("as tuple") {
+    val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
+    checkAnswer(
+      data.as[(String, Int)],
+      ("a", 1), ("b", 2))
+  }
+
   test("as case class / collect") {
     val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
     checkAnswer(
@@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       2, 3, 4)
   }
 
-  test("select 3") {
+  test("select 2") {
     val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
     checkAnswer(
       ds.select(
         expr("_1").as[String],
-        expr("_2").as[Int],
-        expr("_2 + 1").as[Int]),
-      ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+        expr("_2").as[Int]) : Dataset[(String, Int)],
+      ("a", 1), ("b", 2), ("c", 3))
+  }
+
+  test("select 2, primitive and tuple") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+    checkAnswer(
+      ds.select(
+        expr("_1").as[String],
+        expr("struct(_2, _2)").as[(Int, Int)]),
+      ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3)))
+  }
+
+  test("select 2, primitive and class") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+    checkAnswer(
+      ds.select(
+        expr("_1").as[String],
+        expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
+      ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
+  }
+
+  test("select 2, primitive and class, fields reordered") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+    checkDecoding(
+      ds.select(
+        expr("_1").as[String],
+        expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
+      ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
   }
 
   test("filter") {
@@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
   }
 
+  test("joinWith, flat schema") {
+    val ds1 = Seq(1, 2, 3).toDS().as("a")
+    val ds2 = Seq(1, 2).toDS().as("b")
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"a.value" === $"b.value"),
+      (1, 1), (2, 2))
+  }
+
+  test("joinWith, expression condition") {
+    val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"_1" === $"a"),
+      (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+  }
+
+  test("joinWith tuple with primitive, expression") {
+    val ds1 = Seq(1, 1, 2).toDS()
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"value" === $"_2"),
+      (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
+  }
+
+  test("joinWith class with primitive, toDF") {
+    val ds1 = Seq(1, 1, 2).toDS()
+    val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"),
+      Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
+  }
+
+  test("multi-level joinWith") {
+    val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+    val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
+      ((("a", 1), ("a", 1)), ("a", 1)),
+      ((("b", 2), ("b", 2)), ("b", 2)))
+
+  }
+
   test("groupBy function, keys") {
     val ds = Seq(("a", 1), ("b", 1)).toDS()
     val grouped = ds.groupBy(v => (1, v._2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index aba567512fe328d45dd1f721e20846c8eaf91a94..73e02eb0d957481479db119f5042fcd0388eb413 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql
 import java.util.{Locale, TimeZone}
 
 import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe._
 
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.Encoder
 
 abstract class QueryTest extends PlanTest {
 
@@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest {
     }
   }
 
-  protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+  /**
+   * Evaluates a dataset to make sure that the result of calling collect matches the given
+   * expected answer.
+   *  - Special handling is done based on whether the query plan should be expected to return
+   *    the results in sorted order.
+   *  - This function also checks to make sure that the schema for serializing the expected answer
+   *    matches that produced by the dataset (i.e. does manual construction of object match
+   *    the constructed encoder for cases like joins, etc).  Note that this means that it will fail
+   *    for cases where reordering is done on fields.  For such tests, user `checkDecoding` instead
+   *    which performs a subset of the checks done by this function.
+   */
+  protected def checkAnswer[T : Encoder](
+      ds: => Dataset[T],
+      expectedAnswer: T*): Unit = {
     checkAnswer(
       ds.toDF(),
       sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+
+    checkDecoding(ds, expectedAnswer: _*)
+  }
+
+  protected def checkDecoding[T](
+      ds: => Dataset[T],
+      expectedAnswer: T*): Unit = {
+    val decoded = try ds.collect().toSet catch {
+      case e: Exception =>
+        fail(
+          s"""
+             |Exception collecting dataset as objects
+             |${ds.encoder}
+             |${ds.encoder.constructExpression.treeString}
+             |${ds.queryExecution}
+           """.stripMargin, e)
+    }
+
+    if (decoded != expectedAnswer.toSet) {
+      fail(
+        s"""Decoded objects do not match expected objects:
+           |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted}
+            |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
+            |${ds.encoder.constructExpression.treeString}
+         """.stripMargin)
+    }
   }
 
   /**