Skip to content
Snippets Groups Projects
Commit 133a9c9c authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'ir_dev' into 'main'

Finish core IR definition, parsing, and write to .dot graph

See merge request !1
parents 35231789 96105694
No related branches found
No related tags found
1 merge request!1Finish core IR definition, parsing, and write to .dot graph
......@@ -32,6 +32,8 @@ The IR of the Hercules compiler is similar to the sea of nodes IR presented in "
A key design consideration of Hercules IR is the absence of a concept of memory. A downside of this approach is that any language targetting Hecules IR must also be very restrictive regarding memory - in practice, this means tightly controlling or eliminating first-class references. The upside is that the compiler has complete freedom to layout data however it likes in memory when performing code generation. This includes deciding which data resides in which address spaces, which is a necessary ability for a compiler striving to have fine-grained control over what operations are computed on what devices.
In addition to not having a generalized memory, Hercules IR has no functionality for calling functions with side-effects, or doing IO. In other words, Hercules is a pure IR (it's not functional, as functions aren't first class values). This may be changed in the future - we could support effectful programs by giving call operators a control input and output edge. However, at least for now, we need to work with the simplest IR possible.
### Optimizations
TODO: @rarbore2
......
......@@ -6,26 +6,29 @@ pub fn write_dot<W: std::fmt::Write>(module: &Module, w: &mut W) -> std::fmt::Re
write!(w, "digraph \"Module\" {{\n")?;
write!(w, "compound=true\n")?;
for i in 0..module.functions.len() {
write_function(i, module, &module.constants, w)?;
write_function(i, module, w)?;
}
write!(w, "}}\n")?;
Ok(())
}
fn write_function<W: std::fmt::Write>(
i: usize,
module: &Module,
constants: &Vec<Constant>,
w: &mut W,
) -> std::fmt::Result {
fn write_function<W: std::fmt::Write>(i: usize, module: &Module, w: &mut W) -> std::fmt::Result {
write!(w, "subgraph {} {{\n", module.functions[i].name)?;
write!(w, "label=\"{}\"\n", module.functions[i].name)?;
if module.functions[i].num_dynamic_constants > 0 {
write!(
w,
"label=\"{}<{}>\"\n",
module.functions[i].name, module.functions[i].num_dynamic_constants
)?;
} else {
write!(w, "label=\"{}\"\n", module.functions[i].name)?;
}
write!(w, "bgcolor=ivory4\n")?;
write!(w, "cluster=true\n")?;
let mut visited = HashMap::default();
let function = &module.functions[i];
for j in 0..function.nodes.len() {
visited = write_node(i, j, module, constants, visited, w)?.1;
visited = write_node(i, j, module, visited, w)?.1;
}
write!(w, "}}\n")?;
Ok(())
......@@ -35,7 +38,6 @@ fn write_node<W: std::fmt::Write>(
i: usize,
j: usize,
module: &Module,
constants: &Vec<Constant>,
mut visited: HashMap<NodeID, String>,
w: &mut W,
) -> Result<(String, HashMap<NodeID, String>), std::fmt::Error> {
......@@ -51,32 +53,119 @@ fn write_node<W: std::fmt::Write>(
write!(w, "{} [label=\"start\"];\n", name)?;
visited
}
Node::Region { preds } => {
write!(w, "{} [label=\"region\"];\n", name)?;
for (idx, pred) in preds.iter().enumerate() {
let (pred_name, tmp_visited) = write_node(i, pred.idx(), module, visited, w)?;
visited = tmp_visited;
write!(
w,
"{} -> {} [label=\"pred {}\", style=\"dashed\"];\n",
pred_name, name, idx
)?;
}
visited
}
Node::If { control, cond } => {
write!(w, "{} [label=\"if\"];\n", name)?;
let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?;
let (cond_name, visited) = write_node(i, cond.idx(), module, visited, w)?;
write!(
w,
"{} -> {} [label=\"control\", style=\"dashed\"];\n",
control_name, name
)?;
write!(w, "{} -> {} [label=\"cond\"];\n", cond_name, name)?;
visited
}
Node::Fork { control, factor } => {
write!(
w,
"{} [label=\"fork<{:?}>\"];\n",
name,
module.dynamic_constants[factor.idx()]
)?;
let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?;
write!(
w,
"{} -> {} [label=\"control\", style=\"dashed\"];\n",
control_name, name
)?;
visited
}
Node::Join { control, data } => {
write!(w, "{} [label=\"join\"];\n", name,)?;
let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?;
let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?;
write!(
w,
"{} -> {} [label=\"control\", style=\"dashed\"];\n",
control_name, name
)?;
write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?;
visited
}
Node::Phi { control, data } => {
write!(w, "{} [label=\"phi\"];\n", name)?;
let (control_name, mut visited) = write_node(i, control.idx(), module, visited, w)?;
write!(
w,
"{} -> {} [label=\"control\", style=\"dashed\"];\n",
control_name, name
)?;
for (idx, data) in data.iter().enumerate() {
let (data_name, tmp_visited) = write_node(i, data.idx(), module, visited, w)?;
visited = tmp_visited;
write!(w, "{} -> {} [label=\"data {}\"];\n", data_name, name, idx)?;
}
visited
}
Node::Return { control, value } => {
let (control_name, visited) =
write_node(i, control.idx(), module, constants, visited, w)?;
let (value_name, visited) =
write_node(i, value.idx(), module, constants, visited, w)?;
let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?;
let (value_name, visited) = write_node(i, value.idx(), module, visited, w)?;
write!(w, "{} [label=\"return\"];\n", name)?;
write!(w, "{} -> {} [style=\"dashed\"];\n", control_name, name)?;
write!(w, "{} -> {};\n", value_name, name)?;
write!(
w,
"{} -> {} [label=\"control\", style=\"dashed\"];\n",
control_name, name
)?;
write!(w, "{} -> {} [label=\"value\"];\n", value_name, name)?;
visited
}
Node::Parameter { index } => {
write!(w, "{} [label=\"param #{}\"];\n", name, index)?;
write!(w, "{} [label=\"param #{}\"];\n", name, index + 1)?;
visited
}
Node::Constant { id } => {
write!(w, "{} [label=\"{:?}\"];\n", name, constants[id.idx()])?;
write!(
w,
"{} [label=\"{:?}\"];\n",
name,
module.constants[id.idx()]
)?;
visited
}
Node::DynamicConstant { id } => {
write!(
w,
"{} [label=\"dynamic_constant({:?})\"];\n",
name,
module.dynamic_constants[id.idx()]
)?;
visited
}
Node::Add { left, right } => {
let (left_name, visited) =
write_node(i, left.idx(), module, constants, visited, w)?;
let (right_name, visited) =
write_node(i, right.idx(), module, constants, visited, w)?;
write!(w, "{} [label=\"add\"];\n", name)?;
write!(w, "{} -> {};\n", left_name, name)?;
write!(w, "{} -> {};\n", right_name, name)?;
Node::Unary { input, op } => {
write!(w, "{} [label=\"{}\"];\n", name, get_string_uop_kind(*op))?;
let (input_name, visited) = write_node(i, input.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"input\"];\n", input_name, name)?;
visited
}
Node::Binary { left, right, op } => {
write!(w, "{} [label=\"{}\"];\n", name, get_string_bop_kind(*op))?;
let (left_name, visited) = write_node(i, left.idx(), module, visited, w)?;
let (right_name, visited) = write_node(i, right.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"left\"];\n", left_name, name)?;
write!(w, "{} -> {} [label=\"right\"];\n", right_name, name)?;
visited
}
Node::Call {
......@@ -84,28 +173,90 @@ fn write_node<W: std::fmt::Write>(
dynamic_constants,
args,
} => {
for arg in args.iter() {
let (arg_name, tmp_visited) =
write_node(i, arg.idx(), module, constants, visited, w)?;
write!(w, "{} [label=\"call<", name,)?;
for (idx, id) in dynamic_constants.iter().enumerate() {
let dc = &module.dynamic_constants[id.idx()];
if idx == 0 {
write!(w, "{:?}", dc)?;
} else {
write!(w, ", {:?}", dc)?;
}
}
write!(w, ">({})\"];\n", module.functions[function.idx()].name)?;
for (idx, arg) in args.iter().enumerate() {
let (arg_name, tmp_visited) = write_node(i, arg.idx(), module, visited, w)?;
visited = tmp_visited;
write!(w, "{} -> {};\n", arg_name, name)?;
write!(w, "{} -> {} [label=\"arg {}\"];\n", arg_name, name, idx)?;
}
write!(
w,
"{} [label=\"call({})\"];\n",
"{} -> start_{}_0 [label=\"call\", lhead={}];\n",
name,
function.idx(),
module.functions[function.idx()].name
)?;
visited
}
Node::ReadProd { prod, index } => {
write!(w, "{} [label=\"read_prod({})\"];\n", name, index)?;
let (prod_name, visited) = write_node(i, prod.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"prod\"];\n", prod_name, name)?;
visited
}
Node::WriteProd { prod, data, index } => {
write!(w, "{} [label=\"write_prod({})\"];\n", name, index)?;
let (prod_name, visited) = write_node(i, prod.idx(), module, visited, w)?;
let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"prod\"];\n", prod_name, name)?;
write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?;
visited
}
Node::ReadArray { array, index } => {
write!(w, "{} [label=\"read_array\"];\n", name)?;
let (array_name, visited) = write_node(i, array.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"array\"];\n", array_name, name)?;
let (index_name, visited) = write_node(i, index.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"index\"];\n", index_name, name)?;
visited
}
Node::WriteArray { array, data, index } => {
write!(w, "{} [label=\"write_array\"];\n", name)?;
let (array_name, visited) = write_node(i, array.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"array\"];\n", array_name, name)?;
let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?;
let (index_name, visited) = write_node(i, index.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"index\"];\n", index_name, name)?;
visited
}
Node::Match { control, sum } => {
write!(w, "{} [label=\"match\"];\n", name)?;
let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?;
write!(
w,
"{} -> {} [label=\"control\", style=\"dashed\"];\n",
control_name, name
)?;
let (sum_name, visited) = write_node(i, sum.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"sum\"];\n", sum_name, name)?;
visited
}
Node::BuildSum {
data,
sum_ty,
variant,
} => {
write!(
w,
"{} -> start_{}_0 [lhead={}];\n",
"{} [label=\"build_sum({:?}, {})\"];\n",
name,
function.idx(),
module.functions[function.idx()].name
module.types[sum_ty.idx()],
variant
)?;
let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?;
write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?;
visited
}
_ => todo!(),
};
Ok((visited.get(&id).unwrap().clone(), visited))
}
......@@ -125,7 +276,7 @@ fn get_string_node_kind(node: &Node) -> &'static str {
} => "fork",
Node::Join {
control: _,
factor: _,
data: _,
} => "join",
Node::Phi {
control: _,
......@@ -138,14 +289,57 @@ fn get_string_node_kind(node: &Node) -> &'static str {
Node::Parameter { index: _ } => "parameter",
Node::DynamicConstant { id: _ } => "dynamic_constant",
Node::Constant { id: _ } => "constant",
Node::Add { left: _, right: _ } => "add",
Node::Sub { left: _, right: _ } => "sub",
Node::Mul { left: _, right: _ } => "mul",
Node::Div { left: _, right: _ } => "div",
Node::Unary { input: _, op } => get_string_uop_kind(*op),
Node::Binary {
left: _,
right: _,
op,
} => get_string_bop_kind(*op),
Node::Call {
function: _,
dynamic_constants: _,
args: _,
} => "call",
Node::ReadProd { prod: _, index: _ } => "read_prod",
Node::WriteProd {
prod: _,
data: _,
index: _,
} => "write_prod ",
Node::ReadArray { array: _, index: _ } => "read_array",
Node::WriteArray {
array: _,
data: _,
index: _,
} => "write_array",
Node::Match { control: _, sum: _ } => "match",
Node::BuildSum {
data: _,
sum_ty: _,
variant: _,
} => "build_sum",
}
}
fn get_string_uop_kind(uop: UnaryOperator) -> &'static str {
match uop {
UnaryOperator::Not => "not",
UnaryOperator::Neg => "neg",
}
}
fn get_string_bop_kind(bop: BinaryOperator) -> &'static str {
match bop {
BinaryOperator::Add => "add",
BinaryOperator::Sub => "sub",
BinaryOperator::Mul => "mul",
BinaryOperator::Div => "div",
BinaryOperator::Rem => "rem",
BinaryOperator::LT => "lt",
BinaryOperator::LTE => "lte",
BinaryOperator::GT => "gt",
BinaryOperator::GTE => "gte",
BinaryOperator::EQ => "eq",
BinaryOperator::NE => "ne",
}
}
extern crate ordered_float;
/*
* A module is a list of functions. Functions contain types, constants, and
* dynamic constants, which are interned at the module level. Thus, if one
* wants to run an intraprocedural pass in parallel, it is advised to first
* destruct the module, then reconstruct it once finished.
*/
#[derive(Debug, Clone)]
pub struct Module {
pub functions: Vec<Function>,
......@@ -8,6 +14,14 @@ pub struct Module {
pub dynamic_constants: Vec<DynamicConstant>,
}
/*
* A function has a name, a list of types for its parameters, a single return
* type, a list of nodes in its sea-of-nodes style IR, and a number of dynamic
* constants. When calling a function, arguments matching the parameter types
* are required, as well as the correct number of dynamic constants. All
* dynamic constants are 64-bit unsigned integers (usize / u64), so it is
* sufficient to merely store how many of them the function takes as arguments.
*/
#[derive(Debug, Clone)]
pub struct Function {
pub name: String,
......@@ -17,9 +31,22 @@ pub struct Function {
pub num_dynamic_constants: u32,
}
/*
* Hercules IR has a fairly standard type system, with the exception of the
* control type. Hercules IR is based off of the sea-of-nodes IR, the main
* feature of which being a merged control and data flow graph. Thus, control
* is a type of value, just like any other type. However, the type system is
* very restrictive over what can be done with control values. A novel addition
* in Hercules IR is that a control type is parameterized by a list of thread
* spawning factors. This is the mechanism in Hercules IR for representing
* parallelism. Summation types are an IR equivalent of Rust's enum types.
* These are lowered into tagged unions during scheduling. Array types are one-
* dimensional. Multi-dimensional arrays are represented by nesting array types.
* An array extent is represented with a dynamic constant.
*/
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Type {
Control(DynamicConstantID),
Control(Box<[DynamicConstantID]>),
Integer8,
Integer16,
Integer32,
......@@ -32,9 +59,17 @@ pub enum Type {
Float64,
Product(Box<[TypeID]>),
Summation(Box<[TypeID]>),
Array(TypeID, Box<[DynamicConstantID]>),
Array(TypeID, DynamicConstantID),
}
/*
* Constants are pretty standard in Hercules IR. Float constants used the
* ordered_float crate so that constants can be keys in maps (used for
* interning constants during IR construction). Product, summation, and array
* constants all contain their own type. This is only strictly necessary for
* summation types, but provides a nice mechanism for sanity checking for
* product and array types as well.
*/
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Constant {
Integer8(i8),
......@@ -52,12 +87,34 @@ pub enum Constant {
Array(TypeID, Box<[ConstantID]>),
}
/*
* Dynamic constants are unsigned 64-bit integers passed to a Hercules function
* at runtime using the Hercules runtime API. They cannot be the result of
* computations in Hercules IR. For a single execution of a Hercules function,
* dynamic constants are constant throughout execution. This provides a
* mechanism by which Hercules functions can operate on arrays with variable
* length, while not needing Hercules functions to perform dynamic memory
* allocation - by providing dynamic constants to the runtime API, the runtime
* can allocate memory as necessary.
*/
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DynamicConstant {
Constant(usize),
Parameter(usize),
}
/*
* Hercules IR is a combination of a possibly cylic control flow graph, and
* many acyclic data flow graphs. Each node represents some operation on input
* values (including control), and produces some output value. Operations that
* conceptually produce multiple outputs (such as an if node) produce a product
* type instead. For example, the if node produces prod(control(N),
* control(N)), where the first control token represents the false branch, and
* the second control token represents the true branch. Another example is the
* fork node, which produces prod(control(N*k), u64), where the u64 is the
* thread ID. Functions are devoid of side effects, so call nodes don't take as
* input or output control tokens. There is also no global memory - use arrays.
*/
#[derive(Debug, Clone)]
pub enum Node {
Start,
......@@ -74,7 +131,7 @@ pub enum Node {
},
Join {
control: NodeID,
factor: DynamicConstantID,
data: NodeID,
},
Phi {
control: NodeID,
......@@ -93,29 +150,73 @@ pub enum Node {
DynamicConstant {
id: DynamicConstantID,
},
Add {
left: NodeID,
right: NodeID,
},
Sub {
left: NodeID,
right: NodeID,
},
Mul {
left: NodeID,
right: NodeID,
Unary {
input: NodeID,
op: UnaryOperator,
},
Div {
Binary {
left: NodeID,
right: NodeID,
op: BinaryOperator,
},
Call {
function: FunctionID,
dynamic_constants: Box<[DynamicConstantID]>,
args: Box<[NodeID]>,
},
ReadProd {
prod: NodeID,
index: usize,
},
WriteProd {
prod: NodeID,
data: NodeID,
index: usize,
},
ReadArray {
array: NodeID,
index: NodeID,
},
WriteArray {
array: NodeID,
data: NodeID,
index: NodeID,
},
Match {
control: NodeID,
sum: NodeID,
},
BuildSum {
data: NodeID,
sum_ty: TypeID,
variant: usize,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnaryOperator {
Not,
Neg,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinaryOperator {
Add,
Sub,
Mul,
Div,
Rem,
LT,
LTE,
GT,
GTE,
EQ,
NE,
}
/*
* Rust things to make newtyped IDs usable.
*/
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FunctionID(u32);
......
This diff is collapsed.
fn matmul<3>(a: array(array(f32, #1), #0), b: array(array(f32, #2), #1)) -> array(array(f32, #2), #0)
i = fork(start, #0)
i_ctrl = read_prod(i, 0)
i_idx = read_prod(i, 1)
k = fork(i_ctrl, #2)
k_ctrl = read_prod(k, 0)
k_idx = read_prod(k, 1)
zero_idx = constant(u64, 0)
one_idx = constant(u64, 1)
zero_val = constant(f32, 0)
loop = region(k_ctrl, if_true)
j = phi(loop, zero_idx, j_inc)
sum = phi(loop, zero_val, sum_inc)
j_inc = add(j, one_idx)
fval1 = read_array(a, i_idx)
fval2 = read_array(b, j)
val1 = read_array(fval1, j)
val2 = read_array(fval2, k_idx)
mul = mul(val1, val2)
sum_inc = add(sum, mul)
j_size = dynamic_constant(#1)
less = lt(j_inc, j_size)
if = if(loop, less)
if_false = read_prod(if, 0)
if_true = read_prod(if, 1)
k_join = join(if_false, sum_inc)
k_join_ctrl = read_prod(k_join, 0)
k_join_data = read_prod(k_join, 1)
i_join = join(k_join_ctrl, k_join_data)
i_join_ctrl = read_prod(i_join, 0)
i_join_data = read_prod(i_join, 1)
r = return(i_join_ctrl, i_join_data)
fn myfunc(x: i32) -> i32
y = call(add, x, x)
y = call<5>(add, x, x)
r = return(start, y)
fn add(x: i32, y: i32) -> i32
fn add<1>(x: i32, y: i32) -> i32
c = constant(i8, 5)
r = return(start, w)
dc = dynamic_constant(#0)
r = return(start, s)
w = add(z, c)
z = add(x, y)
s = add(w, dc)
z = add(x, y)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment