diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c18d7858f0a43e80b79736cb711ecdae06880289..4a9524074132e9bb0f196df07e95d66050188219 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) val resolved = unresolved.flatMap(child.resolveChildren) - val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) val missingInProject = requiredAttributes -- p.output if (missingInProject.nonEmpty) { @@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve).toSet - val missingInAggs = resolved -- a.outputSet + val resolved = unresolved.flatMap(groupingRelation.resolve) + val missingInAggs = resolved.filterNot(a.outputSet.contains) logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") if (missingInAggs.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a0e25775da6ddef3a11e9353ebd62701fec996b0..a2c61c65487cb458f3f5d3f8a3effb59cd54370a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override def dataType = throw new UnresolvedException(this, "dataType") override def foldable = throw new UnresolvedException(this, "foldable") override def nullable = throw new UnresolvedException(this, "nullable") - override def references = children.flatMap(_.references).toSet override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..c3a08bbdb6bc7d438987edbee9b68dfd2a3cf256 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +protected class AttributeEquals(val a: Attribute) { + override def hashCode() = a.exprId.hashCode() + override def equals(other: Any) = other match { + case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId + case otherAttribute => false + } +} + +object AttributeSet { + /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ + def apply(baseSet: Seq[Attribute]) = { + new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) + } +} + +/** + * A Set designed to hold [[AttributeReference]] objects, that performs equality checking using + * expression id instead of standard java equality. Using expression id means that these + * sets will correctly test for membership, even when the AttributeReferences in question differ + * cosmetically (e.g., the names have different capitalizations). + * + * Note that we do not override equality for Attribute references as it is really weird when + * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests, + * and also makes doing transformations hard (we always try keep older trees instead of new ones + * when the transformation was a no-op). + */ +class AttributeSet private (val baseSet: Set[AttributeEquals]) + extends Traversable[Attribute] with Serializable { + + /** Returns true if the members of this AttributeSet and other are the same. */ + override def equals(other: Any) = other match { + case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case _ => false + } + + /** Returns true if this set contains an Attribute with the same expression id as `elem` */ + def contains(elem: NamedExpression): Boolean = + baseSet.contains(new AttributeEquals(elem.toAttribute)) + + /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */ + def +(elem: Attribute): AttributeSet = // scalastyle:ignore + new AttributeSet(baseSet + new AttributeEquals(elem)) + + /** Returns a new [[AttributeSet]] that does not contain `elem`. */ + def -(elem: Attribute): AttributeSet = + new AttributeSet(baseSet - new AttributeEquals(elem)) + + /** Returns an iterator containing all of the attributes in the set. */ + def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator + + /** + * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in + * `other`. + */ + def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + + /** + * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found + * in `other`. + */ + def --(other: Traversable[NamedExpression]) = + new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + + /** + * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found + * in `other`. + */ + def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + + /** + * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to + * true. + */ + override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + + /** + * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in + * `this` and `other`. + */ + def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + + override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) + + // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all + // sorts of things in its closure. + override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 0913f158887807715a69cf6ea7e13d71a584b19a..54c6baf1af3bf1bd896023438597802ba18f4bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def references = Set.empty - override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ba62dabe3dd6a8938ad6385dbc1ef6ae4ce8eaad..70507e7ee2be8c1fc5f24e5044a7e5ce389b215d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] { */ def foldable: Boolean = false def nullable: Boolean - def references: Set[Attribute] + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): EvaluatedType @@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def foldable = left.foldable && right.foldable - override def references = left.references ++ right.references - override def toString = s"($left $symbol $right)" } @@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - override def references = child.references + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index 38f836f0a1a0e0162b63d4ca2841d43dca72c53f..851db95b9177e16994dab40adb5e4bb84d703dd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType case object Rand extends LeafExpression { override def dataType = DoubleType override def nullable = false - override def references = Set.empty private[this] lazy val rand = new Random diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 95633dd0c98707db02aec08239c044ed2b8ca152..63ac2a608b6ff808291e60601c6f819a4929a94d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def references = children.flatMap(_.references).toSet def nullable = true /** This method has been generated by this script diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d2b7685e7306570e2228db23f1070a6dfb5c5d2c..d00b2ac09745c1c8f6b43247d7dcac117ee1b840 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -31,7 +31,6 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def references = child.references override def dataType = child.dataType override def nullable = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index eb8898900d6a50b58fa91abb6cdb90741e827c51..1eb55715794a7f00c58787bf88e995780bf4587c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow def nullable = false - def references = children.toSet + def dataType = DynamicType override def eval(input: Row): DynamicRow = input match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 613b87ca98d97acfdceac2a5fc38304aa4543389..dbc0c2965a805cb69de0baefc4104527ed315dbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -78,7 +78,7 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def references = base.references + override def nullable = base.nullable override def dataType = base.dataType @@ -89,7 +89,7 @@ abstract class AggregateFunction } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = true override def dataType = child.dataType override def toString = s"MIN($child)" @@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr } case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = true override def dataType = child.dataType override def toString = s"MAX($child)" @@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"COUNT($child)" @@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate def this() = this(null) override def children = expressions - override def references = expressions.flatMap(_.references).toSet + override def nullable = false override def dataType = LongType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" @@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress def this() = this(null) override def children = expressions - override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = ArrayType(expressions.head.dataType) override def toString = s"AddToHashSet(${expressions.mkString(",")})" @@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression def this() = this(null) override def children = inputSet :: Nil - override def references = inputSet.references override def nullable = false override def dataType = LongType override def toString = s"CombineAndCount($inputSet)" @@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction( case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = DoubleType override def toString = s"AVG($child)" @@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"SUM($child)" @@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" @@ -331,7 +329,6 @@ case class SumDistinct(child: Expression) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references override def nullable = true override def dataType = child.dataType override def toString = s"FIRST($child)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5f8b6ae10f0c488ddaf54f92eaa0bbfcca45605e..aae86a3628be156407de733f0c135cb25a6993c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -95,8 +95,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def children = left :: right :: Nil - override def references = left.references ++ right.references - override def dataType = left.dataType override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index c1154eb81c319072d64ee285dab43468ac606c76..dafd745ec96c6b70f3ee28a620dc355a67ff5cbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { /** `Null` is returned for invalid ordinals. */ override def nullable = true override def foldable = child.foldable && ordinal.foldable - override def references = children.flatMap(_.references).toSet + def dataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e99c5b452d183ec3331087b1f391d571276b3e20..9c865254e0be924e47e8228c934384c2bb648805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -47,8 +47,6 @@ abstract class Generator extends Expression { override def nullable = false - override def references = children.flatMap(_.references).toSet - /** * Should be overridden by specific generators. Called only once for each instance to ensure * that rule application does not change the output schema of a generator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e15e16d6333656745be254f906a43502c075f10c..a8c2396d62632cfe802858c674827fa49ba1864b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def foldable = true def nullable = value == null - def references = Set.empty + override def toString = if (value != null) value.toString else "null" @@ -66,8 +66,6 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf val dataType = Literal(value).dataType - def references = Set.empty - def update(expression: Expression, input: Row) = { value = expression.eval(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 02d04762629f55a99a0a19018b13efd38f1f2759..7c4b9d4847e2613c098781b677fb571e6b755d5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression { def toAttribute = this def newInstance: Attribute - override def references = Set(this) + } /** @@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable - override def references = child.references + override def toAttribute = { if (resolved) { @@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { + override def references = AttributeSet(this :: Nil) + override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e88c5d4fa178a3f48bf4a69bc2e5e8d4debee7c6..086d0a3e073e5f659fb42992d5fd3e05695ad264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ def nullable = !children.exists(!_.nullable) - def references = children.flatMap(_.references).toSet // Coalesce is foldable if all children are foldable. override def foldable = !children.exists(!_.foldable) @@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - def references = child.references override def foldable = child.foldable def nullable = false @@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - def references = child.references override def foldable = child.foldable def nullable = false override def toString = s"IS NOT NULL $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5976b0ddf3e038af493f267c08d42e52f3faf6d8..1313ccd120c1f10f65d27157da450c1557bf61ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { def children = value +: list - def references = children.flatMap(_.references).toSet + def nullable = true // TODO: Figure out correct nullability semantics of IN. override def toString = s"$value IN ${list.mkString("(", ",", ")")}" @@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi def children = predicate :: trueValue :: falseValue :: Nil override def nullable = trueValue.nullable || falseValue.nullable - def references = children.flatMap(_.references).toSet + override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType def dataType = { if (!resolved) { @@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any def children = branches - def references = children.flatMap(_.references).toSet + def dataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index e6c570b47bee221628123b2f199ee7ac3b25b528..3d4c4a8853c123cc73aa09c2637950be10a02cbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -26,8 +26,6 @@ import org.apache.spark.util.collection.OpenHashSet case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def references = Set.empty - def nullable = false // We are currently only using these Expressions internally for aggregation. However, if we ever @@ -53,9 +51,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { def nullable = set.nullable def dataType = set.dataType - - def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet - def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 97fc3a3b14b88ac6dc33699ee176f0a31ef4eb69..c2a3a5ca3ca8b191956e1aabf311b4ecb026ad5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends if (str.dataType == BinaryType) str.dataType else StringType } - def references = children.flatMap(_.references).toSet - override def children = str :: pos :: len :: Nil @inline diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5f86d6047cb9c6f3f0376a2a0615f9972041e00e..ddd4b3755d6293ade54bb3838ca0b26f26ca1a3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -65,8 +65,10 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. - val allReferences: Set[Attribute] = - projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + val allReferences: AttributeSet = + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) @@ -76,8 +78,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate unneeded attributes from right side of a LeftSemiJoin. case Join(left, right, LeftSemi, condition) => // Collect the list of all references required to evaluate the condition. - val allReferences: Set[Attribute] = - condition.map(_.references).getOrElse(Set.empty) + val allReferences: AttributeSet = + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) Join(left, prunedChild(right, allReferences), LeftSemi, condition) @@ -104,7 +106,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0988b0c6d990c68465cc7d17f76940f2f00cc708..1e177e28f80b31d149f1935c03c80ea7ce100a29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} @@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Returns the set of attributes that are output by this node. */ - def outputSet: Set[Attribute] = output.toSet + def outputSet: AttributeSet = AttributeSet(output) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 278569f0cb14a5ce190c3a2730a6d0d04de7d77d..8616ac45b0e957f494a096445b1aab3c2ea292ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -45,17 +45,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product ) - /** - * Returns the set of attributes that are referenced by this node - * during evaluation. - */ - def references: Set[Attribute] - /** * Returns the set of attributes that this node takes as * input from its children. */ - lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet + lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output)) /** * Returns true if this expression and all its children have been resolved to a specific schema @@ -126,9 +120,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { override lazy val statistics: Statistics = throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") - - // Leaf nodes by definition cannot reference any input attributes. - override def references = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index d3f9d0fb9323711e5002bd0e0d81d9f837999c21..4460c86ed90264b215a4dc53e87525db4c177b5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -30,6 +30,4 @@ case class ScriptTransformation( input: Seq[Expression], script: String, output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - def references = input.flatMap(_.references).toSet -} + child: LogicalPlan) extends UnaryNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3cb407217c4c3884c1e2e3c7e93e0c8df418f1f3..4adfb189372d6d9b1aef02d3fc4b323face4db84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { def output = projectList.map(_.toAttribute) - def references = projectList.flatMap(_.references).toSet } /** @@ -59,14 +58,10 @@ case class Generate( override def output = if (join) child.output ++ generatorOutput else generatorOutput - - override def references = - if (join) child.outputSet else generator.references } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = condition.references } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } - - override def references = Set.empty } case class Join( @@ -86,8 +79,6 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def references = condition.map(_.references).getOrElse(Set.empty) - override def output = { joinType match { case LeftSemi => @@ -106,8 +97,6 @@ case class Join( case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { def output = left.output - - def references = Set.empty } case class InsertIntoTable( @@ -118,7 +107,6 @@ case class InsertIntoTable( extends LogicalPlan { // The table being inserted into is a child for the purposes of transformations. override def children = table :: child :: Nil - override def references = Set.empty override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { @@ -130,20 +118,17 @@ case class InsertIntoCreatedTable( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { - override def references = Set.empty override def output = child.output } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def references = Set.empty override def output = child.output } case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = order.flatMap(_.references).toSet } case class Aggregate( @@ -152,19 +137,20 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { + /** The set of all AttributeReferences required for this aggregation. */ + def references = + AttributeSet( + groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references)) + override def output = aggregateExpressions.map(_.toAttribute) - override def references = - (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output = child.output.map(_.withQualifiers(alias :: Nil)) - override def references = Set.empty } /** @@ -191,20 +177,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { a.qualifiers) case other => other } - - override def references = Set.empty } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = Set.empty } case class Distinct(child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = child.outputSet } case object NoRelation extends LeafNode { @@ -213,5 +195,4 @@ case object NoRelation extends LeafNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output = left.output - override def references = Set.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 7146fbd540f29c594ef2f589795f0dbafa02c157..72b0c5c8e7a2641b807e5913085e8cba0c461ca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData { - - def references = sortExpressions.flatMap(_.references).toSet } case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData { - - def references = partitionExpressions.flatMap(_.references).toSet } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4bb022cf238af5190c28a9c76815bcb29c8b5e90..ccb0df113c0632478ef2096057c45e68bd4568f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + // TODO: This is not really valid... def clustering = ordering.map(_.child).toSet } @@ -139,7 +140,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) with Partitioning { override def children = expressions - override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType @@ -179,7 +179,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) with Partitioning { override def children = ordering - override def references = ordering.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6344874538d675e78ec94dafee098965f05cadb7..296202543e2ca97d98bd8e947079c6f0f266e381 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq - def references = Set.empty[Attribute] def nullable = true def dataType = NullType override lazy val resolved = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8a9f4deb6a19ef88357ab70fb09d8a14ebc82ce8..6f0eed3f63c41b6f0848b73a8548cdfbdc4033ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -344,8 +344,8 @@ class SQLContext(@transient val sparkContext: SparkContext) prunePushedDownFilters: Seq[Expression] => Seq[Expression], scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - val projectSet = projectList.flatMap(_.references).toSet - val filterSet = filterPredicates.flatMap(_.references).toSet + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) // Right now we still use a projection even if the only evaluation is applying an alias @@ -354,7 +354,8 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: Decouple final output schema from expression evaluation so this copy can be // avoided safely. - if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) { + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index e63b4903041f6f8941c37e6548392c9f546e674a..24e88eea3189e148803c097bfefaed3676b426ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -79,8 +79,6 @@ private[sql] case class InMemoryRelation( override def children = Seq.empty - override def references = Set.empty - override def newInstance() = { new InMemoryRelation( output.map(_.newInstance), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21cbbc9772a00a3da24c828eb5ddbffeab7d586b..7d33ea5b021e26c17659afb1b0cc33147d4f7e5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ extends LogicalPlan with MultiInstanceRelation { def output = alreadyPlanned.output - override def references = Set.empty override def children = Nil - override final def newInstance: this.type = { + override final def newInstance(): this.type = { SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index f31df051824d72eeb8b1127322130daa38f7ed46..5b896c55b73938aa18325a1c316509b462e73eef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -58,8 +58,6 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def references = Set.empty - def output = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index b92091b560b1c4166751b5ee7417029ec5d35832..aef6ebf86b1ebd08a9c04e210ef75bec1d1d8839 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -49,7 +49,6 @@ private[spark] case class PythonUDF( override def toString = s"PythonUDF#$name(${children.mkString(",")})" def nullable: Boolean = true - def references: Set[Attribute] = children.flatMap(_.references).toSet override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") } @@ -113,7 +112,6 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() - def references = Set.empty def output = child.output :+ resultAttribute } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 389ace726d205d734859ca4f06d23b446c15047d..10fa8314c9156e01c23224e085fb5e7bf4d26a85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -79,9 +79,9 @@ private[hive] trait HiveStrategies { hiveContext.convertMetastoreParquet => // Filter out all predicates that only deal with partition keys - val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val partitionsKeys = AttributeSet(relation.partitionKeys) val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.map(_.exprId).subsetOf(partitionKeyIds) + _.references.subsetOf(partitionsKeys) } // We are going to throw the predicates and projection back at the whole optimization @@ -176,9 +176,9 @@ private[hive] trait HiveStrategies { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. - val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val partitionKeyIds = AttributeSet(relation.partitionKeys) val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.map(_.exprId).subsetOf(partitionKeyIds) + _.references.subsetOf(partitionKeyIds) } pruneFilterProject( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index c6497a15efa0c0d4ddf61e510bf7db1db0a37dfe..7d1ad53d8bdb3eadb0c54316391b3748c290bffb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type EvaluatedType = Any def nullable = true - def references = children.flatMap(_.references).toSet lazy val function = createFunction[UDFType]() @@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" def newInstance() = new HiveUdafFunction(functionClassName, children, this) @@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf( children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { - override def references = children.flatMap(_.references).toSet - @transient protected lazy val function: GenericUDTF = createFunction() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 6b3ffd1c0ffe2b395271dd061a5f476caf88021c..b6be6bc1bfefed9f85bde48acc52da15402c4a8d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) case class Nested(a: Int, B: Int) +case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest { .registerTempTable("caseSensitivityTest") sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + + println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution) + + sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect() + + // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a") + } test("nested repeated resolution") {