diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c18d7858f0a43e80b79736cb711ecdae06880289..4a9524074132e9bb0f196df07e95d66050188219 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
       case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
         val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
         val resolved = unresolved.flatMap(child.resolveChildren)
-        val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
+        val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
 
         val missingInProject = requiredAttributes -- p.output
         if (missingInProject.nonEmpty) {
@@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
         )
 
         logDebug(s"Grouping expressions: $groupingRelation")
-        val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
-        val missingInAggs = resolved -- a.outputSet
+        val resolved = unresolved.flatMap(groupingRelation.resolve)
+        val missingInAggs = resolved.filterNot(a.outputSet.contains)
         logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
         if (missingInAggs.nonEmpty) {
           // Add missing grouping exprs and then project them away after the sort.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a0e25775da6ddef3a11e9353ebd62701fec996b0..a2c61c65487cb458f3f5d3f8a3effb59cd54370a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
   override def dataType = throw new UnresolvedException(this, "dataType")
   override def foldable = throw new UnresolvedException(this, "foldable")
   override def nullable = throw new UnresolvedException(this, "nullable")
-  override def references = children.flatMap(_.references).toSet
   override lazy val resolved = false
 
   // Unresolved functions are transient at compile time and don't get evaluated during execution.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c3a08bbdb6bc7d438987edbee9b68dfd2a3cf256
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+protected class AttributeEquals(val a: Attribute) {
+  override def hashCode() = a.exprId.hashCode()
+  override def equals(other: Any) = other match {
+    case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
+    case otherAttribute => false
+  }
+}
+
+object AttributeSet {
+  /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
+  def apply(baseSet: Seq[Attribute]) = {
+    new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
+  }
+}
+
+/**
+ * A Set designed to hold [[AttributeReference]] objects, that performs equality checking using
+ * expression id instead of standard java equality.  Using expression id means that these
+ * sets will correctly test for membership, even when the AttributeReferences in question differ
+ * cosmetically (e.g., the names have different capitalizations).
+ *
+ * Note that we do not override equality for Attribute references as it is really weird when
+ * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests,
+ * and also makes doing transformations hard (we always try keep older trees instead of new ones
+ * when the transformation was a no-op).
+ */
+class AttributeSet private (val baseSet: Set[AttributeEquals])
+  extends Traversable[Attribute] with Serializable {
+
+  /** Returns true if the members of this AttributeSet and other are the same. */
+  override def equals(other: Any) = other match {
+    case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
+    case _ => false
+  }
+
+  /** Returns true if this set contains an Attribute with the same expression id as `elem` */
+  def contains(elem: NamedExpression): Boolean =
+    baseSet.contains(new AttributeEquals(elem.toAttribute))
+
+  /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */
+  def +(elem: Attribute): AttributeSet =  // scalastyle:ignore
+    new AttributeSet(baseSet + new AttributeEquals(elem))
+
+  /** Returns a new [[AttributeSet]] that does not contain `elem`. */
+  def -(elem: Attribute): AttributeSet =
+    new AttributeSet(baseSet - new AttributeEquals(elem))
+
+  /** Returns an iterator containing all of the attributes in the set. */
+  def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
+
+  /**
+   * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
+   * `other`.
+   */
+  def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
+
+  /**
+   * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
+   * in `other`.
+   */
+  def --(other: Traversable[NamedExpression]) =
+    new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
+
+  /**
+   * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
+   * in `other`.
+   */
+  def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
+
+  /**
+   * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to
+   * true.
+   */
+  override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
+
+  /**
+   * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
+   * `this` and `other`.
+   */
+  def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
+
+  override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
+
+  // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
+  // sorts of things in its closure.
+  override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 0913f158887807715a69cf6ea7e13d71a584b19a..54c6baf1af3bf1bd896023438597802ba18f4bda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
 
   type EvaluatedType = Any
 
-  override def references = Set.empty
-
   override def toString = s"input[$ordinal]"
 
   override def eval(input: Row): Any = input(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ba62dabe3dd6a8938ad6385dbc1ef6ae4ce8eaad..70507e7ee2be8c1fc5f24e5044a7e5ce389b215d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] {
    */
   def foldable: Boolean = false
   def nullable: Boolean
-  def references: Set[Attribute]
+  def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
 
   /** Returns the result of evaluating this expression on a given input Row */
   def eval(input: Row = null): EvaluatedType
@@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
 
   override def foldable = left.foldable && right.foldable
 
-  override def references = left.references ++ right.references
-
   override def toString = s"($left $symbol $right)"
 }
 
@@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
 abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
   self: Product =>
 
-  override def references = child.references
+
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
index 38f836f0a1a0e0162b63d4ca2841d43dca72c53f..851db95b9177e16994dab40adb5e4bb84d703dd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType
 case object Rand extends LeafExpression {
   override def dataType = DoubleType
   override def nullable = false
-  override def references = Set.empty
 
   private[this] lazy val rand = new Random
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 95633dd0c98707db02aec08239c044ed2b8ca152..63ac2a608b6ff808291e60601c6f819a4929a94d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
 
   type EvaluatedType = Any
 
-  def references = children.flatMap(_.references).toSet
   def nullable = true
 
   /** This method has been generated by this script
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d2b7685e7306570e2228db23f1070a6dfb5c5d2c..d00b2ac09745c1c8f6b43247d7dcac117ee1b840 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -31,7 +31,6 @@ case object Descending extends SortDirection
 case class SortOrder(child: Expression, direction: SortDirection) extends Expression 
     with trees.UnaryNode[Expression] {
 
-  override def references = child.references
   override def dataType = child.dataType
   override def nullable = child.nullable
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index eb8898900d6a50b58fa91abb6cdb90741e827c51..1eb55715794a7f00c58787bf88e995780bf4587c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
   type EvaluatedType = DynamicRow
 
   def nullable = false
-  def references = children.toSet
+
   def dataType = DynamicType
 
   override def eval(input: Row): DynamicRow = input match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 613b87ca98d97acfdceac2a5fc38304aa4543389..dbc0c2965a805cb69de0baefc4104527ed315dbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -78,7 +78,7 @@ abstract class AggregateFunction
 
   /** Base should return the generic aggregate expression that this function is computing */
   val base: AggregateExpression
-  override def references = base.references
+
   override def nullable = base.nullable
   override def dataType = base.dataType
 
@@ -89,7 +89,7 @@ abstract class AggregateFunction
 }
 
 case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = true
   override def dataType = child.dataType
   override def toString = s"MIN($child)"
@@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
 }
 
 case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = true
   override def dataType = child.dataType
   override def toString = s"MAX($child)"
@@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
 }
 
 case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"COUNT($child)"
@@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
   def this() = this(null)
 
   override def children = expressions
-  override def references = expressions.flatMap(_.references).toSet
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
@@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
   def this() = this(null)
 
   override def children = expressions
-  override def references = expressions.flatMap(_.references).toSet
   override def nullable = false
   override def dataType = ArrayType(expressions.head.dataType)
   override def toString = s"AddToHashSet(${expressions.mkString(",")})"
@@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
   def this() = this(null)
 
   override def children = inputSet :: Nil
-  override def references = inputSet.references
   override def nullable = false
   override def dataType = LongType
   override def toString = s"CombineAndCount($inputSet)"
@@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction(
 
 case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
   extends AggregateExpression with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = child.dataType
   override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
 
 case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
   extends AggregateExpression with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
 
 case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
   extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = LongType
   override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
 }
 
 case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = DoubleType
   override def toString = s"AVG($child)"
@@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
 }
 
 case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
+
   override def nullable = false
   override def dataType = child.dataType
   override def toString = s"SUM($child)"
@@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
 case class SumDistinct(child: Expression)
   extends AggregateExpression with trees.UnaryNode[Expression] {
 
-  override def references = child.references
+
   override def nullable = false
   override def dataType = child.dataType
   override def toString = s"SUM(DISTINCT $child)"
@@ -331,7 +329,6 @@ case class SumDistinct(child: Expression)
 }
 
 case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-  override def references = child.references
   override def nullable = true
   override def dataType = child.dataType
   override def toString = s"FIRST($child)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 5f8b6ae10f0c488ddaf54f92eaa0bbfcca45605e..aae86a3628be156407de733f0c135cb25a6993c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -95,8 +95,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
 
   override def children = left :: right :: Nil
 
-  override def references = left.references ++ right.references
-
   override def dataType = left.dataType
 
   override def eval(input: Row): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index c1154eb81c319072d64ee285dab43468ac606c76..dafd745ec96c6b70f3ee28a620dc355a67ff5cbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
   /** `Null` is returned for invalid ordinals. */
   override def nullable = true
   override def foldable = child.foldable && ordinal.foldable
-  override def references = children.flatMap(_.references).toSet
+
   def dataType = child.dataType match {
     case ArrayType(dt, _) => dt
     case MapType(_, vt, _) => vt
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e99c5b452d183ec3331087b1f391d571276b3e20..9c865254e0be924e47e8228c934384c2bb648805 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -47,8 +47,6 @@ abstract class Generator extends Expression {
 
   override def nullable = false
 
-  override def references = children.flatMap(_.references).toSet
-
   /**
    * Should be overridden by specific generators.  Called only once for each instance to ensure
    * that rule application does not change the output schema of a generator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e15e16d6333656745be254f906a43502c075f10c..a8c2396d62632cfe802858c674827fa49ba1864b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
 
   override def foldable = true
   def nullable = value == null
-  def references = Set.empty
+
 
   override def toString = if (value != null) value.toString else "null"
 
@@ -66,8 +66,6 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf
 
   val dataType = Literal(value).dataType
 
-  def references = Set.empty
-
   def update(expression: Expression, input: Row) = {
     value = expression.eval(input)
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 02d04762629f55a99a0a19018b13efd38f1f2759..7c4b9d4847e2613c098781b677fb571e6b755d5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression {
 
   def toAttribute = this
   def newInstance: Attribute
-  override def references = Set(this)
+
 }
 
 /**
@@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String)
 
   override def dataType = child.dataType
   override def nullable = child.nullable
-  override def references = child.references
+
 
   override def toAttribute = {
     if (resolved) {
@@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
     (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
   extends Attribute with trees.LeafNode[Expression] {
 
+  override def references = AttributeSet(this :: Nil)
+
   override def equals(other: Any) = other match {
     case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
     case _ => false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index e88c5d4fa178a3f48bf4a69bc2e5e8d4debee7c6..086d0a3e073e5f659fb42992d5fd3e05695ad264 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
   /** Coalesce is nullable if all of its children are nullable, or if it has no children. */
   def nullable = !children.exists(!_.nullable)
 
-  def references = children.flatMap(_.references).toSet
   // Coalesce is foldable if all children are foldable.
   override def foldable = !children.exists(!_.foldable)
 
@@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
 }
 
 case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
-  def references = child.references
   override def foldable = child.foldable
   def nullable = false
 
@@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
 }
 
 case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
