From dc54b81b62c4728658a7e1c1cdab8cab90db2f64 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 24 Apr 2024 12:34:35 -0500
Subject: [PATCH] Start implementing new scheduling language

---
 juno_scheduler/src/lib.rs | 128 ++++++++++++++++++++++++--------------
 1 file changed, 81 insertions(+), 47 deletions(-)

diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 490b138b..20a4202e 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -3,6 +3,7 @@ extern crate hercules_ir;
 use std::collections::{HashSet, HashMap};
 use std::fs::File;
 use std::io::Read;
+use std::ops::{Index, IndexMut};
 
 use lrlex::DefaultLexerTypes;
 use lrpar::NonStreamingLexer;
@@ -74,90 +75,119 @@ pub fn schedule(module : &Module, info : FunctionMap, schedule : String)
     }
 }
 
+#[derive(Clone)]
+struct DeviceSchedules {
+    cpu : Vec<Schedule>,
+    gpu : Vec<Schedule>,
+}
+
+impl DeviceSchedules {
+    fn new() -> Self {
+        DeviceSchedules { cpu : vec![], gpu : vec![] }
+    }
+}
+
+impl Index<Device> for DeviceSchedules {
+    type Output = Vec<Schedule>;
+
+    fn index(&self, device : Device) -> &Self::Output {
+        match device {
+            Device::CPU => &self.cpu,
+            Device::GPU => &self.gpu,
+        }
+    }
+}
+
+impl IndexMut<Device> for DeviceSchedules {
+    fn index_mut(&mut self, device : Device) -> &mut Self::Output {
+        match device {
+            Device::CPU => &mut self.cpu,
+            Device::GPU => &mut self.gpu,
+        }
+    }
+}
+
 // a plan that tracks additional information useful while we construct the
 // schedule
 struct TempPlan {
-    schedules  : Vec<Vec<Schedule>>,
+    schedules  : Vec<DeviceSchedules>,
     // we track both the partition each node is in and what labeled caused us
     // to assign that partition
     partitions : Vec<(usize, PartitionNumber)>,
-    partition_devices : Vec<Device>,
+    partition_devices : Vec<Vec<Device>>,
 }
 type PartitionNumber = usize;
 
-impl Into<Plan> for TempPlan {
-    fn into(self) -> Plan {
-        let TempPlan { schedules, partitions, partition_devices } = self;
-        let num_partitions = partition_devices.len();
-
-        Plan {
-            schedules : schedules,
-            partitions : partitions.into_iter().map(|(_, p)| PartitionID::new(p))
-                                   .collect::<Vec<_>>(),
-            partition_devices : partition_devices,
-            num_partitions : num_partitions,
-        }
-    }
-}
-
-
-fn generate_schedule(module : &Module, info : FunctionMap, schedule : Vec<parser::Inst>,
+fn generate_schedule(module : &Module, info : FunctionMap, schedule : Vec<parser::Partition>,
                      lexer : &dyn NonStreamingLexer<DefaultLexerTypes<u32>>)
     -> Result<HashMap<FunctionID, Plan>, String> {
     let mut res : HashMap<FunctionID, TempPlan> = HashMap::new();
-    // we initialize every node in every function as not having any schedule
-    // and being in the default partition which is a cpu partition (a result of
-    // label 0)
+
+    // We initialize every node in every function as not having any schedule
+    // and being in the default partition which is a CPU-only partition
+    // (a result of label 0)
     for (_, (_, _, func_insts)) in info.iter() {
         for (_, func_id, _, _) in func_insts.iter() {
             let num_nodes = module.functions[func_id.idx()].nodes.len();
             res.insert(*func_id,
                        TempPlan {
-                           schedules : vec![vec![]; num_nodes],
+                           schedules : vec![DeviceSchedules::new(); num_nodes],
                            partitions : vec![(0, 0); num_nodes],
-                           partition_devices : vec![Device::CPU] });
+                           partition_devices : vec![vec![Device::CPU]] });
         }
     }
 
+    // Construct a map from function names to function numbers
     let mut function_names : HashMap<String, usize> = HashMap::new();
     for (num, (_, nm, _)) in info.iter() {
         function_names.insert(nm.clone(), *num);
     }
+    // Make the map immutable
     let function_names = function_names;
 
-    for parser::Inst { span : _, base, commands } in schedule {
-        let parser::Func { span : _, name : func_name, args : func_args }
-        = match &base { 
-            parser::Base::Function { span : _, func } => func,
-            parser::Base::Label    { span : _, func, .. } => func,
-        };
-        let name = lexer.span_str(*func_name).to_string();
+    for parser::Partition { span : _, func, labels, directs } in schedule {
+        // Identify the function we are partitioning/scheduling
+        let parser::Func { span : _, name : func_name, args : func_args } = func;
+        let name = lexer.span_str(func_name).to_string();
         let func_num = match function_names.get(&name) {
             Some(num) => num,
             None => { return Err(format!("Function {} is undefined", name)); },
         };
 
-        if func_args.is_some() { 
+        if func_args.is_some() {
             todo!("Scheduling particular typed variants not supported")
         }
 
+        // Identify label information
         let (label_map, _, func_inst) = info.get(func_num).unwrap();
-        let label_num =
-            match &base {
-                parser::Base::Function { .. } => 0,
-                parser::Base::Label { span : _, func : _, label } => {
-                    let label_name = lexer.span_str(*label).to_string();
-                    match label_map.get(&label_name) {
-                        Some(num) => *num,
-                        None => { 
-                            return Err(format!("Label {} undefined in {}", 
-                                               label_name, name));
-                        },
-                    }
-                },
-            };
+        let get_label_num = |label_span| {
+            let label_name = lexer.span_str(label_span).to_string();
+            match label_map.get(&label_name) {
+                Some(num) => Ok(*num),
+                None => { Err(format!("Label {} undefined in {}", label_name, name)) },
+            }
+        };
+        let label_nums = labels.into_iter().map(get_label_num)
+                               .collect::<HashSet<_>>();
 
+        // Process the partitioning and scheduling directives for each instance
+        // of the function
         for (_, func_id, label_info, node_labels) in func_inst {
+            // Setup the new partition
+            let func_info = res.get_mut(func_id).unwrap();
+            let partition_num = func_info.partition_devices.len();
+            let mut partition_devices = vec![];
+
+            // Need some sort of recursive process directives function which
+            // can take information about what devices, or conditions, etc.
+            for directive in directs.iter() {
+                match directive {
+                    parser::Directive::OnDevice { span, device, directs } => {},
+                    parser::Directive::IfExpr { span, cond, directs } => {},
+                    parser::Directive::Command { span, label, commands } => {},
+                }
+            }
+            /*
             for parser::Command { span : _, name : command_name, 
                                   args : command_args } in commands.iter() {
                 if command_args.len() != 0 { todo!("Command arguments not supported") }
@@ -203,8 +233,12 @@ fn generate_schedule(module : &Module, info : FunctionMap, schedule : Vec<parser
                     return Err(format!("Command {} undefined", command));
                 }
             }
+            */
+            
+            func_info.partition_devices.push(partition_devices);
         }
     }
 
-    Ok(res.into_iter().map(|(f, p)| (f, p.into())).collect::<HashMap<_, _>>())
+    Err(format!("Scheduler not implemented"))
+    //Ok(res.into_iter().map(|(f, p)| (f, p.into())).collect::<HashMap<_, _>>())
 }
-- 
GitLab