diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 490b138b61a993ab81757fca2d211d88d1c2bbf5..20a4202ef27bb99916806be9a203b2d4e8e8e2ee 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -3,6 +3,7 @@ extern crate hercules_ir; use std::collections::{HashSet, HashMap}; use std::fs::File; use std::io::Read; +use std::ops::{Index, IndexMut}; use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; @@ -74,90 +75,119 @@ pub fn schedule(module : &Module, info : FunctionMap, schedule : String) } } +#[derive(Clone)] +struct DeviceSchedules { + cpu : Vec<Schedule>, + gpu : Vec<Schedule>, +} + +impl DeviceSchedules { + fn new() -> Self { + DeviceSchedules { cpu : vec![], gpu : vec![] } + } +} + +impl Index<Device> for DeviceSchedules { + type Output = Vec<Schedule>; + + fn index(&self, device : Device) -> &Self::Output { + match device { + Device::CPU => &self.cpu, + Device::GPU => &self.gpu, + } + } +} + +impl IndexMut<Device> for DeviceSchedules { + fn index_mut(&mut self, device : Device) -> &mut Self::Output { + match device { + Device::CPU => &mut self.cpu, + Device::GPU => &mut self.gpu, + } + } +} + // a plan that tracks additional information useful while we construct the // schedule struct TempPlan { - schedules : Vec<Vec<Schedule>>, + schedules : Vec<DeviceSchedules>, // we track both the partition each node is in and what labeled caused us // to assign that partition partitions : Vec<(usize, PartitionNumber)>, - partition_devices : Vec<Device>, + partition_devices : Vec<Vec<Device>>, } type PartitionNumber = usize; -impl Into<Plan> for TempPlan { - fn into(self) -> Plan { - let TempPlan { schedules, partitions, partition_devices } = self; - let num_partitions = partition_devices.len(); - - Plan { - schedules : schedules, - partitions : partitions.into_iter().map(|(_, p)| PartitionID::new(p)) - .collect::<Vec<_>>(), - partition_devices : partition_devices, - num_partitions : num_partitions, - } - } -} - - -fn generate_schedule(module : &Module, info : FunctionMap, schedule : Vec<parser::Inst>, +fn generate_schedule(module : &Module, info : FunctionMap, schedule : Vec<parser::Partition>, lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>) -> Result<HashMap<FunctionID, Plan>, String> { let mut res : HashMap<FunctionID, TempPlan> = HashMap::new(); - // we initialize every node in every function as not having any schedule - // and being in the default partition which is a cpu partition (a result of - // label 0) + + // We initialize every node in every function as not having any schedule + // and being in the default partition which is a CPU-only partition + // (a result of label 0) for (_, (_, _, func_insts)) in info.iter() { for (_, func_id, _, _) in func_insts.iter() { let num_nodes = module.functions[func_id.idx()].nodes.len(); res.insert(*func_id, TempPlan { - schedules : vec![vec![]; num_nodes], + schedules : vec![DeviceSchedules::new(); num_nodes], partitions : vec![(0, 0); num_nodes], - partition_devices : vec![Device::CPU] }); + partition_devices : vec![vec![Device::CPU]] }); } } + // Construct a map from function names to function numbers let mut function_names : HashMap<String, usize> = HashMap::new(); for (num, (_, nm, _)) in info.iter() { function_names.insert(nm.clone(), *num); } + // Make the map immutable let function_names = function_names; - for parser::Inst { span : _, base, commands } in schedule { - let parser::Func { span : _, name : func_name, args : func_args } - = match &base { - parser::Base::Function { span : _, func } => func, - parser::Base::Label { span : _, func, .. } => func, - }; - let name = lexer.span_str(*func_name).to_string(); + for parser::Partition { span : _, func, labels, directs } in schedule { + // Identify the function we are partitioning/scheduling + let parser::Func { span : _, name : func_name, args : func_args } = func; + let name = lexer.span_str(func_name).to_string(); let func_num = match function_names.get(&name) { Some(num) => num, None => { return Err(format!("Function {} is undefined", name)); }, }; - if func_args.is_some() { + if func_args.is_some() { todo!("Scheduling particular typed variants not supported") } + // Identify label information let (label_map, _, func_inst) = info.get(func_num).unwrap(); - let label_num = - match &base { - parser::Base::Function { .. } => 0, - parser::Base::Label { span : _, func : _, label } => { - let label_name = lexer.span_str(*label).to_string(); - match label_map.get(&label_name) { - Some(num) => *num, - None => { - return Err(format!("Label {} undefined in {}", - label_name, name)); - }, - } - }, - }; + let get_label_num = |label_span| { + let label_name = lexer.span_str(label_span).to_string(); + match label_map.get(&label_name) { + Some(num) => Ok(*num), + None => { Err(format!("Label {} undefined in {}", label_name, name)) }, + } + }; + let label_nums = labels.into_iter().map(get_label_num) + .collect::<HashSet<_>>(); + // Process the partitioning and scheduling directives for each instance + // of the function for (_, func_id, label_info, node_labels) in func_inst { + // Setup the new partition + let func_info = res.get_mut(func_id).unwrap(); + let partition_num = func_info.partition_devices.len(); + let mut partition_devices = vec![]; + + // Need some sort of recursive process directives function which + // can take information about what devices, or conditions, etc. + for directive in directs.iter() { + match directive { + parser::Directive::OnDevice { span, device, directs } => {}, + parser::Directive::IfExpr { span, cond, directs } => {}, + parser::Directive::Command { span, label, commands } => {}, + } + } + /* for parser::Command { span : _, name : command_name, args : command_args } in commands.iter() { if command_args.len() != 0 { todo!("Command arguments not supported") } @@ -203,8 +233,12 @@ fn generate_schedule(module : &Module, info : FunctionMap, schedule : Vec<parser return Err(format!("Command {} undefined", command)); } } + */ + + func_info.partition_devices.push(partition_devices); } } - Ok(res.into_iter().map(|(f, p)| (f, p.into())).collect::<HashMap<_, _>>()) + Err(format!("Scheduler not implemented")) + //Ok(res.into_iter().map(|(f, p)| (f, p.into())).collect::<HashMap<_, _>>()) }