diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index 8fe978c5c9554fa7d0fd42f480ff724dcdc9cb36..874e75e739b0f05f72f0f8cfa4c6ae6540ed9c6f 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -180,6 +180,13 @@ pub fn outline( editor.edit(|mut edit| { // Step 2: assemble the outlined function. let u32_ty = edit.add_type(Type::UnsignedInteger32); + let return_types: Box<[_]> = return_idx_to_inside_id + .iter() + .map(|id| typing[id.idx()]) + .chain(callee_succ_return_idx.map(|_| u32_ty)) + .collect(); + let single_return = return_types.len() == 1; + let mut outlined = Function { name: format!( "{}_{}", @@ -191,13 +198,11 @@ pub fn outline( .map(|id| typing[id.idx()]) .chain(callee_pred_param_idx.map(|_| u32_ty)) .collect(), - return_type: edit.add_type(Type::Product( - return_idx_to_inside_id - .iter() - .map(|id| typing[id.idx()]) - .chain(callee_succ_return_idx.map(|_| u32_ty)) - .collect(), - )), + return_type: if single_return { + return_types[0] + } else { + edit.add_type(Type::Product(return_types)) + }, num_dynamic_constants: edit.get_num_dynamic_constant_params(), entry: false, nodes: vec![], @@ -393,18 +398,24 @@ pub fn outline( data_ids.push(cons_node_id); } - // Build the return product. - let mut construct_id = NodeID::new(outlined.nodes.len()); - outlined.nodes.push(Node::Constant { id: cons_id }); - for (idx, data) in data_ids.into_iter().enumerate() { - let write = Node::Write { - collect: construct_id, - data: data, - indices: Box::new([Index::Field(idx)]), - }; - construct_id = NodeID::new(outlined.nodes.len()); - outlined.nodes.push(write); - } + // Build the return value + let construct_id = if single_return { + assert!(data_ids.len() == 1); + data_ids.pop().unwrap() + } else { + let mut construct_id = NodeID::new(outlined.nodes.len()); + outlined.nodes.push(Node::Constant { id: cons_id }); + for (idx, data) in data_ids.into_iter().enumerate() { + let write = Node::Write { + collect: construct_id, + data: data, + indices: Box::new([Index::Field(idx)]), + }; + construct_id = NodeID::new(outlined.nodes.len()); + outlined.nodes.push(write); + } + construct_id + }; // Return the return product. outlined.nodes.push(Node::Return { @@ -505,16 +516,20 @@ pub fn outline( }; // Create the read nodes from the call node to get the outputs of the - // outlined function. - let output_reads: Vec<_> = (0..return_idx_to_inside_id.len()) - .map(|idx| { - let read = Node::Read { - collect: call_id, - indices: Box::new([Index::Field(idx)]), - }; - edit.add_node(read) - }) - .collect(); + // outlined function (if there are multiple returned values) + let output_reads: Vec<_> = if single_return { + vec![call_id] + } else { + (0..return_idx_to_inside_id.len()) + .map(|idx| { + let read = Node::Read { + collect: call_id, + indices: Box::new([Index::Field(idx)]), + }; + edit.add_node(read) + }) + .collect() + }; let indicator_read = callee_succ_return_idx.map(|idx| { let read = Node::Read { collect: call_id, diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index cce1426440bb9b4535dc82d70d377dcc422ab8bd..6d6503b09dddefec63cf271f5a2caa1a1169d59c 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -48,7 +48,12 @@ use crate::*; * actually tracking each source and use of a product and verifying that all of * the nodes involved are mutable. */ -pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>, allow_sroa_arrays: bool) { +pub fn sroa( + editor: &mut FunctionEditor, + reverse_postorder: &Vec<NodeID>, + types: &Vec<TypeID>, + allow_sroa_arrays: bool, +) { let mut types: HashMap<NodeID, TypeID> = types .iter() .enumerate() @@ -226,13 +231,13 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // that information to the node map for the rest of SROA (this produces some reads // that mix types of indices, since we only read leaves but that's okay since those // reads are not handled by SROA) - let indices = - if can_sroa_type(editor, types[collect]) { - indices.chunk_by(|i, j| i.is_field() == j.is_field()) - .collect::<Vec<_>>() - } else { - vec![indices.as_ref()] - }; + let indices = if can_sroa_type(editor, types[collect]) { + indices + .chunk_by(|i, j| i.is_field() == j.is_field()) + .collect::<Vec<_>>() + } else { + vec![indices.as_ref()] + }; let (field_reads, non_fields_produce_prod) = { if indices.len() == 0 { @@ -726,7 +731,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool { match &*editor.get_type(typ) { Type::Array(_, _) => true, - Type::Product(ts) | Type::Summation(ts) => ts.iter().any(|t| type_contains_array(editor, *t)), + Type::Product(ts) | Type::Summation(ts) => { + ts.iter().any(|t| type_contains_array(editor, *t)) + } _ => false, } } diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index 2ff9fa9f6b5dd07d30b71eb130aa3b4f019f28e9..a3197041da5a47b50f36ce81f0e36647083e3c17 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -540,25 +540,31 @@ impl CodeGenerator<'_> { block = after_call_region; // Read each of the "inout values" and perform the SSA update - let inouts_index = self.builder.builder.create_field_index(1); + let has_inouts = !inouts.is_empty(); + // TODO: We should omit unit returns, if we do so the + 1 below is not needed for (idx, var) in inouts.into_iter().enumerate() { - let index = self.builder.builder.create_field_index(idx); + let index = self.builder.builder.create_field_index(idx + 1); let mut read = self.builder.allocate_node(); let read_id = read.id(); - read.build_read(call_id, vec![inouts_index.clone(), index].into()); + read.build_read(call_id, vec![index].into()); self.builder.add_node(read); ssa.write_variable(var, block, read_id); } // Read the "actual return" value and return it - let value_index = self.builder.builder.create_field_index(0); - let mut read = self.builder.allocate_node(); - let read_id = read.id(); - read.build_read(call_id, vec![value_index].into()); - self.builder.add_node(read); + let result = if !has_inouts { + call_id + } else { + let value_index = self.builder.builder.create_field_index(0); + let mut read = self.builder.allocate_node(); + let read_id = read.id(); + read.build_read(call_id, vec![value_index].into()); + self.builder.add_node(read); + read_id + }; - (read_id, block) + (result, block) } Expr::Intrinsic { id, diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index b8d04035efb17be31a6964e72198b48bbd80513b..1059229a8f44b5363fae945f0640cd041d140496 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -808,8 +808,14 @@ fn analyze_program( // Compute the proper type accounting for the inouts (which become returns) let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - let inout_tuple = types.new_tuple(inout_types); - let pure_return_type = types.new_tuple(vec![return_type, inout_tuple]); + let mut return_types = vec![return_type]; + return_types.extend(inout_types); + // TODO: Ideally we would omit unit returns + let pure_return_type = if return_types.len() == 1 { + return_types.pop().unwrap() + } else { + types.new_tuple(return_types) + }; // Finally, we have a properly built environment and we can // start processing the body @@ -4809,7 +4815,7 @@ fn process_expr( }; // Now, process the arguments to ensure they has the type needed by this - // constructor + // function let mut arg_vals: Vec<Either<Expr, usize>> = vec![]; let mut errors = LinkedList::new(); @@ -5009,19 +5015,21 @@ fn process_expr( } fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt { - let inout_types = inouts.iter().map(|e| e.get_type()).collect(); - let inout_type = types.new_tuple(inout_types); + let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - let inout_vals = Expr::Tuple { - vals: inouts.clone(), - typ: inout_type, - }; + let mut return_types = vec![expr.get_type()]; + return_types.extend(inout_types); - let expr_type = expr.get_type(); + let mut return_vals = vec![expr]; + return_vals.extend_from_slice(inouts); - let val = Expr::Tuple { - vals: vec![expr, inout_vals], - typ: types.new_tuple(vec![expr_type, inout_type]), + let val = if return_vals.len() == 1 { + return_vals.pop().unwrap() + } else { + Expr::Tuple { + vals: return_vals, + typ: types.new_tuple(return_types), + } }; Stmt::ReturnStmt { expr: val } diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs index 6ed10d4d0e3c247f578f29a6640cb561df970268..49dd9cf52abf9f020a43d1975b5ea59ef8a8706e 100644 --- a/juno_scheduler/src/default.rs +++ b/juno_scheduler/src/default.rs @@ -49,7 +49,6 @@ pub fn default_schedule() -> ScheduleStmt { DCE, Inline, DeleteUncalled, - InterproceduralSROA, SROA, PhiElim, DCE, @@ -96,7 +95,6 @@ pub fn default_schedule() -> ScheduleStmt { GVN, DCE, AutoOutline, - InterproceduralSROA, SROA, ReuseProducts, SimplifyCFG,