From 55dfd5dcdbf3a9bfddb2108c8325bda3100eb33d Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@apache.org>
Date: Mon, 7 Apr 2014 18:39:18 -0700
Subject: [PATCH] Removed the default eval implementation from Expression, and
 added a bunch of override's in classes I touched.

It is more robust to not provide a default implementation for Expression's.

Author: Reynold Xin <rxin@apache.org>

Closes #350 from rxin/eval-default and squashes the following commits:

0a83b8f [Reynold Xin] Removed the default eval implementation from Expression, and added a bunch of override's in classes I touched.
---
 .../sql/catalyst/analysis/unresolved.scala    | 52 ++++++++++++-------
 .../sql/catalyst/expressions/Expression.scala |  3 +-
 .../sql/catalyst/expressions/SortOrder.scala  | 11 +++-
 .../sql/catalyst/expressions/aggregates.scala |  8 +++
 .../expressions/namedExpressions.scala        | 21 +++++---
 .../plans/physical/partitioning.scala         | 32 ++++++++----
 .../ExpressionEvaluationSuite.scala           |  5 +-
 .../optimizer/ConstantFoldingSuite.scala      |  2 +-
 8 files changed, 89 insertions(+), 45 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 41e9bcef3c..d629172a74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.catalyst.{errors, trees}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
 import org.apache.spark.sql.catalyst.trees.TreeNode
 
@@ -36,7 +37,7 @@ case class UnresolvedRelation(
     databaseName: Option[String],
     tableName: String,
     alias: Option[String] = None) extends BaseRelation {
-  def output = Nil
+  override def output = Nil
   override lazy val resolved = false
 }
 
@@ -44,26 +45,33 @@ case class UnresolvedRelation(
  * Holds the name of an attribute that has yet to be resolved.
  */
 case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
-  def exprId = throw new UnresolvedException(this, "exprId")
-  def dataType = throw new UnresolvedException(this, "dataType")
-  def nullable = throw new UnresolvedException(this, "nullable")
-  def qualifiers = throw new UnresolvedException(this, "qualifiers")
+  override def exprId = throw new UnresolvedException(this, "exprId")
+  override def dataType = throw new UnresolvedException(this, "dataType")
+  override def nullable = throw new UnresolvedException(this, "nullable")
+  override def qualifiers = throw new UnresolvedException(this, "qualifiers")
   override lazy val resolved = false
 
-  def newInstance = this
-  def withQualifiers(newQualifiers: Seq[String]) = this
+  override def newInstance = this
+  override def withQualifiers(newQualifiers: Seq[String]) = this
+
+  // Unresolved attributes are transient at compile time and don't get evaluated during execution.
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
 
   override def toString: String = s"'$name"
 }
 
 case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
-  def exprId = throw new UnresolvedException(this, "exprId")
-  def dataType = throw new UnresolvedException(this, "dataType")
+  override def dataType = throw new UnresolvedException(this, "dataType")
   override def foldable = throw new UnresolvedException(this, "foldable")
-  def nullable = throw new UnresolvedException(this, "nullable")
-  def qualifiers = throw new UnresolvedException(this, "qualifiers")
-  def references = children.flatMap(_.references).toSet
+  override def nullable = throw new UnresolvedException(this, "nullable")
+  override def references = children.flatMap(_.references).toSet
   override lazy val resolved = false
+
+  // Unresolved functions are transient at compile time and don't get evaluated during execution.
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
   override def toString = s"'$name(${children.mkString(",")})"
 }
 
