diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 03aa25eda807fa9df19434e19e574d0b450ace84..c40061ae0aafd6a87ec372584f43ee2093683f5b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -97,6 +97,24 @@ object Encoders {
    */
   def STRING: Encoder[java.lang.String] = ExpressionEncoder()
 
+  /**
+   * Creates an encoder for Java Bean of type T.
+   *
+   * T must be publicly accessible.
+   *
+   * supported types for java bean field:
+   *  - primitive types: boolean, int, double, etc.
+   *  - boxed types: Boolean, Integer, Double, etc.
+   *  - String
+   *  - java.math.BigDecimal
+   *  - time related: java.sql.Date, java.sql.Timestamp
+   *  - collection types: only array and java.util.List currently, map support is in progress
+   *  - nested java bean.
+   *
+   * @since 1.6.0
+   */
+  def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
+
   /**
    * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
    * This encoder maps T into a single byte array (binary) field.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 7d4cfbe6faecba06f371e9895afca86a5f74ef98..c8ee87e8819f2cfb5184b9e5e5ca4a905fe320a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -17,14 +17,20 @@
 
 package org.apache.spark.sql.catalyst
 
-import java.beans.Introspector
+import java.beans.{PropertyDescriptor, Introspector}
 import java.lang.{Iterable => JIterable}
-import java.util.{Iterator => JIterator, Map => JMap}
+import java.util.{Iterator => JIterator, Map => JMap, List => JList}
 
 import scala.language.existentials
 
 import com.google.common.reflect.TypeToken
+
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
+import org.apache.spark.unsafe.types.UTF8String
+
 
 /**
  * Type-inference utilities for POJOs and Java collections.
@@ -33,13 +39,14 @@ object JavaTypeInference {
 
   private val iterableType = TypeToken.of(classOf[JIterable[_]])
   private val mapType = TypeToken.of(classOf[JMap[_, _]])
+  private val listType = TypeToken.of(classOf[JList[_]])
   private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
   private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
   private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
   private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
 
   /**
-   * Infers the corresponding SQL data type of a JavaClean class.
+   * Infers the corresponding SQL data type of a JavaBean class.
    * @param beanClass Java type
    * @return (SQL data type, nullable)
    */
@@ -58,6 +65,8 @@ object JavaTypeInference {
         (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
 
       case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+      case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
+
       case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
       case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
       case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
@@ -87,15 +96,14 @@ object JavaTypeInference {
         (ArrayType(dataType, nullable), true)
 
       case _ if mapType.isAssignableFrom(typeToken) =>
-        val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
-        val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
-        val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
-        val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
+        val (keyType, valueType) = mapKeyValueType(typeToken)
         val (keyDataType, _) = inferDataType(keyType)
         val (valueDataType, nullable) = inferDataType(valueType)
         (MapType(keyDataType, valueDataType, nullable), true)
 
       case _ =>
+        // TODO: we should only collect properties that have getter and setter. However, some tests
+        // pass in scala case class as java bean class which doesn't have getter and setter.
         val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
         val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
         val fields = properties.map { property =>
@@ -107,11 +115,294 @@ object JavaTypeInference {
     }
   }
 
+  private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
+    val beanInfo = Introspector.getBeanInfo(beanClass)
+    beanInfo.getPropertyDescriptors
+      .filter(p => p.getReadMethod != null && p.getWriteMethod != null)
+  }
+
   private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
     val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
-    val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
-    val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
-    val itemType = iteratorType.resolveType(nextReturnType)
-    itemType
+    val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
+    val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
+    iteratorType.resolveType(nextReturnType)
+  }
+
+  private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
+    val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
+    val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
+    val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
+    val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
+    keyType -> valueType
+  }
+
+  /**
+   * Returns the Spark SQL DataType for a given java class.  Where this is not an exact mapping
+   * to a native type, an ObjectType is returned.
+   *
+   * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
+   * system.  As a result, ObjectType will be returned for things like boxed Integers.
+   */
+  private def inferExternalType(cls: Class[_]): DataType = cls match {
+    case c if c == java.lang.Boolean.TYPE => BooleanType
+    case c if c == java.lang.Byte.TYPE => ByteType
+    case c if c == java.lang.Short.TYPE => ShortType
+    case c if c == java.lang.Integer.TYPE => IntegerType
+    case c if c == java.lang.Long.TYPE => LongType
+    case c if c == java.lang.Float.TYPE => FloatType
+    case c if c == java.lang.Double.TYPE => DoubleType
+    case c if c == classOf[Array[Byte]] => BinaryType
+    case _ => ObjectType(cls)
+  }
+
+  /**
+   * Returns an expression that can be used to construct an object of java bean `T` given an input
+   * row with a compatible schema.  Fields of the row will be extracted using UnresolvedAttributes
+   * of the same name as the constructor arguments.  Nested classes will have their fields accessed
+   * using UnresolvedExtractValue.
+   */
+  def constructorFor(beanClass: Class[_]): Expression = {
+    constructorFor(TypeToken.of(beanClass), None)
+  }
+
+  private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
+    /** Returns the current path with a sub-field extracted. */
+    def addToPath(part: String): Expression = path
+      .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+      .getOrElse(UnresolvedAttribute(part))
+
+    /** Returns the current path or `BoundReference`. */
+    def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
+
+    typeToken.getRawType match {
+      case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
+
+      case c if c == classOf[java.lang.Short] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+      case c if c == classOf[java.lang.Integer] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+      case c if c == classOf[java.lang.Long] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+      case c if c == classOf[java.lang.Double] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+      case c if c == classOf[java.lang.Byte] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+      case c if c == classOf[java.lang.Float] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+      case c if c == classOf[java.lang.Boolean] =>
+        NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+
+      case c if c == classOf[java.sql.Date] =>
+        StaticInvoke(
+          DateTimeUtils,
+          ObjectType(c),
+          "toJavaDate",
+          getPath :: Nil,
+          propagateNull = true)
+
+      case c if c == classOf[java.sql.Timestamp] =>
+        StaticInvoke(
+          DateTimeUtils,
+          ObjectType(c),
+          "toJavaTimestamp",
+          getPath :: Nil,
+          propagateNull = true)
+
+      case c if c == classOf[java.lang.String] =>
+        Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+      case c if c == classOf[java.math.BigDecimal] =>
+        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+      case c if c.isArray =>
+        val elementType = c.getComponentType
+        val primitiveMethod = elementType match {
+          case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
+          case c if c == java.lang.Byte.TYPE => Some("toByteArray")
+          case c if c == java.lang.Short.TYPE => Some("toShortArray")
+          case c if c == java.lang.Integer.TYPE => Some("toIntArray")
+          case c if c == java.lang.Long.TYPE => Some("toLongArray")
+          case c if c == java.lang.Float.TYPE => Some("toFloatArray")
+          case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
+          case _ => None
+        }
+
+        primitiveMethod.map { method =>
+          Invoke(getPath, method, ObjectType(c))
+        }.getOrElse {
+          Invoke(
+            MapObjects(
+              p => constructorFor(typeToken.getComponentType, Some(p)),
+              getPath,
+              inferDataType(elementType)._1),
+            "array",
+            ObjectType(c))
+        }
+
+      case c if listType.isAssignableFrom(typeToken) =>
+        val et = elementType(typeToken)
+        val array =
+          Invoke(
+            MapObjects(
+              p => constructorFor(et, Some(p)),
+              getPath,
+              inferDataType(et)._1),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
+
+      case _ if mapType.isAssignableFrom(typeToken) =>
+        val (keyType, valueType) = mapKeyValueType(typeToken)
+        val keyDataType = inferDataType(keyType)._1
+        val valueDataType = inferDataType(valueType)._1
+
+        val keyData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(keyType, Some(p)),
+              Invoke(getPath, "keyArray", ArrayType(keyDataType)),
+              keyDataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        val valueData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(valueType, Some(p)),
+              Invoke(getPath, "valueArray", ArrayType(valueDataType)),
+              valueDataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        StaticInvoke(
+          ArrayBasedMapData,
+          ObjectType(classOf[JMap[_, _]]),
+          "toJavaMap",
+          keyData :: valueData :: Nil)
+
+      case other =>
+        val properties = getJavaBeanProperties(other)
+        assert(properties.length > 0)
+
+        val setters = properties.map { p =>
+          val fieldName = p.getName
+          val fieldType = typeToken.method(p.getReadMethod).getReturnType
+          p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
+        }.toMap
+
+        val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
+        val result = InitializeJavaBean(newInstance, setters)
+
+        if (path.nonEmpty) {
+          expressions.If(
+            IsNull(getPath),
+            expressions.Literal.create(null, ObjectType(other)),
+            result
+          )
+        } else {
+          result
+        }
+    }
+  }
+
+  /**
+   * Returns expressions for extracting all the fields from the given type.
+   */
+  def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
+    val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
+    extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+  }
+
+  private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
+
+    def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
+      val (dataType, nullable) = inferDataType(elementType)
+      if (ScalaReflection.isNativeType(dataType)) {
+        NewInstance(
+          classOf[GenericArrayData],
+          input :: Nil,
+          dataType = ArrayType(dataType, nullable))
+      } else {
+        MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
+      }
+    }
+
+    if (!inputObject.dataType.isInstanceOf[ObjectType]) {
+      inputObject
+    } else {
+      typeToken.getRawType match {
+        case c if c == classOf[String] =>
+          StaticInvoke(
+            classOf[UTF8String],
+            StringType,
+            "fromString",
+            inputObject :: Nil)
+
+        case c if c == classOf[java.sql.Timestamp] =>
+          StaticInvoke(
+            DateTimeUtils,
+            TimestampType,
+            "fromJavaTimestamp",
+            inputObject :: Nil)
+
+        case c if c == classOf[java.sql.Date] =>
+          StaticInvoke(
+            DateTimeUtils,
+            DateType,
+            "fromJavaDate",
+            inputObject :: Nil)
+
+        case c if c == classOf[java.math.BigDecimal] =>
+          StaticInvoke(
+            Decimal,
+            DecimalType.SYSTEM_DEFAULT,
+            "apply",
+            inputObject :: Nil)
+
+        case c if c == classOf[java.lang.Boolean] =>
+          Invoke(inputObject, "booleanValue", BooleanType)
+        case c if c == classOf[java.lang.Byte] =>
+          Invoke(inputObject, "byteValue", ByteType)
+        case c if c == classOf[java.lang.Short] =>
+          Invoke(inputObject, "shortValue", ShortType)
+        case c if c == classOf[java.lang.Integer] =>
+          Invoke(inputObject, "intValue", IntegerType)
+        case c if c == classOf[java.lang.Long] =>
+          Invoke(inputObject, "longValue", LongType)
+        case c if c == classOf[java.lang.Float] =>
+          Invoke(inputObject, "floatValue", FloatType)
+        case c if c == classOf[java.lang.Double] =>
+          Invoke(inputObject, "doubleValue", DoubleType)
+
+        case _ if typeToken.isArray =>
+          toCatalystArray(inputObject, typeToken.getComponentType)
+
+        case _ if listType.isAssignableFrom(typeToken) =>
+          toCatalystArray(inputObject, elementType(typeToken))
+
+        case _ if mapType.isAssignableFrom(typeToken) =>
+          // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
+          // not guarantee they have same iteration order(which is different from scala map).
+          // A possible solution is creating a new `MapObjects` that can iterate a map directly.
+          throw new UnsupportedOperationException("map type is not supported currently")
+
+        case other =>
+          val properties = getJavaBeanProperties(other)
+          if (properties.length > 0) {
+            CreateNamedStruct(properties.flatMap { p =>
+              val fieldName = p.getName
+              val fieldType = typeToken.method(p.getReadMethod).getReturnType
+              val fieldValue = Invoke(
+                inputObject,
+                p.getReadMethod.getName,
+                inferExternalType(fieldType.getRawType))
+              expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+            })
+          } else {
+            throw new UnsupportedOperationException(s"no encoder found for ${other.getName}")
+          }
+      }
+    }
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 06ffe864552fd45aecef6b3c6f788f9359262535..3e8420ecb9ccf64ce8cf3f0665cd31965f2cdcbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection}
 import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
 
 /**
@@ -68,6 +67,22 @@ object ExpressionEncoder {
       ClassTag[T](cls))
   }
 
+  // TODO: improve error message for java bean encoder.
+  def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
+    val schema = JavaTypeInference.inferDataType(beanClass)._1
+    assert(schema.isInstanceOf[StructType])
+
+    val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
+    val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
+
+    new ExpressionEncoder[T](
+      schema.asInstanceOf[StructType],
+      flat = false,
+      toRowExpression.flatten,
+      fromRowExpression,
+      ClassTag[T](beanClass))
+  }
+
   /**
    * Given a set of N encoders, constructs a new encoder that produce objects as items in an
    * N-tuple.  Note that these encoders should be unresolved so that information about
@@ -216,7 +231,7 @@ case class ExpressionEncoder[T](
    */
   def assertUnresolved(): Unit = {
     (fromRowExpression +:  toRowExpressions).foreach(_.foreach {
-      case a: AttributeReference =>
+      case a: AttributeReference if a.name != "loopVar" =>
         sys.error(s"Unresolved encoder expected, but $a was found.")
       case _ =>
     })
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 62d09f0f55105f65b4405a7fdd48e3acfe8cb28c..e6ab9a31be59edee7774633426d5a9123ab7917c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
  * as an ArrayType.  This is similar to a typical map operation, but where the lambda function
  * is expressed using catalyst expressions.
  *
- * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData
+ * The following collection ObjectTypes are currently supported:
+ *   Seq, Array, ArrayData, java.util.List
  *
  * @param function A function that returns an expression, given an attribute that can be used
  *                 to access the current value.  This is does as a lambda function so that
@@ -386,6 +387,8 @@ case class MapObjects(
       (".size()", (i: String) => s".apply($i)", false)
     case ObjectType(cls) if cls.isArray =>
       (".length", (i: String) => s"[$i]", false)
+    case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+      (".size()", (i: String) => s".get($i)", false)
     case ArrayType(t, _) =>
       val (sqlType, primitiveElement) = t match {
         case m: MapType => (m, false)
@@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
 
   override def dataType: DataType = ObjectType(tag.runtimeClass)
 }
+
+/**
+ * Initialize a Java Bean instance by setting its field values via setters.
+ */
+case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
+  extends Expression {
+
+  override def nullable: Boolean = beanInstance.nullable
+  override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
+  override def dataType: DataType = beanInstance.dataType
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val instanceGen = beanInstance.gen(ctx)
+
+    val initialize = setters.map {
+      case (setterMethod, fieldValue) =>
+        val fieldGen = fieldValue.gen(ctx)
+        s"""
+           ${fieldGen.code}
+           ${instanceGen.value}.$setterMethod(${fieldGen.value});
+         """
+    }
+
+    ev.isNull = instanceGen.isNull
+    ev.value = instanceGen.value
+
+    s"""
+      ${instanceGen.code}
+      if (!${instanceGen.isNull}) {
+        ${initialize.mkString("\n")}
+      }
+     """
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 35f087baccdeebe6d677c510274bf84e82ec5b1e..f1cea07976a3765d086bb6fd729cf5bb1a6886b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.trees
 
+import scala.collection.Map
+
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.types.{StructType, DataType}
 
@@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
         case nonChild: AnyRef => nonChild
         case null => null
       }
+      case m: Map[_, _] => m.mapValues {
+        case arg: TreeNode[_] if containsChild(arg) =>
+          val newChild = remainingNewChildren.remove(0)
+          val oldChild = remainingOldChildren.remove(0)
+          if (newChild fastEquals oldChild) {
+            oldChild
+          } else {
+            changed = true
+            newChild
+          }
+        case nonChild: AnyRef => nonChild
+        case null => null
+      }.view.force // `mapValues` is lazy and we need to force it to materialize
       case arg: TreeNode[_] if containsChild(arg) =>
         val newChild = remainingNewChildren.remove(0)
         val oldChild = remainingOldChildren.remove(0)
@@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
         } else {
           Some(arg)
         }
-      case m: Map[_, _] => m
+      case m: Map[_, _] => m.mapValues {
+        case arg: TreeNode[_] if containsChild(arg) =>
+          val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
+          if (!(newChild fastEquals arg)) {
+            changed = true
+            newChild
+          } else {
+            arg
+          }
+        case other => other
+      }.view.force // `mapValues` is lazy and we need to force it to materialize
       case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
         case arg: TreeNode[_] if containsChild(arg) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
index 70b028d2b3f7c356480dd68f8806f9abfaa0d95b..d85b72ed83def8d89bd5ef91afe8345a7be26dc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
@@ -70,4 +70,9 @@ object ArrayBasedMapData {
   def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
     keys.zip(values).toMap
   }
+
+  def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
+    import scala.collection.JavaConverters._
+    keys.zip(values).toMap.asJava
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 96588bb5dc1bc033cae775bdba4416514806898e..2b8cdc1e23ab374dc2fec25f8dd963453076acdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.util
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.{DataType, Decimal}
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 class GenericArrayData(val array: Array[Any]) extends ArrayData {
 
   def this(seq: Seq[Any]) = this(seq.toArray)
+  def this(list: java.util.List[Any]) = this(list.asScala)
 
   // TODO: This is boxing.  We should specialize.
   def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 8fff39906b3426fd71aa90627aa135aeef15672d..965bdb1515e5580cc18ab2726ae9105e392e4815 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
   override def output: Seq[Attribute] = Nil
 }
 
+case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
+  override def children: Seq[Expression] = map.values.toSeq
+  override def nullable: Boolean = true
+  override def dataType: NullType = NullType
+  override lazy val resolved = true
+}
+
 class TreeNodeSuite extends SparkFunSuite {
   test("top node changed") {
     val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite {
     val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
     assert(expected === actual)
   }
+
+  test("expressions inside a map") {
+    val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2)))
+
+    {
+      val actual = expression.transform {
+        case Literal(i: Int, _) => Literal(i + 1)
+      }
+      val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
+      assert(actual === expected)
+    }
+
+    {
+      val actual = expression.withNewChildren(Seq(Literal(2), Literal(3)))
+      val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
+      assert(actual === expected)
+    }
+  }
 }
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 67a3190cb7d4f31504ba9f5f7b222e346f9de2be..ae47f4fe0e231f779ef5a145782d119d72f99139 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -31,14 +31,15 @@ import org.apache.spark.Accumulator;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.function.*;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.Encoders;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.GroupedDataset;
+import org.apache.spark.sql.*;
 import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.catalyst.encoders.OuterScopes;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.StructType;
 
 import static org.apache.spark.sql.functions.*;
+import static org.apache.spark.sql.types.DataTypes.*;
 
 public class JavaDatasetSuite implements Serializable {
   private transient JavaSparkContext jsc;
@@ -506,4 +507,169 @@ public class JavaDatasetSuite implements Serializable {
   public void testKryoEncoderErrorMessageForPrivateClass() {
     Encoders.kryo(PrivateClassTest.class);
   }
+
+  public class SimpleJavaBean implements Serializable {
+    private boolean a;
+    private int b;
+    private byte[] c;
+    private String[] d;
+    private List<String> e;
+    private List<Long> f;
+
+    public boolean isA() {
+      return a;
+    }
+
+    public void setA(boolean a) {
+      this.a = a;
+    }
+
+    public int getB() {
+      return b;
+    }
+
+    public void setB(int b) {
+      this.b = b;
+    }
+
+    public byte[] getC() {
+      return c;
+    }
+
+    public void setC(byte[] c) {
+      this.c = c;
+    }
+
+    public String[] getD() {
+      return d;
+    }
+
+    public void setD(String[] d) {
+      this.d = d;
+    }
+
+    public List<String> getE() {
+      return e;
+    }
+
+    public void setE(List<String> e) {
+      this.e = e;
+    }
+
+    public List<Long> getF() {
+      return f;
+    }
+
+    public void setF(List<Long> f) {
+      this.f = f;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+
+      SimpleJavaBean that = (SimpleJavaBean) o;
+
+      if (a != that.a) return false;
+      if (b != that.b) return false;
+      if (!Arrays.equals(c, that.c)) return false;
+      if (!Arrays.equals(d, that.d)) return false;
+      if (!e.equals(that.e)) return false;
+      return f.equals(that.f);
+    }
+
+    @Override
+    public int hashCode() {
+      int result = (a ? 1 : 0);
+      result = 31 * result + b;
+      result = 31 * result + Arrays.hashCode(c);
+      result = 31 * result + Arrays.hashCode(d);
+      result = 31 * result + e.hashCode();
+      result = 31 * result + f.hashCode();
+      return result;
+    }
+  }
+
+  public class NestedJavaBean implements Serializable {
+    private SimpleJavaBean a;
+
+    public SimpleJavaBean getA() {
+      return a;
+    }
+
+    public void setA(SimpleJavaBean a) {
+      this.a = a;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+
+      NestedJavaBean that = (NestedJavaBean) o;
+
+      return a.equals(that.a);
+    }
+
+    @Override
+    public int hashCode() {
+      return a.hashCode();
+    }
+  }
+
+  @Test
+  public void testJavaBeanEncoder() {
+    OuterScopes.addOuterScope(this);
+    SimpleJavaBean obj1 = new SimpleJavaBean();
+    obj1.setA(true);
+    obj1.setB(3);
+    obj1.setC(new byte[]{1, 2});
+    obj1.setD(new String[]{"hello", null});
+    obj1.setE(Arrays.asList("a", "b"));
+    obj1.setF(Arrays.asList(100L, null, 200L));
+    SimpleJavaBean obj2 = new SimpleJavaBean();
+    obj2.setA(false);
+    obj2.setB(30);
+    obj2.setC(new byte[]{3, 4});
+    obj2.setD(new String[]{null, "world"});
+    obj2.setE(Arrays.asList("x", "y"));
+    obj2.setF(Arrays.asList(300L, null, 400L));
+
+    List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
+    Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class));
+    Assert.assertEquals(data, ds.collectAsList());
+
+    NestedJavaBean obj3 = new NestedJavaBean();
+    obj3.setA(obj1);
+
+    List<NestedJavaBean> data2 = Arrays.asList(obj3);
+    Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
+    Assert.assertEquals(data2, ds2.collectAsList());
+
+    Row row1 = new GenericRow(new Object[]{
+      true,
+      3,
+      new byte[]{1, 2},
+      new String[]{"hello", null},
+      Arrays.asList("a", "b"),
+      Arrays.asList(100L, null, 200L)});
+    Row row2 = new GenericRow(new Object[]{
+      false,
+      30,
+      new byte[]{3, 4},
+      new String[]{null, "world"},
+      Arrays.asList("x", "y"),
+      Arrays.asList(300L, null, 400L)});
+    StructType schema = new StructType()
+      .add("a", BooleanType, false)
+      .add("b", IntegerType, false)
+      .add("c", BinaryType)
+      .add("d", createArrayType(StringType))
+      .add("e", createArrayType(StringType))
+      .add("f", createArrayType(LongType));
+    Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
+      .as(Encoders.bean(SimpleJavaBean.class));
+    Assert.assertEquals(data, ds3.collectAsList());
+  }
 }