diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index ecc7eecfeb20c96969c50ebe8ac6c436c5f47a1a..0cea209ff7d902cb6131e04bc9cef8c3f9cc1f7b 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1046,3 +1046,162 @@ define_id_type!(NodeID); define_id_type!(TypeID); define_id_type!(ConstantID); define_id_type!(DynamicConstantID); + +// Debug printing for modules + +use std::fmt::Display; +use std::fmt::Formatter; + +impl Display for Module { + fn fmt(&self, f : &mut Formatter<'_>) -> std::fmt::Result { + for func in self.functions.iter() { + func.ir_fmt(f, self)?; + write!(f, "\n")?; + } + Ok(()) + } +} + +trait IRDisplay { + fn ir_fmt(&self, f : &mut Formatter<'_>, module : &Module) -> std::fmt::Result; +} + +impl IRDisplay for Function { + fn ir_fmt(&self, f : &mut Formatter<'_>, module : &Module) -> std::fmt::Result { + write!(f, "fn {}<{}>(", self.name, self.num_dynamic_constants)?; + + for (idx, typ) in self.param_types.iter().enumerate() { + write!(f, "arg_{} : ", idx)?; + module.write_type(*typ, f)?; + if idx + 1 < self.param_types.len() { write!(f, ", ")?; } + } + + write!(f, ") -> ")?; + module.write_type(self.return_type, f)?; + + write!(f, "\n")?; + + for (idx, node) in self.nodes.iter().enumerate() { + write!(f, "\tvar_{} = ", idx)?; + node.ir_fmt(f, module)?; + write!(f, "\n")?; + } + + Ok(()) + } +} + +impl IRDisplay for Node { + fn ir_fmt(&self, f : &mut Formatter<'_>, module : &Module) -> std::fmt::Result { + match self { + Node::Start => { write!(f, "start") }, + Node::Region { preds } => { + write!(f, "region(")?; + for (idx, pred) in preds.iter().enumerate() { + write!(f, "var_{}", pred.0)?; + if idx + 1 < preds.len() { + write!(f, ", ")?; + } + } + write!(f, ")") + }, + Node::If { control, cond } => { + write!(f, "if(var_{}, var_{})", control.0, cond.0) + }, + Node::Match { control, sum } => { + write!(f, "match(var_{}, var_{})", control.0, sum.0) + }, + Node::Fork { control, factor } => { + write!(f, "fork(var_{}, ", control.0)?; + module.write_dynamic_constant(*factor, f)?; + write!(f, ")") + }, + Node::Join { control } => { + write!(f, "join(var_{})", control.0) + }, + Node::Phi { control, data } => { + write!(f, "phi(var_{}", control.0)?; + for val in data.iter() { + write!(f, ", var_{}", val.0)?; + } + write!(f, ")") + }, + Node::ThreadID { control } => { + write!(f, "thread_id(var_{})", control.0) + }, + Node::Reduce { control, init, reduct } => { + write!(f, "reduce(var_{}, var_{}, var_{})", control.0, init.0, reduct.0) + }, + Node::Return { control, data } => { + write!(f, "return(var_{}, var_{})", control.0, data.0) + }, + Node::Parameter { index } => { + write!(f, "arg_{}", index) + }, + Node::Constant { id } => { + write!(f, "constant(")?; + module.write_constant(*id, f)?; + write!(f, ")") + }, + Node::DynamicConstant { id } => { + write!(f, "dynamic_constant(")?; + module.write_dynamic_constant(*id, f)?; + write!(f, ")") + }, + Node::Unary { input, op } => { + write!(f, "{}(arg_{})", op.lower_case_name(), input.0) + }, + Node::Binary { left, right, op } => { + write!(f, "{}(arg_{}, arg_{})", op.lower_case_name(), left.0, right.0) + }, + Node::Call { function, dynamic_constants, args } => { + write!(f, "call<")?; + for (idx, dyn_const) in dynamic_constants.iter().enumerate() { + module.write_dynamic_constant(*dyn_const, f)?; + if idx + 1 < dynamic_constants.len() { + write!(f, ", ")?; + } + } + write!(f, ">({}", module.functions[function.0 as usize].name)?; + for arg in args.iter() { + write!(f, ", var_{}", arg.0)?; + } + write!(f, ")") + }, + Node::Read { collect, indices } => { + write!(f, "read(var_{}", collect.0)?; + for idx in indices.iter() { + write!(f, ", ")?; + idx.ir_fmt(f, module)?; + } + write!(f, ")") + }, + Node::Write { collect, data, indices } => { + write!(f, "write(var_{}, var_{}", collect.0, data.0)?; + for idx in indices.iter() { + write!(f, ", ")?; + idx.ir_fmt(f, module)?; + } + write!(f, ")") + }, + } + } +} + +impl IRDisplay for Index { + fn ir_fmt(&self, f : &mut Formatter<'_>, _module : &Module) -> std::fmt::Result { + match self { + Index::Field(idx) => write!(f, "field({})", idx), + Index::Variant(idx) => write!(f, "variant({})", idx), + Index::Control(idx) => write!(f, "control({})", idx), + Index::Position(indices) => { + write!(f, "position(")?; + for (i, idx) in indices.iter().enumerate() { + write!(f, "var_{}", idx.0)?; + if i + 1 < indices.len() { write!(f, ", ")?; } + } + write!(f, ")") + }, + } + } +}