@@ -79,15 +87,15 @@ case class Star(
     mapFunction: Attribute => Expression = identity[Attribute])
   extends Attribute with trees.LeafNode[Expression] {
 
-  def name = throw new UnresolvedException(this, "exprId")
-  def exprId = throw new UnresolvedException(this, "exprId")
-  def dataType = throw new UnresolvedException(this, "dataType")
-  def nullable = throw new UnresolvedException(this, "nullable")
-  def qualifiers = throw new UnresolvedException(this, "qualifiers")
+  override def name = throw new UnresolvedException(this, "exprId")
+  override def exprId = throw new UnresolvedException(this, "exprId")
+  override def dataType = throw new UnresolvedException(this, "dataType")
+  override def nullable = throw new UnresolvedException(this, "nullable")
+  override def qualifiers = throw new UnresolvedException(this, "qualifiers")
   override lazy val resolved = false
 
-  def newInstance = this
-  def withQualifiers(newQualifiers: Seq[String]) = this
+  override def newInstance = this
+  override def withQualifiers(newQualifiers: Seq[String]) = this
 
   def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
     val expandedAttributes: Seq[Attribute] = table match {
@@ -104,5 +112,9 @@ case class Star(
     mappedAttributes
   }
 
+  // Star gets expanded at runtime so we never evaluate a Star.
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
   override def toString = table.map(_ + ".").getOrElse("") + "*"
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index f190bd0cca..8a1db8e796 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -50,8 +50,7 @@ abstract class Expression extends TreeNode[Expression] {
   def references: Set[Attribute]
 
   /** Returns the result of evaluating this expression on a given input Row */
-  def eval(input: Row = null): EvaluatedType =
-    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+  def eval(input: Row = null): EvaluatedType
 
   /**
    * Returns `true` if this expression and all its children have been resolved to a specific schema
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d5d93778f4..08b2f11d20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+
 abstract sealed class SortDirection
 case object Ascending extends SortDirection
 case object Descending extends SortDirection
@@ -26,7 +28,12 @@ case object Descending extends SortDirection
  * transformations over expression will descend into its child.
  */
 case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
-  def dataType = child.dataType
-  def nullable = child.nullable
+  override def dataType = child.dataType
+  override def nullable = child.nullable
+
+  // SortOrder itself is never evaluated.
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
   override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5edcea1427..b152f95f96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.types._
 import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
 
 abstract class AggregateExpression extends Expression {
   self: Product =>
@@ -28,6 +29,13 @@ abstract class AggregateExpression extends Expression {
    * of input rows/
    */
   def newInstance(): AggregateFunction
+
+  /**
+   * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are
+   * replaced with a physical aggregate operator at runtime.
+   */
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index eb4bc8e755..a8145c37c2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.trees
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.types._
 
 object NamedExpression {
@@ -58,9 +59,9 @@ abstract class Attribute extends NamedExpression {
 
   def withQualifiers(newQualifiers: Seq[String]): Attribute
 
-  def references = Set(this)
   def toAttribute = this
   def newInstance: Attribute
+  override def references = Set(this)
 }
 
 /**
@@ -77,15 +78,15 @@ case class Alias(child: Expression, name: String)
     (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
   extends NamedExpression with trees.UnaryNode[Expression] {
 
-  type EvaluatedType = Any
+  override type EvaluatedType = Any
 
   override def eval(input: Row) = child.eval(input)
 
-  def dataType = child.dataType
-  def nullable = child.nullable
-  def references = child.references
+  override def dataType = child.dataType
+  override def nullable = child.nullable
+  override def references = child.references
 
-  def toAttribute = {
+  override def toAttribute = {
     if (resolved) {
       AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers)
     } else {
@@ -127,7 +128,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
     h
   }
 
-  def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
+  override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
 
   /**
    * Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -143,7 +144,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
   /**
    * Returns a copy of this [[AttributeReference]] with new qualifiers.
    */
-  def withQualifiers(newQualifiers: Seq[String]) = {
+  override def withQualifiers(newQualifiers: Seq[String]) = {
     if (newQualifiers == qualifiers) {
       this
     } else {
@@ -151,5 +152,9 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
     }
   }
 
+  // Unresolved attributes are transient at compile time and don't get evaluated during execution.
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
   override def toString: String = s"$name#${exprId.id}$typeSuffix"
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 8893744eb2..ffb3a92f8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans.physical
 
-import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder}
 import org.apache.spark.sql.catalyst.types.IntegerType
 
 /**
@@ -139,12 +140,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
   extends Expression
   with Partitioning {
 
-  def children = expressions
-  def references = expressions.flatMap(_.references).toSet
-  def nullable = false
-  def dataType = IntegerType
+  override def children = expressions
+  override def references = expressions.flatMap(_.references).toSet
+  override def nullable = false
+  override def dataType = IntegerType
 
-  lazy val clusteringSet = expressions.toSet
+  private[this] lazy val clusteringSet = expressions.toSet
 
   override def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
@@ -158,6 +159,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
     case h: HashPartitioning if h == this => true
     case _ => false
   }
+
+  override def eval(input: Row = null): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
 }
 
 /**
@@ -168,17 +172,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
  *    partition.
  *  - Each partition will have a `min` and `max` row, relative to the given ordering.  All rows
  *    that are in between `min` and `max` in this `ordering` will reside in this partition.
+ *
+ * This class extends expression primarily so that transformations over expression will descend
+ * into its child.
  */
 case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
   extends Expression
   with Partitioning {
 
-  def children = ordering
-  def references = ordering.flatMap(_.references).toSet
-  def nullable = false
-  def dataType = IntegerType
+  override def children = ordering
+  override def references = ordering.flatMap(_.references).toSet
+  override def nullable = false
+  override def dataType = IntegerType
 
-  lazy val clusteringSet = ordering.map(_.child).toSet
+  private[this] lazy val clusteringSet = ordering.map(_.child).toSet
 
   override def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
@@ -195,4 +202,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
     case r: RangePartitioning if r == this => true
     case _ => false
   }
+
+  override def eval(input: Row): EvaluatedType =
+    throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 92987405aa..31be6c4ef1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -100,7 +100,10 @@ class ExpressionEvaluationSuite extends FunSuite {
     (null,  false, null) ::
     (null,  null,  null) :: Nil)
 
-  def booleanLogicTest(name: String, op: (Expression, Expression) => Expression,  truthTable: Seq[(Any, Any, Any)]) {
+  def booleanLogicTest(
+      name: String,
+      op: (Expression, Expression) => Expression,
+      truthTable: Seq[(Any, Any, Any)]) {
     test(s"3VL $name") {
       truthTable.foreach {
         case (l,r,answer) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 2ab14f48cc..20dfba8477 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.types.IntegerType
+import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType}
 
 // For implicit conversions
 import org.apache.spark.sql.catalyst.dsl.plans._
-- 
GitLab