-  def references = child.references
   override def foldable = child.foldable
   def nullable = false
   override def toString = s"IS NOT NULL $child"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 5976b0ddf3e038af493f267c08d42e52f3faf6d8..1313ccd120c1f10f65d27157da450c1557bf61ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate {
  */
 case class In(value: Expression, list: Seq[Expression]) extends Predicate {
   def children = value +: list
-  def references = children.flatMap(_.references).toSet
+
   def nullable = true // TODO: Figure out correct nullability semantics of IN.
   override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
 
@@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
 
   def children = predicate :: trueValue :: falseValue :: Nil
   override def nullable = trueValue.nullable || falseValue.nullable
-  def references = children.flatMap(_.references).toSet
+
   override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
   def dataType = {
     if (!resolved) {
@@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
 case class CaseWhen(branches: Seq[Expression]) extends Expression {
   type EvaluatedType = Any
   def children = branches
-  def references = children.flatMap(_.references).toSet
+
   def dataType = {
     if (!resolved) {
       throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index e6c570b47bee221628123b2f199ee7ac3b25b528..3d4c4a8853c123cc73aa09c2637950be10a02cbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -26,8 +26,6 @@ import org.apache.spark.util.collection.OpenHashSet
 case class NewSet(elementType: DataType) extends LeafExpression {
   type EvaluatedType = Any
 
-  def references = Set.empty
-
   def nullable = false
 
   // We are currently only using these Expressions internally for aggregation.  However, if we ever
@@ -53,9 +51,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
   def nullable = set.nullable
 
   def dataType = set.dataType
-
-  def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet
-
   def eval(input: Row): Any = {
     val itemEval = item.eval(input)
     val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 97fc3a3b14b88ac6dc33699ee176f0a31ef4eb69..c2a3a5ca3ca8b191956e1aabf311b4ecb026ad5b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
     if (str.dataType == BinaryType) str.dataType else StringType
   }
 
-  def references = children.flatMap(_.references).toSet
-
   override def children = str :: pos :: len :: Nil
 
   @inline
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5f86d6047cb9c6f3f0376a2a0615f9972041e00e..ddd4b3755d6293ade54bb3838ca0b26f26ca1a3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -65,8 +65,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Eliminate unneeded attributes from either side of a Join.
     case Project(projectList, Join(left, right, joinType, condition)) =>
       // Collect the list of all references required either above or to evaluate the condition.
-      val allReferences: Set[Attribute] =
-        projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty)
+      val allReferences: AttributeSet =
+        AttributeSet(
+          projectList.flatMap(_.references.iterator)) ++
+          condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
 
       /** Applies a projection only when the child is producing unnecessary attributes */
       def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
@@ -76,8 +78,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Eliminate unneeded attributes from right side of a LeftSemiJoin.
     case Join(left, right, LeftSemi, condition) =>
       // Collect the list of all references required to evaluate the condition.
-      val allReferences: Set[Attribute] =
-        condition.map(_.references).getOrElse(Set.empty)
+      val allReferences: AttributeSet =
+        condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
 
       Join(left, prunedChild(right, allReferences), LeftSemi, condition)
 
@@ -104,7 +106,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
   }
 
   /** Applies a projection only when the child is producing unnecessary attributes */
-  private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) =
+  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
     if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
       Project(allReferences.filter(c.outputSet.contains).toSeq, c)
     } else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0988b0c6d990c68465cc7d17f76940f2f00cc708..1e177e28f80b31d149f1935c03c80ea7ce100a29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType}
 
@@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
   /**
    * Returns the set of attributes that are output by this node.
    */
-  def outputSet: Set[Attribute] = output.toSet
+  def outputSet: AttributeSet = AttributeSet(output)
 
   /**
    * Runs [[transform]] with `rule` on all expressions present in this query operator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 278569f0cb14a5ce190c3a2730a6d0d04de7d77d..8616ac45b0e957f494a096445b1aab3c2ea292ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -45,17 +45,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
     sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product
   )
 
-  /**
-   * Returns the set of attributes that are referenced by this node
-   * during evaluation.
-   */
-  def references: Set[Attribute]
-
   /**
    * Returns the set of attributes that this node takes as
    * input from its children.
    */
-  lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet
+  lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output))
 
   /**
    * Returns true if this expression and all its children have been resolved to a specific schema
@@ -126,9 +120,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
 
   override lazy val statistics: Statistics =
     throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
-
-  // Leaf nodes by definition cannot reference any input attributes.
-  override def references = Set.empty
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index d3f9d0fb9323711e5002bd0e0d81d9f837999c21..4460c86ed90264b215a4dc53e87525db4c177b5b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -30,6 +30,4 @@ case class ScriptTransformation(
     input: Seq[Expression],
     script: String,
     output: Seq[Attribute],
-    child: LogicalPlan) extends UnaryNode {
-  def references = input.flatMap(_.references).toSet
-}
+    child: LogicalPlan) extends UnaryNode
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 3cb407217c4c3884c1e2e3c7e93e0c8df418f1f3..4adfb189372d6d9b1aef02d3fc4b323face4db84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._
 
 case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
   def output = projectList.map(_.toAttribute)
-  def references = projectList.flatMap(_.references).toSet
 }
 
 /**
@@ -59,14 +58,10 @@ case class Generate(
 
   override def output =
     if (join) child.output ++ generatorOutput else generatorOutput
-
-  override def references =
-    if (join) child.outputSet else generator.references
 }
 
 case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = condition.references
 }
 
 case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   override lazy val resolved =
     childrenResolved &&
     !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType }
-
-  override def references = Set.empty
 }
 
 case class Join(
@@ -86,8 +79,6 @@ case class Join(
   joinType: JoinType,
   condition: Option[Expression]) extends BinaryNode {
 
-  override def references = condition.map(_.references).getOrElse(Set.empty)
-
   override def output = {
     joinType match {
       case LeftSemi =>
@@ -106,8 +97,6 @@ case class Join(
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   def output = left.output
-
-  def references = Set.empty
 }
 
 case class InsertIntoTable(
@@ -118,7 +107,6 @@ case class InsertIntoTable(
   extends LogicalPlan {
   // The table being inserted into is a child for the purposes of transformations.
   override def children = table :: child :: Nil
-  override def references = Set.empty
   override def output = child.output
 
   override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
@@ -130,20 +118,17 @@ case class InsertIntoCreatedTable(
     databaseName: Option[String],
     tableName: String,
     child: LogicalPlan) extends UnaryNode {
-  override def references = Set.empty
   override def output = child.output
 }
 
 case class WriteToFile(
     path: String,
     child: LogicalPlan) extends UnaryNode {
-  override def references = Set.empty
   override def output = child.output
 }
 
 case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = order.flatMap(_.references).toSet
 }
 
 case class Aggregate(
@@ -152,19 +137,20 @@ case class Aggregate(
     child: LogicalPlan)
   extends UnaryNode {
 
+  /** The set of all AttributeReferences required for this aggregation. */
+  def references =
+    AttributeSet(
+      groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references))
+
   override def output = aggregateExpressions.map(_.toAttribute)
-  override def references =
-    (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
 }
 
 case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = limitExpr.references
 }
 
 case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
   override def output = child.output.map(_.withQualifiers(alias :: Nil))
-  override def references = Set.empty
 }
 
 /**
@@ -191,20 +177,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
         a.qualifiers)
     case other => other
   }
-
-  override def references = Set.empty
 }
 
 case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
     extends UnaryNode {
 
   override def output = child.output
-  override def references = Set.empty
 }
 
 case class Distinct(child: LogicalPlan) extends UnaryNode {
   override def output = child.output
-  override def references = child.outputSet
 }
 
 case object NoRelation extends LeafNode {
@@ -213,5 +195,4 @@ case object NoRelation extends LeafNode {
 
 case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   override def output = left.output
-  override def references = Set.empty
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 7146fbd540f29c594ef2f589795f0dbafa02c157..72b0c5c8e7a2641b807e5913085e8cba0c461ca8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode {
 
 case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
     extends RedistributeData {
-
-  def references = sortExpressions.flatMap(_.references).toSet
 }
 
 case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan)
     extends RedistributeData {
-
-  def references = partitionExpressions.flatMap(_.references).toSet
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 4bb022cf238af5190c28a9c76815bcb29c8b5e90..ccb0df113c0632478ef2096057c45e68bd4568f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
       "An AllTuples should be used to represent a distribution that only has " +
       "a single partition.")
 
+  // TODO: This is not really valid...
   def clustering = ordering.map(_.child).toSet
 }
 
@@ -139,7 +140,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
   with Partitioning {
 
   override def children = expressions
-  override def references = expressions.flatMap(_.references).toSet
   override def nullable = false
   override def dataType = IntegerType
 
@@ -179,7 +179,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
   with Partitioning {
 
   override def children = ordering
-  override def references = ordering.flatMap(_.references).toSet
   override def nullable = false
   override def dataType = IntegerType
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6344874538d675e78ec94dafee098965f05cadb7..296202543e2ca97d98bd8e947079c6f0f266e381 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, NullType}
 
 case class Dummy(optKey: Option[Expression]) extends Expression {
   def children = optKey.toSeq
-  def references = Set.empty[Attribute]
   def nullable = true
   def dataType = NullType
   override lazy val resolved = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8a9f4deb6a19ef88357ab70fb09d8a14ebc82ce8..6f0eed3f63c41b6f0848b73a8548cdfbdc4033ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -344,8 +344,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
         prunePushedDownFilters: Seq[Expression] => Seq[Expression],
         scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {
 
-      val projectSet = projectList.flatMap(_.references).toSet
-      val filterSet = filterPredicates.flatMap(_.references).toSet
+      val projectSet = AttributeSet(projectList.flatMap(_.references))
+      val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
       val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And)
 
       // Right now we still use a projection even if the only evaluation is applying an alias
@@ -354,7 +354,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
       // TODO: Decouple final output schema from expression evaluation so this copy can be
       // avoided safely.
 
-      if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) {
+      if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
+          filterSet.subsetOf(projectSet)) {
         // When it is possible to just use column pruning to get the right projection and
         // when the columns of this projection are enough to evaluate all filter conditions,
         // just do a scan followed by a filter, with no extra project.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index e63b4903041f6f8941c37e6548392c9f546e674a..24e88eea3189e148803c097bfefaed3676b426ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -79,8 +79,6 @@ private[sql] case class InMemoryRelation(
 
   override def children = Seq.empty
 
-  override def references = Set.empty
-
   override def newInstance() = {
     new InMemoryRelation(
       output.map(_.newInstance),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21cbbc9772a00a3da24c828eb5ddbffeab7d586b..7d33ea5b021e26c17659afb1b0cc33147d4f7e5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ
   extends LogicalPlan with MultiInstanceRelation {
 
   def output = alreadyPlanned.output
-  override def references = Set.empty
   override def children = Nil
 
-  override final def newInstance: this.type = {
+  override final def newInstance(): this.type = {
     SparkLogicalPlan(
       alreadyPlanned match {
         case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f31df051824d72eeb8b1127322130daa38f7ed46..5b896c55b73938aa18325a1c316509b462e73eef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -58,8 +58,6 @@ package object debug {
   }
 
   private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
-    def references = Set.empty
-
     def output = child.output
 
     implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index b92091b560b1c4166751b5ee7417029ec5d35832..aef6ebf86b1ebd08a9c04e210ef75bec1d1d8839 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -49,7 +49,6 @@ private[spark] case class PythonUDF(
   override def toString = s"PythonUDF#$name(${children.mkString(",")})"
 
   def nullable: Boolean = true
-  def references: Set[Attribute] = children.flatMap(_.references).toSet
 
   override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.")
 }
@@ -113,7 +112,6 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
 case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode {
   val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)()
 
-  def references = Set.empty
   def output = child.output :+ resultAttribute
 }
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 389ace726d205d734859ca4f06d23b446c15047d..10fa8314c9156e01c23224e085fb5e7bf4d26a85 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -79,9 +79,9 @@ private[hive] trait HiveStrategies {
              hiveContext.convertMetastoreParquet =>
 
         // Filter out all predicates that only deal with partition keys
-        val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+        val partitionsKeys = AttributeSet(relation.partitionKeys)
         val (pruningPredicates, otherPredicates) = predicates.partition {
-          _.references.map(_.exprId).subsetOf(partitionKeyIds)
+          _.references.subsetOf(partitionsKeys)
         }
 
         // We are going to throw the predicates and projection back at the whole optimization
@@ -176,9 +176,9 @@ private[hive] trait HiveStrategies {
       case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
         // Filter out all predicates that only deal with partition keys, these are given to the
         // hive table scan operator to be used for partition pruning.
-        val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+        val partitionKeyIds = AttributeSet(relation.partitionKeys)
         val (pruningPredicates, otherPredicates) = predicates.partition {
-          _.references.map(_.exprId).subsetOf(partitionKeyIds)
+          _.references.subsetOf(partitionKeyIds)
         }
 
         pruneFilterProject(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index c6497a15efa0c0d4ddf61e510bf7db1db0a37dfe..7d1ad53d8bdb3eadb0c54316391b3748c290bffb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu
   type EvaluatedType = Any
 
   def nullable = true
-  def references = children.flatMap(_.references).toSet
 
   lazy val function = createFunction[UDFType]()
 
@@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf(
 
   def nullable: Boolean = true
 
-  def references: Set[Attribute] = children.map(_.references).flatten.toSet
-
   override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
 
   def newInstance() = new HiveUdafFunction(functionClassName, children, this)
@@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf(
     children: Seq[Expression])
   extends Generator with HiveInspectors with HiveFunctionFactory {
 
-  override def references = children.flatMap(_.references).toSet
-
   @transient
   protected lazy val function: GenericUDTF = createFunction()
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index 6b3ffd1c0ffe2b395271dd061a5f476caf88021c..b6be6bc1bfefed9f85bde48acc52da15402c4a8d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.hive.test.TestHive._
 
-case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
 case class Nested(a: Int, B: Int)
+case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
 
 /**
  * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
@@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest {
       .registerTempTable("caseSensitivityTest")
 
     sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
+
+    println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution)
+
+    sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect()
+
+    // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a")
+
   }
 
   test("nested repeated resolution") {