Skip to content
Snippets Groups Projects
Commit 7cad03ef authored by Ryan Ziegler's avatar Ryan Ziegler
Browse files

reassociation go brr

parent d462a7e9
No related branches found
No related tags found
No related merge requests found
Pipeline #202567 failed
......@@ -1494,6 +1494,17 @@ impl Node {
}
}
pub fn try_get_binary_op(&self) -> Option<BinaryOperator> {
match self {
Node::Binary {
left: _,
right: _,
op,
} => Some(*op),
_ => None,
}
}
pub fn try_ternary(&self, bop: TernaryOperator) -> Option<(NodeID, NodeID, NodeID)> {
if let Node::Ternary {
first,
......
......@@ -3,35 +3,164 @@ use hercules_ir::ir::*;
use crate::*;
fn last<T: Clone>(c: &Vec<T>) -> Option<T> {
if c.len() > 0 {
c.get(c.len() - 1).cloned()
} else {
None
}
}
/*
* Top level function to run dead code elimination.
* Returns the longest addition or multiplication chain starting at the given
* start node.
*/
pub fn reassociate(editor: &mut FunctionEditor) {
panic!("Reassociate was called");
// Create worklist (starts as all nodes).
let mut worklist: Vec<NodeID> = editor.node_ids().collect();
while let Some(work) = worklist.pop() {
// If a node on the worklist is a start node, it is either *the* start
// node (which we shouldn't delete), or is a gravestone for an already
// deleted node earlier in the worklist. If a node is a return node, it
// shouldn't be removed.
if editor.func().nodes[work.idx()].is_start() || editor.func().nodes[work.idx()].is_return()
{
continue;
}
fn longest_chain_from_node(edit: &mut FunctionEdit, start: NodeID) -> Vec<NodeID> {
let supported_operators = vec![BinaryOperator::Add, BinaryOperator::Mul];
let operator_chains = supported_operators
.iter()
.map(|chain_operator| {
let chain_operator = *chain_operator;
let mut chain = vec![];
let mut current = start;
// If a node on the worklist has 0 users, delete it. Add its uses onto
// the worklist.
if editor.get_users(work).len() == 0 {
let uses = get_uses(&editor.func().nodes[work.idx()]);
let success = editor.edit(|edit| edit.delete_node(work));
if success {
// If the edit was performed, then its uses may now be dead.
for u in uses.as_ref() {
worklist.push(*u);
while let Some((lhs, rhs)) = edit.get_node(current).try_binary(chain_operator) // must be matching operator
&& (last(&chain)
.map(|prev| {
// either LHS or RHS must be from chain, and the other side
// cannot be the same operator. This ensures we do not
// try to convert reduction trees into chains.
(prev == lhs && !edit.get_node(rhs).try_binary(chain_operator).is_some())
^ (prev == rhs && !edit.get_node(lhs).try_binary(chain_operator).is_some())
})
.unwrap_or_else(|| true))
{
chain.push(current);
let uses = edit.get_users(current).collect::<Vec<_>>();
if uses.len() == 1 {
// A node in a chain can only be used by one thing,
// otherwise it would break the "chain" structure
current = uses.get(0).cloned().unwrap()
} else {
break;
}
}
}
chain
})
.collect::<Vec<_>>();
return operator_chains
.iter()
.max_by_key(|v| v.len())
.cloned()
.unwrap();
}
/*
* Given a chain of operations which is assumed to be associative,
* this will attempt to convert the chain into a reduction tree.
* Returns the NodeID who
*/
fn build_reduction_tree(
edit: &mut FunctionEdit,
inputs: Vec<NodeID>,
op: BinaryOperator,
) -> NodeID {
if inputs.len() == 2 {
edit.add_node(Node::Binary {
left: inputs[0],
right: inputs[1],
op,
})
} else {
let left_inputs = inputs[0..inputs.len() / 2].to_vec();
let right_inputs = inputs[inputs.len() / 2..].to_vec();
let left = build_reduction_tree(edit, left_inputs, op);
let right = build_reduction_tree(edit, right_inputs, op);
edit.add_node(Node::Binary { left, right, op })
}
}
fn reassociate_chain<'a, 'b>(
mut edit: FunctionEdit<'a, 'b>,
chain: &Vec<NodeID>,
) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>> {
// TODO: handle chains that are not length power of two
// to do this we pick largest n such that 2^n <= len(chain)
// and only consider the first 2^n elements of chain
let chain = chain.clone();
let op = edit
.get_node(
chain
.get(0)
.cloned()
.expect("Empty chain found during reassociating!"),
)
.try_get_binary_op()
.expect("Chain begins with a non-binary node");
// 1. Get the inputs to the chain. Remember that a chain is formed
// by operating on several values together in series, so each node
// has two inputs, one being the previous intermediate result,
// and one the new value we want to include in our computation.
let inputs = chain
.iter()
.map(|node_id| match get_uses(edit.get_node(*node_id)) {
NodeUses::Two([lhs, rhs]) => {
if edit.get_node(lhs).try_binary(op).is_none() {
lhs
} else if edit.get_node(rhs).try_binary(op).is_none() {
rhs
} else {
panic!("Node in chain has no valid inputs for reassociation");
}
}
_ => panic!(
"Unexpected node found in chain: {:?}",
edit.get_node(*node_id)
),
})
.collect::<Vec<_>>();
// 2. Add the reduction tree over the inputs. At this point, we expect
// to have both the chain and the tree in our IR still.
let new_output = build_reduction_tree(&mut edit, inputs, op);
// 3. Use the reduction tree output now instead.
let old_output = last(&chain).expect("Empty chain");
edit = edit.replace_all_uses(old_output, new_output)?;
// 4. Delete the now unnecessary nodes in the chain
for node_id in chain.iter() {
edit = edit.delete_node(*node_id)?;
}
Ok(edit)
}
/*
* Top level function that reassociates chains of binops
* into reduction trees.
*/
pub fn reassociate(editor: &mut FunctionEditor) {
// A chain of length n begins with some node 0 such that
// for nodes 1...n node i only uses node i-1 and (any node that is not 0..i-1).
// A chain is maximal if it cannot be extended.
// We want to collapse each maximal chain into a reduction tree.
let num_nodes = editor.func().nodes.len();
editor.edit(|mut edit| {
let longest_chain = (0..num_nodes)
.map(|idx| longest_chain_from_node(&mut edit, NodeID::new(idx)))
.max_by_key(|chain| chain.len())
.expect("Function has no nodes!");
if longest_chain.len() >= 3 {
reassociate_chain(edit, &longest_chain)
} else {
Ok(edit)
}
});
}
......@@ -36,6 +36,7 @@ xdot[true](*);
// reassociate go brr
reassociate(*);
xdot[true](*);
gvn(*);
phi-elim(*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment