diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index dba69659afc804ac9aed5fa81c32b37fb609e8fe..c8c6676f24c1708dc2b1dee713cb35556fcf8308 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
     val input = children.flatMap(_.output)
     productIterator.map {
       // Children are checked using sameResult above.
-      case tn: TreeNode[_] if children contains tn => null
+      case tn: TreeNode[_] if containsChild(tn) => null
       case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
       case s: Option[_] => s.map {
         case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 5964e3dc3d77e6edeab5a87ebc37f5a453fc72b5..f304597bc978e3a1fcd94cecd25687b0cd79c27b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -59,9 +59,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
 
   val origin: Origin = CurrentOrigin.get
 
-  /** Returns a Seq of the children of this node */
+  /**
+   * Returns a Seq of the children of this node.
+   * Children should not change. Immutability required for containsChild optimization
+   */
   def children: Seq[BaseType]
 
+  lazy val containsChild: Set[TreeNode[_]] = children.toSet
+
   /**
    * Faster version of equality which short-circuits when two treeNodes are the same instance.
    * We don't just override Object.equals, as doing so prevents the scala compiler from
@@ -147,7 +152,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
   def mapChildren(f: BaseType => BaseType): this.type = {
     var changed = false
     val newArgs = productIterator.map {
-      case arg: TreeNode[_] if children contains arg =>
+      case arg: TreeNode[_] if containsChild(arg) =>
         val newChild = f(arg.asInstanceOf[BaseType])
         if (newChild fastEquals arg) {
           arg
@@ -173,7 +178,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
     val newArgs = productIterator.map {
       // Handle Seq[TreeNode] in TreeNode parameters.
       case s: Seq[_] => s.map {
-        case arg: TreeNode[_] if children contains arg =>
+        case arg: TreeNode[_] if containsChild(arg) =>
           val newChild = remainingNewChildren.remove(0)
           val oldChild = remainingOldChildren.remove(0)
           if (newChild fastEquals oldChild) {
@@ -185,7 +190,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
         case nonChild: AnyRef => nonChild
         case null => null
       }
-      case arg: TreeNode[_] if children contains arg =>
+      case arg: TreeNode[_] if containsChild(arg) =>
         val newChild = remainingNewChildren.remove(0)
         val oldChild = remainingOldChildren.remove(0)
         if (newChild fastEquals oldChild) {
@@ -238,7 +243,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
   def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
     var changed = false
     val newArgs = productIterator.map {
-      case arg: TreeNode[_] if children contains arg =>
+      case arg: TreeNode[_] if containsChild(arg) =>
         val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
         if (!(newChild fastEquals arg)) {
           changed = true
@@ -246,7 +251,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
         } else {
           arg
         }
-      case Some(arg: TreeNode[_]) if children contains arg =>
+      case Some(arg: TreeNode[_]) if containsChild(arg) =>
         val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
         if (!(newChild fastEquals arg)) {
           changed = true
@@ -257,7 +262,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
-        case arg: TreeNode[_] if children contains arg =>
+        case arg: TreeNode[_] if containsChild(arg) =>
           val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
           if (!(newChild fastEquals arg)) {
             changed = true
@@ -295,7 +300,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
   def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
     var changed = false
     val newArgs = productIterator.map {
-      case arg: TreeNode[_] if children contains arg =>
+      case arg: TreeNode[_] if containsChild(arg) =>
         val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
         if (!(newChild fastEquals arg)) {
           changed = true
@@ -303,7 +308,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
         } else {
           arg
         }
-      case Some(arg: TreeNode[_]) if children contains arg =>
+      case Some(arg: TreeNode[_]) if containsChild(arg) =>
         val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
         if (!(newChild fastEquals arg)) {
           changed = true
@@ -314,7 +319,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
-        case arg: TreeNode[_] if children contains arg =>
+        case arg: TreeNode[_] if containsChild(arg) =>
           val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
           if (!(newChild fastEquals arg)) {
             changed = true
@@ -383,7 +388,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
 
   /** Returns a string representing the arguments to this node, minus any children */
   def argString: String = productIterator.flatMap {
-    case tn: TreeNode[_] if children contains tn => Nil
+    case tn: TreeNode[_] if containsChild(tn) => Nil
     case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil
     case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil
     case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil