diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index f3cd095d676035deca57390b0284a57902345d31..7ec75fb9d4de0f91ac457ba41d67777c4a53a068 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1494,6 +1494,17 @@ impl Node { } } + pub fn try_get_binary_op(&self) -> Option<BinaryOperator> { + match self { + Node::Binary { + left: _, + right: _, + op, + } => Some(*op), + _ => None, + } + } + pub fn try_ternary(&self, bop: TernaryOperator) -> Option<(NodeID, NodeID, NodeID)> { if let Node::Ternary { first, diff --git a/hercules_opt/src/reassociate.rs b/hercules_opt/src/reassociate.rs index f876c55a854d9d6ab1b183145e3f997f47be56d3..e667f6d59b4ee3e84a702a3771cb0c3949c864c2 100644 --- a/hercules_opt/src/reassociate.rs +++ b/hercules_opt/src/reassociate.rs @@ -3,35 +3,164 @@ use hercules_ir::ir::*; use crate::*; +fn last<T: Clone>(c: &Vec<T>) -> Option<T> { + if c.len() > 0 { + c.get(c.len() - 1).cloned() + } else { + None + } +} + /* - * Top level function to run dead code elimination. + * Returns the longest addition or multiplication chain starting at the given + * start node. */ -pub fn reassociate(editor: &mut FunctionEditor) { - panic!("Reassociate was called"); - // Create worklist (starts as all nodes). - let mut worklist: Vec<NodeID> = editor.node_ids().collect(); - - while let Some(work) = worklist.pop() { - // If a node on the worklist is a start node, it is either *the* start - // node (which we shouldn't delete), or is a gravestone for an already - // deleted node earlier in the worklist. If a node is a return node, it - // shouldn't be removed. - if editor.func().nodes[work.idx()].is_start() || editor.func().nodes[work.idx()].is_return() - { - continue; - } +fn longest_chain_from_node(edit: &mut FunctionEdit, start: NodeID) -> Vec<NodeID> { + let supported_operators = vec![BinaryOperator::Add, BinaryOperator::Mul]; + let operator_chains = supported_operators + .iter() + .map(|chain_operator| { + let chain_operator = *chain_operator; + let mut chain = vec![]; + let mut current = start; - // If a node on the worklist has 0 users, delete it. Add its uses onto - // the worklist. - if editor.get_users(work).len() == 0 { - let uses = get_uses(&editor.func().nodes[work.idx()]); - let success = editor.edit(|edit| edit.delete_node(work)); - if success { - // If the edit was performed, then its uses may now be dead. - for u in uses.as_ref() { - worklist.push(*u); + while let Some((lhs, rhs)) = edit.get_node(current).try_binary(chain_operator) // must be matching operator + && (last(&chain) + .map(|prev| { + // either LHS or RHS must be from chain, and the other side + // cannot be the same operator. This ensures we do not + // try to convert reduction trees into chains. + (prev == lhs && !edit.get_node(rhs).try_binary(chain_operator).is_some()) + ^ (prev == rhs && !edit.get_node(lhs).try_binary(chain_operator).is_some()) + }) + .unwrap_or_else(|| true)) + { + chain.push(current); + let uses = edit.get_users(current).collect::<Vec<_>>(); + if uses.len() == 1 { + // A node in a chain can only be used by one thing, + // otherwise it would break the "chain" structure + current = uses.get(0).cloned().unwrap() + } else { + break; } } - } + + chain + }) + .collect::<Vec<_>>(); + return operator_chains + .iter() + .max_by_key(|v| v.len()) + .cloned() + .unwrap(); +} + +/* + * Given a chain of operations which is assumed to be associative, + * this will attempt to convert the chain into a reduction tree. + * Returns the NodeID who + */ +fn build_reduction_tree( + edit: &mut FunctionEdit, + inputs: Vec<NodeID>, + op: BinaryOperator, +) -> NodeID { + if inputs.len() == 2 { + edit.add_node(Node::Binary { + left: inputs[0], + right: inputs[1], + op, + }) + } else { + let left_inputs = inputs[0..inputs.len() / 2].to_vec(); + let right_inputs = inputs[inputs.len() / 2..].to_vec(); + let left = build_reduction_tree(edit, left_inputs, op); + let right = build_reduction_tree(edit, right_inputs, op); + edit.add_node(Node::Binary { left, right, op }) } } + +fn reassociate_chain<'a, 'b>( + mut edit: FunctionEdit<'a, 'b>, + chain: &Vec<NodeID>, +) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>> { + // TODO: handle chains that are not length power of two + // to do this we pick largest n such that 2^n <= len(chain) + // and only consider the first 2^n elements of chain + let chain = chain.clone(); + + let op = edit + .get_node( + chain + .get(0) + .cloned() + .expect("Empty chain found during reassociating!"), + ) + .try_get_binary_op() + .expect("Chain begins with a non-binary node"); + + // 1. Get the inputs to the chain. Remember that a chain is formed + // by operating on several values together in series, so each node + // has two inputs, one being the previous intermediate result, + // and one the new value we want to include in our computation. + let inputs = chain + .iter() + .map(|node_id| match get_uses(edit.get_node(*node_id)) { + NodeUses::Two([lhs, rhs]) => { + if edit.get_node(lhs).try_binary(op).is_none() { + lhs + } else if edit.get_node(rhs).try_binary(op).is_none() { + rhs + } else { + panic!("Node in chain has no valid inputs for reassociation"); + } + } + _ => panic!( + "Unexpected node found in chain: {:?}", + edit.get_node(*node_id) + ), + }) + .collect::<Vec<_>>(); + + // 2. Add the reduction tree over the inputs. At this point, we expect + // to have both the chain and the tree in our IR still. + let new_output = build_reduction_tree(&mut edit, inputs, op); + + // 3. Use the reduction tree output now instead. + let old_output = last(&chain).expect("Empty chain"); + edit = edit.replace_all_uses(old_output, new_output)?; + + // 4. Delete the now unnecessary nodes in the chain + for node_id in chain.iter() { + edit = edit.delete_node(*node_id)?; + } + + Ok(edit) +} + +/* + * Top level function that reassociates chains of binops + * into reduction trees. + */ +pub fn reassociate(editor: &mut FunctionEditor) { + // A chain of length n begins with some node 0 such that + // for nodes 1...n node i only uses node i-1 and (any node that is not 0..i-1). + // A chain is maximal if it cannot be extended. + // We want to collapse each maximal chain into a reduction tree. + + let num_nodes = editor.func().nodes.len(); + + editor.edit(|mut edit| { + let longest_chain = (0..num_nodes) + .map(|idx| longest_chain_from_node(&mut edit, NodeID::new(idx))) + .max_by_key(|chain| chain.len()) + .expect("Function has no nodes!"); + + if longest_chain.len() >= 3 { + reassociate_chain(edit, &longest_chain) + } else { + Ok(edit) + } + }); +} diff --git a/juno_samples/grape_reduction/src/grape.sch b/juno_samples/grape_reduction/src/grape.sch index 3be751171c40343a0992bb4b7d01bf700261123d..ed51380ae0c075cd97d92fc98d71a878a76464d0 100644 --- a/juno_samples/grape_reduction/src/grape.sch +++ b/juno_samples/grape_reduction/src/grape.sch @@ -36,6 +36,7 @@ xdot[true](*); // reassociate go brr reassociate(*); +xdot[true](*); gvn(*); phi-elim(*);