diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index a2df51e598a2b31fb972ba2d1edd7a527c86dc37..97502ed3afe72af3c829a1ba1741e9b894b834c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -85,7 +85,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
    * @param f the function to be applied to each node in the tree.
    */
   def foreachUp(f: BaseType => Unit): Unit = {
-    children.foreach(_.foreach(f))
+    children.foreach(_.foreachUp(f))
     f(this)
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 4eb8708335dcf410920484827e0bdf22f445347d..6b393327cc97afc70588491b766225eb395c9b87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -117,5 +117,17 @@ class TreeNodeSuite extends FunSuite {
     assert(transformed.origin.startPosition.isDefined)
   }
 
+  test("foreach up") {
+    val actual = new ArrayBuffer[String]()
+    val expected = Seq("1", "2", "3", "4", "-", "*", "+")
+    val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
+    expression foreachUp {
+      case b: BinaryExpression => actual.append(b.symbol);
+      case l: Literal => actual.append(l.toString);
+    }
+
+    assert(expected === actual)
+  }
+
 
 }