From d96608674f6c2ff3abb13c65d80c1a3872206710 Mon Sep 17 00:00:00 2001
From: scwf <wangfei1@huawei.com>
Date: Thu, 16 Apr 2015 17:35:51 -0700
Subject: [PATCH] [SQL][Minor] Fix foreachUp of treenode

`foreachUp` should runs the given function recursively on [[children]] then on this node(just like transformUp). The current implementation does not follow this.

This will leads to checkanalysis do not check from bottom of logical tree.

Author: scwf <wangfei1@huawei.com>
Author: Fei Wang <wangfei1@huawei.com>

Closes #5518 from scwf/patch-1 and squashes the following commits:

18e28b2 [scwf] added a test case
1ccbfa8 [Fei Wang] fix foreachUp
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala   |  2 +-
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala     | 12 ++++++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index a2df51e598..97502ed3af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -85,7 +85,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
    * @param f the function to be applied to each node in the tree.
    */
   def foreachUp(f: BaseType => Unit): Unit = {
-    children.foreach(_.foreach(f))
+    children.foreach(_.foreachUp(f))
     f(this)
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 4eb8708335..6b393327cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -117,5 +117,17 @@ class TreeNodeSuite extends FunSuite {
     assert(transformed.origin.startPosition.isDefined)
   }
 
+  test("foreach up") {
+    val actual = new ArrayBuffer[String]()
+    val expected = Seq("1", "2", "3", "4", "-", "*", "+")
+    val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
+    expression foreachUp {
+      case b: BinaryExpression => actual.append(b.symbol);
+      case l: Literal => actual.append(l.toString);
+    }
+
+    assert(expected === actual)
+  }
+
 
 }
-- 
GitLab