Skip to content
Snippets Groups Projects
Commit 0c1b2df0 authored by Michael Davies's avatar Michael Davies Committed by Michael Armbrust
Browse files

[SPARK-8077] [SQL] Optimization for TreeNodes with large numbers of children

For example large IN clauses

Large IN clauses are parsed very slowly. For example SQL below (10K items in IN) takes 45-50s.

s"""SELECT * FROM Person WHERE ForeName IN ('${(1 to 10000).map("n" + _).mkString("','")}')"""

This is principally due to TreeNode which repeatedly call contains on children, where children in this case is a List that is 10K long. In effect parsing for large IN clauses is O(N squared).
A lazily initialised Set based on children for contains reduces parse time to around 2.5s

Author: Michael Davies <Michael.BellDavies@gmail.com>

Closes #6673 from MickDavies/SPARK-8077 and squashes the following commits:

38cd425 [Michael Davies] SPARK-8077: Optimization for  TreeNodes with large numbers of children
d80103b [Michael Davies] SPARK-8077: Optimization for  TreeNodes with large numbers of children
e6be8be [Michael Davies] SPARK-8077: Optimization for  TreeNodes with large numbers of children
parent 50a0496a
No related branches found
No related tags found
No related merge requests found
......@@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
val input = children.flatMap(_.output)
productIterator.map {
// Children are checked using sameResult above.
case tn: TreeNode[_] if children contains tn => null
case tn: TreeNode[_] if containsChild(tn) => null
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
case s: Option[_] => s.map {
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
......
......@@ -59,9 +59,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val origin: Origin = CurrentOrigin.get
/** Returns a Seq of the children of this node */
/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
*/
def children: Seq[BaseType]
lazy val containsChild: Set[TreeNode[_]] = children.toSet
/**
* Faster version of equality which short-circuits when two treeNodes are the same instance.
* We don't just override Object.equals, as doing so prevents the scala compiler from
......@@ -147,7 +152,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def mapChildren(f: BaseType => BaseType): this.type = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
if (newChild fastEquals arg) {
arg
......@@ -173,7 +178,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
val newArgs = productIterator.map {
// Handle Seq[TreeNode] in TreeNode parameters.
case s: Seq[_] => s.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
......@@ -185,7 +190,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case nonChild: AnyRef => nonChild
case null => null
}
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
......@@ -238,7 +243,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
changed = true
......@@ -246,7 +251,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} else {
arg
}
case Some(arg: TreeNode[_]) if children contains arg =>
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
changed = true
......@@ -257,7 +262,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
if (!(newChild fastEquals arg)) {
changed = true
......@@ -295,7 +300,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
changed = true
......@@ -303,7 +308,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} else {
arg
}
case Some(arg: TreeNode[_]) if children contains arg =>
case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
changed = true
......@@ -314,7 +319,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
if (!(newChild fastEquals arg)) {
changed = true
......@@ -383,7 +388,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
/** Returns a string representing the arguments to this node, minus any children */
def argString: String = productIterator.flatMap {
case tn: TreeNode[_] if children contains tn => Nil
case tn: TreeNode[_] if containsChild(tn) => Nil
case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil
case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil
case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment