Skip to content
Snippets Groups Projects
Commit 50544f5e authored by Xavier Routh's avatar Xavier Routh
Browse files

2nd version of backend

parent ce2af3ca
No related branches found
No related tags found
No related merge requests found
Pipeline #202337 failed
......@@ -153,7 +153,7 @@ pub fn grape_codegen<Writer: std::fmt::Write>(
backing_allocation: &FunctionBackingAllocation,
w: &mut Writer,
) -> Result<(), Error> {
let ctx = GRAPEContext::<24, 24, 1> {
let ctx = GRAPEContext::<20, 20, 1> {
module_name,
function,
types,
......@@ -214,7 +214,7 @@ where
// Schedule the function,
// while any node is not, and used slices is less than 4 < schedule live_outs.
let chip = self.schedule_slice(live_outs)?;
let chip = self.schedule_slice_toplevel()?;
// chip.0.pretty_print();
// println!("chip: {:?}", chip);
......@@ -286,6 +286,163 @@ where
.unwrap()
}
fn schedule_slice_toplevel(&self) -> Result<(SliceDesc<H, W>, HashMap<NodeID, usize>), Error> {
let mut functional_units = [[FunctionalUnit {
op_type: FuOp::Default,
}; W]; H];
let mut switchboxes = [[Switchbox {
output_wires: (100, 100),
}; W]; H - 1];
let config = SliceDesc { functional_units, switchboxes};
let mut prev_mapping = HashMap::new();
// Push paramter nodes!
for (col, node) in self.function.nodes.iter().enumerate().filter_map(|(node_id, node)|if node.is_parameter() { Some(NodeID::new(node_id))}else {None}).enumerate() {
prev_mapping.insert(node, col);
}
return self.schedule_row_recursive(0, prev_mapping, config);
}
fn schedule_row_recursive(
&self,
row: usize,
prev_mapping: HashMap<NodeID, usize>, // A node mapping for the previous row
mut config: SliceDesc<H, W>,
) -> Result<(SliceDesc<H, W>, HashMap<NodeID, usize>), Error> {
println!("entry @ row {:?}", row);
println!("prev_mapping: {:?}", prev_mapping);
if row == H {
// all uses of any return ndoes are in the prev mapping.
println!("final row");
let returns = self.function.nodes.iter().enumerate().filter_map(|(idx, node)| if node.is_return() { Some(NodeID::new(idx))} else { None});
let uses: Vec<NodeID> = returns.flat_map(|return_node| get_uses(&self.function.nodes[return_node.idx()]).as_ref().iter()
.filter(|node| !self.function.nodes[node.idx()].is_control())
.cloned().collect::<Vec<_>>()).collect();
println!("return_uses: {:?}", uses.clone());
if uses.iter().any(|node| !prev_mapping.contains_key(node)) {
return Err(Error)
} else {
return Ok((config, prev_mapping));
}
}
let constants: Vec<NodeID> = self.function.nodes.iter().enumerate().filter_map(|(id, data)| {
if data.is_constant() {
Some(NodeID::new(id))
} else { None }}
).collect();
// PRUNE constants that can't possibly be used
// PRUNE nodes in prev that don't have any users
let live_nodes: Vec<NodeID> = prev_mapping.iter().filter_map(|(id, _)| if self.def_use_map.get_users(*id).iter().count() > 0 { Some(*id)
} else {None}
)
.collect();
//
let users = prev_mapping.iter().flat_map(|(id, _)| self.def_use_map.get_users(*id).iter().cloned());
// Correctness:
// PRUNE nodes that don't have all their uses computed.
let users: Vec<NodeID> = users.filter(|id| get_uses(&self.function.nodes[id.idx()]).as_ref().iter()
.filter(|node| !self.function.nodes[node.idx()].is_control()) // We only care abt data nodes.
.all(|u| prev_mapping.contains_key(u))).collect();
// Heurestic (Speed): PRIORITIZE nodes that don't have all their users scheduled.
// we should first choose the set of computations of we want to do, and then the set of values
// that we want to keep live
// Q: Why not FORCE?, A: we can rematerialize, and sometimes that is correct in the optimal mapping
// We can also do this by sorting live_nodes by number of users (that aren't scheduled.)
// Do we have to keep track of everything we scheduled then? Maybe.
// Put users first, to prioritize new computation.
let mut choices: Vec<NodeID> = users.into_iter().chain(constants.into_iter()).chain(live_nodes.into_iter())
.filter(|choice| !self.function.nodes[choice.idx()].is_return())
.collect();
fn dedup<T: Eq + std::hash::Hash + Clone>(vec: &mut Vec<T>) {
let mut seen = HashSet::new();
vec.retain(|item| seen.insert(item.clone()));
}
dedup(&mut choices);
// Remove duplicates without changing order.
println!("num choices: {:?}", choices.clone().len());
println!("choices: {:?}", choices.clone());
// Order by node id, and start iterating permutations (NODES choose 8) for the next row.
for i in [0] {
let mut next_mapping = HashMap::new();
// Heurestic (Correctness): Just assume the first 8 are good choices.
for (col, node) in choices.iter().take(W).enumerate() {
// If the node was in the previous mapping and we chose to live, schedule it as a pass through
if prev_mapping.contains_key(node) {
next_mapping.insert(*node, col);
// Collect inputs
if row != 0 {
config.switchboxes[row-1][col].output_wires = (prev_mapping[node], 0);
}
config.functional_units[row][col].op_type = FuOp::PassA;
} else if self.function.nodes[node.idx()].is_constant() {
panic!()
} else {
// compute it
next_mapping.insert(*node, col);
// Collect inputs
let inputs: Vec<usize> = get_uses(&self.function.nodes[node.idx()]).as_ref().iter().map(|node| prev_mapping[node]).collect();
if row != 0 {
config.switchboxes[row-1][col].output_wires = (*(inputs.get(0).unwrap_or(&0)), *(inputs.get(1).unwrap_or(&0)));
}
let op_type = match self.function.nodes[node.idx()] {
Node::Binary { left, right, op } => match op {
BinaryOperator::Add => FuOp::Add,
BinaryOperator::Sub => todo!(),
BinaryOperator::Mul => FuOp::Mult,
BinaryOperator::Div => todo!(),
BinaryOperator::Rem => todo!(),
BinaryOperator::LT => todo!(),
BinaryOperator::LTE => todo!(),
BinaryOperator::GT => todo!(),
BinaryOperator::GTE => todo!(),
BinaryOperator::EQ => todo!(),
BinaryOperator::NE => todo!(),
BinaryOperator::Or => todo!(),
BinaryOperator::And => todo!(),
BinaryOperator::Xor => todo!(),
BinaryOperator::LSh => todo!(),
BinaryOperator::RSh => todo!(),
}
_ => todo!()
};
config.functional_units[row][col].op_type = op_type; // FIXME:
}
}
let schedule_attempt = self.schedule_row_recursive(row + 1, next_mapping, config);
if schedule_attempt.is_ok() {
return schedule_attempt;
}
}
return Err(Error)
}
fn schedule_slice(
&self,
mut live_outs: HashMap<NodeID, usize>,
......@@ -305,55 +462,52 @@ where
for row in (0..H).rev() {
println!("row: {:?}", row);
next_live_outs.clear();
let mut blarghs: HashMap<NodeID, usize> = HashMap::new();
// Try to compute live not generated
for u in &live_not_computed.clone() {
println!("in live not computed node {:?}", u);
let mut users = self.def_use_map.get_users(*u).iter().filter(|node| !self.function.nodes[node.idx()].is_control());
let u_node_data = &self.function.nodes[u.idx()];
let mut col = 0;
let mut computed_nodes: HashSet<NodeID> = HashSet::new();
// Schedule not computed nodes.
for not_computed_node in &live_not_computed {
println!("in live not computed node {:?}", not_computed_node);
let mut users = self.def_use_map.get_users(*not_computed_node).iter().filter(|node| !self.function.nodes[node.idx()].is_control());
let col;
if !users.all(|user| live_outs.contains_key(user)) {
// Can't schedule this yet, schedule a passthrough instead
col = Self::get_free_col(&next_live_outs);
next_live_outs.insert(*u, col);
blarghs.insert(*u, col);
live_not_computed.insert(*u);
next_live_outs.insert(*not_computed_node, col);
// Schedule Pass Through
functional_units[row][col] = FunctionalUnit {
op_type: FuOp::PassA,
};
println!("schedule node {:?} pass through @ {:?}", u, col);
println!("schedule node {:?} pass through @ {:?}", not_computed_node, col);
} else {
// Schedule Computation
col = Self::get_free_col(&next_live_outs);
next_live_outs.insert(*u, col);
blarghs.insert(*u, col);
next_live_outs.insert(*not_computed_node, col);
computed_nodes.insert(not_computed_node.clone());
functional_units[row][col] = FunctionalUnit { op_type: FuOp::Add };
println!("schedule node {:?} computation @ {:?}", u, col);
live_not_computed.remove(&u);
println!("schedule node {:?} computation @ {:?}", not_computed_node, col);
// live_not_computed.remove(&not_computed_node);
}
let live_out_col = live_outs[&not_computed_node];
let live_out_col = live_outs[u];
// Propogate to itself
switchboxes[row][live_out_col] = Switchbox {
output_wires: (col, col),
};
continue;
}
for (live_out, live_out_col) in &live_outs {
// If this node is dead, don't schedule
if (live_not_computed.contains(live_out)) {
if live_not_computed.contains(live_out) {
continue;
}
......@@ -379,14 +533,18 @@ where
let mut ctr = 0;
println!("uses: {:?}", uses);
// Try to schedule things live_outs uses.
for u in uses {
let users = self.def_use_map.get_users(u);
let u_node_data = &self.function.nodes[u.idx()];
if live_outs.contains_key(&u) && !blarghs.contains_key(&u){
if next_live_outs.contains_key(&u) {
// reuse
if row != (H - 1) {
idx_vec[ctr] = next_live_outs[&u];
}
println!("reuse value {:?}", u);
} else if live_outs.contains_key(&u) {
// Keep it live.
let col = Self::get_free_col(&next_live_outs);
......@@ -401,13 +559,7 @@ where
if row != (H - 1) {
idx_vec[ctr] = col;
}
} else if next_live_outs.contains_key(&u) {
// reuse
if row != (H - 1) {
idx_vec[ctr] = next_live_outs[&u];
}
println!("reuse value {:?}", u);
} else if (u_node_data.is_constant()
} else if (u_node_data.is_constant()
|| u_node_data.is_dynamic_constant()
|| u_node_data.is_parameter())
{
......@@ -471,6 +623,11 @@ where
}
live_outs = next_live_outs.clone();
for node in computed_nodes {
live_not_computed.remove(&node);
}
}
Ok((
......@@ -503,16 +660,16 @@ where
writeln!(writer,"node [shape=square]");
for (row_num, row) in slice.functional_units.iter().enumerate() {
for (col_num, fu) in row.iter().enumerate() {
if fu.op_type != FuOp::Default {
//if fu.op_type != FuOp::Default {
writeln!(writer,"n_{}_{} [label=\"{:?}\"]", row_num, col_num, fu.op_type);
}
//}
}
}
for (row_num, row) in slice.switchboxes.iter().enumerate() {
for (col_num, sb) in row.iter().enumerate() {
if sb.output_wires.0 != 100 {
//if sb.output_wires.0 != 100 {
writeln!(writer,
"n_{}_{} -> n_{}_{}",
row_num,
......@@ -520,8 +677,8 @@ where
row_num + 1,
col_num
);
}
if sb.output_wires.1 != 100 {
//}
//if sb.output_wires.1 != 100 {
writeln!(writer,
"n_{}_{} -> n_{}_{}",
row_num,
......@@ -529,26 +686,26 @@ where
row_num + 1,
col_num
);
}
//}
}
}
writeln!(writer,"edge [weight=1000 style=invis]");
for (row_num, row) in slice.functional_units.iter().enumerate() {
for col_num in 0..row.len() {
print!("n_{}_{}", col_num, row_num);
write!(writer, "n_{}_{}", col_num, row_num);
if col_num + 1 < row.len() {
print!(" -> ");
write!(writer, " -> ");
}
}
writeln!(writer,"");
}
for (row_num, row) in slice.functional_units.iter().enumerate() {
print!("rank=same {{");
write!(writer, "rank=same {{");
for col_num in 0..row.len() {
print!("n_{}_{}", row_num, col_num);
write!(writer, "n_{}_{}", row_num, col_num);
if col_num + 1 < row.len() {
print!(" -> ");
write!(writer, " -> ");
}
}
writeln!(writer,"}}");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment