From 083bfc8f355b0e28405c44e016dcd87119575271 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Sun, 16 Feb 2025 11:06:09 -0600
Subject: [PATCH] Make frontend and outline not generate singleton products

---
 hercules_opt/src/outline.rs   | 73 +++++++++++++++++++++--------------
 hercules_opt/src/sroa.rs      | 25 +++++++-----
 juno_frontend/src/codegen.rs  | 24 +++++++-----
 juno_frontend/src/semant.rs   | 34 +++++++++-------
 juno_scheduler/src/default.rs |  2 -
 5 files changed, 96 insertions(+), 62 deletions(-)

diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs
index 8fe978c5..874e75e7 100644
--- a/hercules_opt/src/outline.rs
+++ b/hercules_opt/src/outline.rs
@@ -180,6 +180,13 @@ pub fn outline(
     editor.edit(|mut edit| {
         // Step 2: assemble the outlined function.
         let u32_ty = edit.add_type(Type::UnsignedInteger32);
+        let return_types: Box<[_]> = return_idx_to_inside_id
+            .iter()
+            .map(|id| typing[id.idx()])
+            .chain(callee_succ_return_idx.map(|_| u32_ty))
+            .collect();
+        let single_return = return_types.len() == 1;
+
         let mut outlined = Function {
             name: format!(
                 "{}_{}",
@@ -191,13 +198,11 @@ pub fn outline(
                 .map(|id| typing[id.idx()])
                 .chain(callee_pred_param_idx.map(|_| u32_ty))
                 .collect(),
-            return_type: edit.add_type(Type::Product(
-                return_idx_to_inside_id
-                    .iter()
-                    .map(|id| typing[id.idx()])
-                    .chain(callee_succ_return_idx.map(|_| u32_ty))
-                    .collect(),
-            )),
+            return_type: if single_return {
+                return_types[0]
+            } else {
+                edit.add_type(Type::Product(return_types))
+            },
             num_dynamic_constants: edit.get_num_dynamic_constant_params(),
             entry: false,
             nodes: vec![],
@@ -393,18 +398,24 @@ pub fn outline(
                 data_ids.push(cons_node_id);
             }
 
-            // Build the return product.
-            let mut construct_id = NodeID::new(outlined.nodes.len());
-            outlined.nodes.push(Node::Constant { id: cons_id });
-            for (idx, data) in data_ids.into_iter().enumerate() {
-                let write = Node::Write {
-                    collect: construct_id,
-                    data: data,
-                    indices: Box::new([Index::Field(idx)]),
-                };
-                construct_id = NodeID::new(outlined.nodes.len());
-                outlined.nodes.push(write);
-            }
+            // Build the return value
+            let construct_id = if single_return {
+                assert!(data_ids.len() == 1);
+                data_ids.pop().unwrap()
+            } else {
+                let mut construct_id = NodeID::new(outlined.nodes.len());
+                outlined.nodes.push(Node::Constant { id: cons_id });
+                for (idx, data) in data_ids.into_iter().enumerate() {
+                    let write = Node::Write {
+                        collect: construct_id,
+                        data: data,
+                        indices: Box::new([Index::Field(idx)]),
+                    };
+                    construct_id = NodeID::new(outlined.nodes.len());
+                    outlined.nodes.push(write);
+                }
+                construct_id
+            };
 
             // Return the return product.
             outlined.nodes.push(Node::Return {
@@ -505,16 +516,20 @@ pub fn outline(
         };
 
         // Create the read nodes from the call node to get the outputs of the
-        // outlined function.
-        let output_reads: Vec<_> = (0..return_idx_to_inside_id.len())
-            .map(|idx| {
-                let read = Node::Read {
-                    collect: call_id,
-                    indices: Box::new([Index::Field(idx)]),
-                };
-                edit.add_node(read)
-            })
-            .collect();
+        // outlined function (if there are multiple returned values)
+        let output_reads: Vec<_> = if single_return {
+            vec![call_id]
+        } else {
+            (0..return_idx_to_inside_id.len())
+                .map(|idx| {
+                    let read = Node::Read {
+                        collect: call_id,
+                        indices: Box::new([Index::Field(idx)]),
+                    };
+                    edit.add_node(read)
+                })
+                .collect()
+        };
         let indicator_read = callee_succ_return_idx.map(|idx| {
             let read = Node::Read {
                 collect: call_id,
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index cce14264..6d6503b0 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -48,7 +48,12 @@ use crate::*;
  * actually tracking each source and use of a product and verifying that all of
  * the nodes involved are mutable.
  */
-pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>, allow_sroa_arrays: bool) {
+pub fn sroa(
+    editor: &mut FunctionEditor,
+    reverse_postorder: &Vec<NodeID>,
+    types: &Vec<TypeID>,
+    allow_sroa_arrays: bool,
+) {
     let mut types: HashMap<NodeID, TypeID> = types
         .iter()
         .enumerate()
@@ -226,13 +231,13 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 // that information to the node map for the rest of SROA (this produces some reads
                 // that mix types of indices, since we only read leaves but that's okay since those
                 // reads are not handled by SROA)
-                let indices =
-                    if can_sroa_type(editor, types[collect]) {
-                        indices.chunk_by(|i, j| i.is_field() == j.is_field())
-                               .collect::<Vec<_>>()
-                    } else {
-                        vec![indices.as_ref()]
-                    };
+                let indices = if can_sroa_type(editor, types[collect]) {
+                    indices
+                        .chunk_by(|i, j| i.is_field() == j.is_field())
+                        .collect::<Vec<_>>()
+                } else {
+                    vec![indices.as_ref()]
+                };
 
                 let (field_reads, non_fields_produce_prod) = {
                     if indices.len() == 0 {
@@ -726,7 +731,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
 fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool {
     match &*editor.get_type(typ) {
         Type::Array(_, _) => true,
-        Type::Product(ts) | Type::Summation(ts) => ts.iter().any(|t| type_contains_array(editor, *t)),
+        Type::Product(ts) | Type::Summation(ts) => {
+            ts.iter().any(|t| type_contains_array(editor, *t))
+        }
         _ => false,
     }
 }
diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs
index 2ff9fa9f..a3197041 100644
--- a/juno_frontend/src/codegen.rs
+++ b/juno_frontend/src/codegen.rs
@@ -540,25 +540,31 @@ impl CodeGenerator<'_> {
                 block = after_call_region;
 
                 // Read each of the "inout values" and perform the SSA update
-                let inouts_index = self.builder.builder.create_field_index(1);
+                let has_inouts = !inouts.is_empty();
+                // TODO: We should omit unit returns, if we do so the + 1 below is not needed
                 for (idx, var) in inouts.into_iter().enumerate() {
-                    let index = self.builder.builder.create_field_index(idx);
+                    let index = self.builder.builder.create_field_index(idx + 1);
                     let mut read = self.builder.allocate_node();
                     let read_id = read.id();
-                    read.build_read(call_id, vec![inouts_index.clone(), index].into());
+                    read.build_read(call_id, vec![index].into());
                     self.builder.add_node(read);
 
                     ssa.write_variable(var, block, read_id);
                 }
 
                 // Read the "actual return" value and return it
-                let value_index = self.builder.builder.create_field_index(0);
-                let mut read = self.builder.allocate_node();
-                let read_id = read.id();
-                read.build_read(call_id, vec![value_index].into());
-                self.builder.add_node(read);
+                let result = if !has_inouts {
+                    call_id
+                } else {
+                    let value_index = self.builder.builder.create_field_index(0);
+                    let mut read = self.builder.allocate_node();
+                    let read_id = read.id();
+                    read.build_read(call_id, vec![value_index].into());
+                    self.builder.add_node(read);
+                    read_id
+                };
 
-                (read_id, block)
+                (result, block)
             }
             Expr::Intrinsic {
                 id,
diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs
index b8d04035..1059229a 100644
--- a/juno_frontend/src/semant.rs
+++ b/juno_frontend/src/semant.rs
@@ -808,8 +808,14 @@ fn analyze_program(
                 // Compute the proper type accounting for the inouts (which become returns)
                 let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
 
-                let inout_tuple = types.new_tuple(inout_types);
-                let pure_return_type = types.new_tuple(vec![return_type, inout_tuple]);
+                let mut return_types = vec![return_type];
+                return_types.extend(inout_types);
+                // TODO: Ideally we would omit unit returns
+                let pure_return_type = if return_types.len() == 1 {
+                    return_types.pop().unwrap()
+                } else {
+                    types.new_tuple(return_types)
+                };
 
                 // Finally, we have a properly built environment and we can
                 // start processing the body
@@ -4809,7 +4815,7 @@ fn process_expr(
                     };
 
                     // Now, process the arguments to ensure they has the type needed by this
-                    // constructor
+                    // function
                     let mut arg_vals: Vec<Either<Expr, usize>> = vec![];
                     let mut errors = LinkedList::new();
 
@@ -5009,19 +5015,21 @@ fn process_expr(
 }
 
 fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt {
-    let inout_types = inouts.iter().map(|e| e.get_type()).collect();
-    let inout_type = types.new_tuple(inout_types);
+    let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
 
-    let inout_vals = Expr::Tuple {
-        vals: inouts.clone(),
-        typ: inout_type,
-    };
+    let mut return_types = vec![expr.get_type()];
+    return_types.extend(inout_types);
 
-    let expr_type = expr.get_type();
+    let mut return_vals = vec![expr];
+    return_vals.extend_from_slice(inouts);
 
-    let val = Expr::Tuple {
-        vals: vec![expr, inout_vals],
-        typ: types.new_tuple(vec![expr_type, inout_type]),
+    let val = if return_vals.len() == 1 {
+        return_vals.pop().unwrap()
+    } else {
+        Expr::Tuple {
+            vals: return_vals,
+            typ: types.new_tuple(return_types),
+        }
     };
 
     Stmt::ReturnStmt { expr: val }
diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs
index 6ed10d4d..49dd9cf5 100644
--- a/juno_scheduler/src/default.rs
+++ b/juno_scheduler/src/default.rs
@@ -49,7 +49,6 @@ pub fn default_schedule() -> ScheduleStmt {
         DCE,
         Inline,
         DeleteUncalled,
-        InterproceduralSROA,
         SROA,
         PhiElim,
         DCE,
@@ -96,7 +95,6 @@ pub fn default_schedule() -> ScheduleStmt {
         GVN,
         DCE,
         AutoOutline,
-        InterproceduralSROA,
         SROA,
         ReuseProducts,
         SimplifyCFG,
-- 
GitLab