diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs index 5ee3ff6173a94ea4070d8bac2695e7870ae9721f..8cc91f922b58d6db4bd0b189e8b3d506bfdb6e01 100644 --- a/hercules_ir/src/schedule.rs +++ b/hercules_ir/src/schedule.rs @@ -11,6 +11,18 @@ use crate::*; #[derive(Debug, Clone)] pub enum Schedule { ParallelReduce, + Vectorize, +} + +/* + * The authoritative enumeration of supported devices. Technically, a device + * refers to a specific backend, so difference "devices" may refer to the same + * "kind" of hardware. + */ +#[derive(Debug, Clone)] +pub enum Device { + CPU, + GPU, } /* @@ -22,6 +34,7 @@ pub enum Schedule { pub struct Plan { pub schedules: Vec<Vec<Schedule>>, pub partitions: Vec<PartitionID>, + pub partition_devices: Vec<Device>, pub num_partitions: usize, } @@ -57,6 +70,7 @@ pub fn default_plan( let mut plan = Plan { schedules: vec![vec![]; function.nodes.len()], partitions: vec![PartitionID::new(0); function.nodes.len()], + partition_devices: vec![Device::CPU; 1], num_partitions: 0, }; @@ -66,6 +80,9 @@ pub fn default_plan( // Infer a partitioning. partition_out_forks(function, reverse_postorder, fork_join_map, bbs, &mut plan); + // Place fork partitions on the GPU. + place_fork_partitions_on_gpu(function, &mut plan); + plan } @@ -251,4 +268,17 @@ pub fn partition_out_forks( for id in (0..function.nodes.len()).map(NodeID::new) { plan.partitions[id.idx()] = representative_to_partition_ids[&representatives[id.idx()]]; } + + plan.partition_devices = vec![Device::CPU; plan.num_partitions]; +} + +/* + * Set the device for all partitions containing a fork to the GPU. + */ +pub fn place_fork_partitions_on_gpu(function: &Function, plan: &mut Plan) { + for idx in 0..function.nodes.len() { + if function.nodes[idx].is_fork() { + plan.partition_devices[plan.partitions[idx].idx()] = Device::GPU; + } + } } diff --git a/hercules_tools/hercules_dot/src/dot.rs b/hercules_tools/hercules_dot/src/dot.rs index 0671fe26ae82136a207fbc4a5c617cf9132711c8..751bd7c8e94846b5b41d703b1abab01f08266c61 100644 --- a/hercules_tools/hercules_dot/src/dot.rs +++ b/hercules_tools/hercules_dot/src/dot.rs @@ -35,7 +35,11 @@ pub fn write_dot<W: Write>( // Step 1: draw IR graph itself. This includes all IR nodes and all edges // between IR nodes. for partition_idx in 0..plan.num_partitions { - write_partition_header(function_id, partition_idx, module, w)?; + let partition_color = match plan.partition_devices[partition_idx] { + Device::CPU => "lightblue", + Device::GPU => "darkseagreen", + }; + write_partition_header(function_id, partition_idx, module, partition_color, w)?; for node_id in &partition_to_node_map[partition_idx] { let node = &function.nodes[node_id.idx()]; let dst_ty = &module.types[typing[function_id.idx()][node_id.idx()].idx()]; @@ -173,13 +177,14 @@ fn write_partition_header<W: Write>( function_id: FunctionID, partition_idx: usize, module: &Module, + color: &str, w: &mut W, ) -> std::fmt::Result { let function = &module.functions[function_id.idx()]; write!(w, "subgraph {}_{} {{\n", function.name, partition_idx)?; write!(w, "label=\"\"\n")?; write!(w, "style=rounded\n")?; - write!(w, "bgcolor=ivory3\n")?; + write!(w, "bgcolor={}\n", color)?; write!(w, "cluster=true\n")?; Ok(()) }