diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f2e579afe833ae712a86302262617e0f2abcb85e..7089f079b6dde59586d114c1fbb118fcd3d3c6aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -50,8 +50,7 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(nameParts: Seq[String]) - extends Attribute with trees.LeafNode[Expression] { +case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute { def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") @@ -96,7 +95,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. */ -trait Star extends NamedExpression with trees.LeafNode[Expression] { +abstract class Star extends LeafExpression with NamedExpression { self: Product => override def name: String = throw new UnresolvedException(this, "name") @@ -151,7 +150,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends NamedExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with NamedExpression { override def name: String = throw new UnresolvedException(this, "name") @@ -210,8 +209,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) /** * Holds the expression that has yet to be aliased. */ -case class UnresolvedAlias(child: Expression) extends NamedExpression - with trees.UnaryNode[Expression] { +case class UnresolvedAlias(child: Expression) extends UnaryExpression with NamedExpression { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 3f0d7b803125f531b75a21472fb43431bd47bb8a..b09aea03318dabd6852caedb082f4d64b9f29a83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends NamedExpression with trees.LeafNode[Expression] { + extends LeafExpression with NamedExpression { override def toString: String = s"input[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a655cc8e48ae15ded74aee2c2e14d15125c769f9..f396bd08a823863e2fc1a6aeeda77c5843165660 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 8ab4ef060b68ce01be996fdd55eb2f9b7123724d..b8f7068c9e5e5dac2e5a85ed8757548f8fc69a78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -30,8 +30,10 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression - with trees.UnaryNode[Expression] { +case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { + + /** Sort order is not foldable because we don't have an eval for it. */ + override def foldable: Boolean = false override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index c0e17f97e9b3c514df9c78542121534cbdeee587..71c943dc79e9eda97f83b465f5c8a1ef16935226 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -20,16 +20,20 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -abstract class AggregateExpression extends Expression { +trait AggregateExpression extends Expression { self: Product => + /** + * Aggregate expressions should not be foldable. + */ + override def foldable: Boolean = false + /** * Creates a new instance that can be used to compute this aggregate expression for a group * of input rows/ @@ -60,7 +64,7 @@ case class SplitEvaluation( * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ -abstract class PartialAggregate extends AggregateExpression { +trait PartialAggregate extends AggregateExpression { self: Product => /** @@ -74,7 +78,7 @@ abstract class PartialAggregate extends AggregateExpression { * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. */ abstract class AggregateFunction - extends AggregateExpression with Serializable with trees.LeafNode[Expression] { + extends LeafExpression with AggregateExpression with Serializable { self: Product => /** Base should return the generic aggregate expression that this function is computing */ @@ -91,7 +95,7 @@ abstract class AggregateFunction } } -case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Min(child: Expression) extends UnaryExpression with PartialAggregate { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -124,7 +128,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMin.value } -case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Max(child: Expression) extends UnaryExpression with PartialAggregate { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -157,7 +161,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMax.value } -case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Count(child: Expression) extends UnaryExpression with PartialAggregate { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -310,7 +314,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with AggregateExpression { override def nullable: Boolean = false override def dataType: DataType = HyperLogLogUDT @@ -340,7 +344,7 @@ case class ApproxCountDistinctPartitionFunction( } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with AggregateExpression { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -368,7 +372,7 @@ case class ApproxCountDistinctMergeFunction( } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends PartialAggregate with trees.UnaryNode[Expression] { + extends UnaryExpression with PartialAggregate { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -386,7 +390,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } -case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Average(child: Expression) extends UnaryExpression with PartialAggregate { override def prettyName: String = "avg" @@ -479,7 +483,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Sum(child: Expression) extends UnaryExpression with PartialAggregate { override def nullable: Boolean = true @@ -606,8 +610,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { +case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate { def this() = this(null) override def nullable: Boolean = true @@ -701,7 +704,7 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class First(child: Expression) extends UnaryExpression with PartialAggregate { override def nullable: Boolean = true override def dataType: DataType = child.dataType override def toString: String = s"FIRST($child)" @@ -729,7 +732,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Last(child: Expression) extends UnaryExpression with PartialAggregate { override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b68d30a26abd8d1a0028468bb073f6799e234c7b..51dc77ee3fc5fdde6371a5b2316b3c8ef01b155f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -40,13 +40,14 @@ import org.apache.spark.sql.types._ * requested. The attributes produced by this function will be automatically copied anytime rules * result in changes to the Generator or its children. */ -abstract class Generator extends Expression { - self: Product => +trait Generator extends Expression { self: Product => // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. override def dataType: DataType = throw new UnsupportedOperationException + override def foldable: Boolean = false + override def nullable: Boolean = false /** @@ -99,8 +100,9 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(child: Expression) - extends Generator with trees.UnaryNode[Expression] { +case class Explode(child: Expression) extends UnaryExpression with Generator { + + override def children: Seq[Expression] = child :: Nil override def checkInputDataTypes(): TypeCheckResult = { if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { @@ -127,6 +129,4 @@ case class Explode(child: Expression) else inputMap.map { case (k, v) => InternalRow(k, v) } } } - - override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 6181c60c0e45307453b3ae597dbe229d80a68658..8bf7a7ce4e6470e6927296a1ab348aa9f182a4ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -37,8 +37,13 @@ object NamedExpression { */ case class ExprId(id: Long) -abstract class NamedExpression extends Expression { - self: Product => +/** + * An [[Expression]] that is named. + */ +trait NamedExpression extends Expression { self: Product => + + /** We should never fold named expressions in order to not remove the alias. */ + override def foldable: Boolean = false def name: String def exprId: ExprId @@ -78,8 +83,7 @@ abstract class NamedExpression extends Expression { } } -abstract class Attribute extends NamedExpression { - self: Product => +abstract class Attribute extends LeafExpression with NamedExpression { self: Product => override def references: AttributeSet = AttributeSet(this) @@ -110,7 +114,7 @@ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil, val explicitMetadata: Option[Metadata] = None) - extends NamedExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) override lazy val resolved = @@ -172,7 +176,8 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { + val qualifiers: Seq[String] = Nil) + extends Attribute { /** * Returns true iff the expression id is the same for both attributes. @@ -242,7 +247,7 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { +case class PrettyAttribute(name: String) extends Attribute { override def toString: String = name diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5d80214abf141c5a64fd08fac498ccf5893070ae..2f94b457f4cdcf2db9a383a082dd36033459b4d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -342,7 +342,7 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal.create(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { @@ -361,7 +361,7 @@ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => - val hSet = list.map(e => e.eval(null)) + val hSet = list.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index d7077a0ec907a0a9a78226141a6862f10b652888..adac37231cc4afcc8052bea380aa91732535a893 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -277,15 +276,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** * A logical plan node with no children. */ -abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { +abstract class LeafNode extends LogicalPlan { self: Product => + + override def children: Seq[LogicalPlan] = Nil } /** * A logical plan node with single child. */ -abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { +abstract class UnaryNode extends LogicalPlan { self: Product => + + def child: LogicalPlan + + override def children: Seq[LogicalPlan] = child :: Nil } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 16844b2f4b680fea1c7732591587c887422931a6..0f95ca688a7a8939d919228baadad9f5573209eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -452,19 +452,3 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { s"$nodeName(${args.mkString(",")})" } } - - -/** - * A [[TreeNode]] with no children. - */ -trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children: Seq[BaseType] = Nil -} - -/** - * A [[TreeNode]] with a single [[child]]. - */ -trait UnaryNode[BaseType <: TreeNode[BaseType]] { - def child: BaseType - def children: Seq[BaseType] = child :: Nil -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 9dc7879fa4a1ae604d66d6edbb605e1eee947106..632f633d82a2eb6032a89627e1adea44612686de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import scala.collection.mutable.ArrayBuffer - object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() } @@ -238,12 +238,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { +private[sql] trait LeafNode extends SparkPlan { self: Product => + + override def children: Seq[SparkPlan] = Nil } -private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { +private[sql] trait UnaryNode extends SparkPlan { self: Product => + + def child: SparkPlan + + override def children: Seq[SparkPlan] = child :: Nil + override def outputPartitioning: Partitioning = child.outputPartitioning }