diff --git a/Cargo.lock b/Cargo.lock index 9f469cc5f4f4134b6ed4ab8645310bf881ff0f4e..6642aef7745bf91996cb283de2c20c34953b3e90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1354,6 +1354,17 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_grape_conv" +version = "0.1.0" +dependencies = [ + "async-std", + "grape_sim", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_grape_reduction" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index c07fdae2a69d6a106442e63981f9216a9ff4ccd1..fd37ab44b5fe189caa1611962a040e7c836b6bfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,4 @@ [workspace] -resolver = "2" members = [ "hercules_cg", "hercules_ir", @@ -38,6 +37,9 @@ members = [ "juno_samples/simple3", "juno_samples/grape", "juno_samples/grape_reduction", + "juno_samples/grape_conv", + "juno_samples/grape_reduction", "juno_scheduler", "juno_utils", ] +resolver = "2" diff --git a/hercules_cg/src/grape.rs b/hercules_cg/src/grape.rs index 6dcf7ce0372c143315b27052eb27a2c070669c37..a7093f777a034be0a89044360fa04c169e14b6ea 100644 --- a/hercules_cg/src/grape.rs +++ b/hercules_cg/src/grape.rs @@ -309,21 +309,10 @@ where return Err(()); }; - let node = &self.function.nodes[pos[0].idx()]; - - let Node::Constant { id } = node else { - return Err(()); - }; - - let Constant::UnsignedInteger64(constant) = self.constants[id.idx()] else { - panic!(); - return Err(()); - }; - inputs.insert(*user); } } - _ => todo!(), + _ => {} }; } // inputs @@ -369,6 +358,32 @@ where let chip_inputs = chip_inputs.unwrap(); + let mut visited = HashSet::new(); + let mut stack: Vec<NodeID> = chip_inputs.iter().cloned().collect(); + + while let Some(node) = stack.pop() { + if visited.insert(node) { + for &neighbor in get_uses(&self.function.nodes[node.idx()]).as_ref() { + stack.push(neighbor); + } + } + } + + let mut host_nodes = HashSet::new(); + + for node in visited { + if chip_inputs.contains(&node) { + host_nodes.insert(node); + continue; + } else { + host_nodes.insert(node); + } + } + + assert!(chip_inputs.len() <= 16); + + println!("chip_inputs: {:?}", chip_inputs); + let chip = self.schedule_slice_toplevel(chip_inputs); // .expect("blargh"); let chip = if let Err((a, b, _)) = chip { @@ -385,6 +400,8 @@ where chip.unwrap() }; + // todo!(); + let (chip, output_mapping, input_mapping) = chip; // chip.0.pretty_print(); // println!("chip: {:?}", chip); @@ -428,6 +445,7 @@ where // todo!(); // Assemble to bitstream. + // todo!(); self.codegen_launch_rust_function( binary_string, @@ -435,6 +453,8 @@ where input_mapping, output_mapping, returns, + self.bbs, + host_nodes, w, )?; @@ -483,6 +503,8 @@ where input_mapping: Vec<NodeID>, output_mapping: HashMap<NodeID, usize>, returns: Vec<NodeID>, + bbs: &BasicBlocks, + host_nodes: HashSet<NodeID>, w: &mut Writer, ) -> Result<(), Error> { // Write bitstream; @@ -587,36 +609,50 @@ where // Iterate parameter mapping println!("input_mapping: {:?}", input_mapping); // panic!(); - for (idx, node) in input_mapping.iter().enumerate() { - let rhs = match &self.function.nodes[node.idx()] { - Node::Parameter { index } => { - format!("p{}", index) - } - Node::Read { collect, indices } => { - // Figure out what parameter the collection is, if any. IT HAS TO BE ONE - //Do the read in this function. - - // FIXME: just asssume its p0 for now - let Index::Position(i) = &indices[0] else { - panic!(); - }; - let Node::Constant { id } = self.function.nodes[i[0].idx()] else { - panic!(); - }; - let Constant::UnsignedInteger64(v) = self.constants[id.idx()] else { - panic!() - }; - - format!("*((p0.0 as *const i16).add({}))", v) + for block in self.bbs.1.iter() { + for id in block { + if host_nodes.contains(id) { + write!(w, "let "); + self.codegen_data_node(*id, w)?; + write!(w, "\n"); } - Node::Start => "".to_owned(), - _ => panic!(), - }; - // skip start node + } + } + + for (idx, node) in input_mapping.iter().enumerate() { if node.idx() != 0 { + let rhs = self.get_value(*node); + write!(w, "input[{idx}] = {};\n", rhs); } + + // let rhs = match &self.function.nodes[node.idx()] { + // Node::Parameter { index } => { + // format!("p{}", index) + // } + // Node::Read { collect, indices } => { + // // Figure out what parameter the collection is, if any. IT HAS TO BE ONE + // //Do the read in this function. + + // // FIXME: just asssume its p0 for now + // let Index::Position(i) = &indices[0] else { + // panic!(); + // }; + + // let Node::Constant { id } = self.function.nodes[i[0].idx()] else { + // panic!(); + // }; + // let Constant::UnsignedInteger64(v) = self.constants[id.idx()] else { + // panic!() + // }; + + // format!("*((p0.0 as *const i16).add({}))", v) + // } + // Node::Start => "".to_owned(), + // _ => panic!(), + // }; + // skip start node } write!(w, "grape_sim::send_data(input.as_ptr());\n")?; write!(w, "await_valid();\n")?; @@ -674,6 +710,413 @@ where output } + fn get_value(&self, id: NodeID) -> String { + format!("node_{}", id.idx()) + } + + fn codegen_data_node<Writer: Write>(&self, id: NodeID, w: &mut Writer) -> Result<(), Error> { + let func = &self.function; + match func.nodes[id.idx()] { + Node::Parameter { index } => write!(w, "{} = p{};", self.get_value(id), index)?, + Node::Constant { id: cons_id } => { + write!(w, "{} = ", self.get_value(id))?; + match self.constants[cons_id.idx()] { + Constant::Boolean(val) => write!(w, "{}", val)?, + Constant::Integer8(val) => write!(w, "{}i8", val)?, + Constant::Integer16(val) => write!(w, "{}i16", val)?, + Constant::Integer32(val) => write!(w, "{}i32", val)?, + Constant::Integer64(val) => write!(w, "{}i64", val)?, + Constant::UnsignedInteger8(val) => write!(w, "{}u8", val)?, + Constant::UnsignedInteger16(val) => write!(w, "{}u16", val)?, + Constant::UnsignedInteger32(val) => write!(w, "{}u32", val)?, + Constant::UnsignedInteger64(val) => write!(w, "{}u64", val)?, + Constant::Float32(val) => { + if val == f32::INFINITY { + write!(w, "f32::INFINITY")? + } else if val == f32::NEG_INFINITY { + write!(w, "f32::NEG_INFINITY")? + } else { + write!(w, "{}f32", val)? + } + } + Constant::Float64(val) => { + if val == f64::INFINITY { + write!(w, "f64::INFINITY")? + } else if val == f64::NEG_INFINITY { + write!(w, "f64::NEG_INFINITY")? + } else { + write!(w, "{}f64", val)? + } + } + Constant::Product(ty, _) + | Constant::Summation(ty, _, _) + | Constant::Array(ty) => { + todo!() + } + } + write!(w, ";")?; + } + Node::DynamicConstant { id: dc_id } => { + todo!() + } + Node::Unary { op, input } => { + match op { + UnaryOperator::Not => { + write!(w, "{} = !{};", self.get_value(id), self.get_value(input))? + } + UnaryOperator::Neg => { + write!(w, "{} = -{};", self.get_value(id), self.get_value(input))? + } + UnaryOperator::Cast(ty) => write!( + w, + "{} = {} as {};", + self.get_value(id), + self.get_value(input), + self.get_type(ty) + )?, + }; + } + Node::Binary { op, left, right } => { + let op = match op { + BinaryOperator::Add => "+", + BinaryOperator::Sub => "-", + BinaryOperator::Mul => "*", + BinaryOperator::Div => "/", + BinaryOperator::Rem => "%", + BinaryOperator::LT => "<", + BinaryOperator::LTE => "<=", + BinaryOperator::GT => ">", + BinaryOperator::GTE => ">=", + BinaryOperator::EQ => "==", + BinaryOperator::NE => "!=", + BinaryOperator::Or => "|", + BinaryOperator::And => "&", + BinaryOperator::Xor => "^", + BinaryOperator::LSh => "<<", + BinaryOperator::RSh => ">>", + }; + + write!( + w, + "{} = {} {} {};", + self.get_value(id), + self.get_value(left), + op, + self.get_value(right) + )?; + } + Node::Ternary { + op, + first, + second, + third, + } => { + match op { + TernaryOperator::Select => write!( + w, + "{} = if {} {{{}}} else {{{}}};", + self.get_value(id), + self.get_value(first), + self.get_value(second), + self.get_value(third), + )?, + }; + } + Node::Read { + collect, + ref indices, + } => { + let collect_ty = self.typing[collect.idx()]; + let self_ty = self.typing[id.idx()]; + let offset = self.codegen_index_math(collect_ty, indices)?; + if self.types[self_ty.idx()].is_primitive() { + write!( + w, + "{} = ({}.byte_add({} as usize).0 as *mut {}).read();", + self.get_value(id), + self.get_value(collect), + offset, + self.get_type(self_ty) + )?; + } else { + write!( + w, + "{} = {}.byte_add({} as usize);", + self.get_value(id), + self.get_value(collect), + offset, + )?; + } + } + Node::Write { + collect, + data, + ref indices, + } => { + todo!() + } + _ => panic!( + "PANIC: Can't lower {:?} in {}.", + func.nodes[id.idx()], + func.name + ), + } + Ok(()) + } + + fn codegen_type_size(&self, ty: TypeID) -> String { + match self.types[ty.idx()] { + Type::Control | Type::MultiReturn(_) => panic!(), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { + "1".to_string() + } + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => "2".to_string(), + Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => "4".to_string(), + Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => "8".to_string(), + Type::Product(ref fields) => { + let fields_align = fields + .into_iter() + .map(|id| get_type_alignment(&self.types, *id)); + let fields: Vec<String> = fields + .into_iter() + .map(|id| self.codegen_type_size(*id)) + .collect(); + + // Emit LLVM IR to round up to the alignment of the next field, + // and then add the size of that field. At the end, round up to + // the alignment of the whole struct. + let mut acc_size = "0".to_string(); + for (field_align, field) in zip(fields_align, fields) { + acc_size = format!( + "(({} + {}) & !{})", + acc_size, + field_align - 1, + field_align - 1 + ); + acc_size = format!("({} + {})", acc_size, field); + } + let total_align = get_type_alignment(&self.types, ty); + format!( + "(({} + {}) & !{})", + acc_size, + total_align - 1, + total_align - 1 + ) + } + Type::Summation(ref variants) => { + let variants = variants.into_iter().map(|id| self.codegen_type_size(*id)); + + // The size of a summation is the size of the largest field, + // plus 1 byte and alignment for the discriminant. + let mut acc_size = "0".to_string(); + for variant in variants { + acc_size = format!("::core::cmp::max({}, {})", acc_size, variant); + } + + // No alignment is necessary before the 1 byte discriminant. + let total_align = get_type_alignment(&self.types, ty); + format!( + "(({} + 1 + {}) & !{})", + acc_size, + total_align - 1, + total_align - 1 + ) + } + Type::Array(elem, ref bounds) => { + // The size of an array is the size of the element multipled by + // the dynamic constant bounds. + let mut acc_size = self.codegen_type_size(elem); + let elem_align = get_type_alignment(&self.types, elem); + acc_size = format!( + "(({} + {}) & !{})", + acc_size, + elem_align - 1, + elem_align - 1 + ); + for dc in bounds { + acc_size = format!("{} * ", acc_size); + self.codegen_dynamic_constant(*dc, &mut acc_size).unwrap(); + } + format!("({})", acc_size) + } + } + } + + /* + * Lower dynamic constant in Hercules IR into a Rust expression. + */ + fn codegen_dynamic_constant<Writer: Write>( + &self, + id: DynamicConstantID, + w: &mut Writer, + ) -> Result<(), Error> { + match &self.dynamic_constants[id.idx()] { + DynamicConstant::Constant(val) => write!(w, "{}", val)?, + DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?, + DynamicConstant::Add(xs) => { + write!(w, "(")?; + let mut xs = xs.iter(); + self.codegen_dynamic_constant(*xs.next().unwrap(), w)?; + for x in xs { + write!(w, "+")?; + self.codegen_dynamic_constant(*x, w)?; + } + write!(w, ")")?; + } + DynamicConstant::Sub(left, right) => { + write!(w, "(")?; + self.codegen_dynamic_constant(*left, w)?; + write!(w, "-")?; + self.codegen_dynamic_constant(*right, w)?; + write!(w, ")")?; + } + DynamicConstant::Mul(xs) => { + write!(w, "(")?; + let mut xs = xs.iter(); + self.codegen_dynamic_constant(*xs.next().unwrap(), w)?; + for x in xs { + write!(w, "*")?; + self.codegen_dynamic_constant(*x, w)?; + } + write!(w, ")")?; + } + DynamicConstant::Div(left, right) => { + write!(w, "(")?; + self.codegen_dynamic_constant(*left, w)?; + write!(w, "/")?; + self.codegen_dynamic_constant(*right, w)?; + write!(w, ")")?; + } + DynamicConstant::Rem(left, right) => { + write!(w, "(")?; + self.codegen_dynamic_constant(*left, w)?; + write!(w, "%")?; + self.codegen_dynamic_constant(*right, w)?; + write!(w, ")")?; + } + DynamicConstant::Min(xs) => { + let mut xs = xs.iter().peekable(); + + // Track the number of parentheses we open that need to be closed later + let mut opens = 0; + while let Some(x) = xs.next() { + if xs.peek().is_none() { + // For the last element, we just print it + self.codegen_dynamic_constant(*x, w)?; + } else { + // Otherwise, we create a new call to min and print the element as the + // first argument + write!(w, "::core::cmp::min(")?; + self.codegen_dynamic_constant(*x, w)?; + write!(w, ",")?; + opens += 1; + } + } + for _ in 0..opens { + write!(w, ")")?; + } + } + DynamicConstant::Max(xs) => { + let mut xs = xs.iter().peekable(); + + let mut opens = 0; + while let Some(x) = xs.next() { + if xs.peek().is_none() { + self.codegen_dynamic_constant(*x, w)?; + } else { + write!(w, "::core::cmp::max(")?; + self.codegen_dynamic_constant(*x, w)?; + write!(w, ",")?; + opens += 1; + } + } + for _ in 0..opens { + write!(w, ")")?; + } + } + } + Ok(()) + } + + /* + * Emit logic to index into an collection. + */ + fn codegen_index_math( + &self, + mut collect_ty: TypeID, + indices: &[Index], + ) -> Result<String, Error> { + let mut acc_offset = "0".to_string(); + for index in indices { + match index { + Index::Field(idx) => { + let Type::Product(ref fields) = self.types[collect_ty.idx()] else { + panic!() + }; + + // Get the offset of the field at index `idx` by calculating + // the product's size up to field `idx`, then offseting the + // base pointer by that amount. + for field in &fields[..*idx] { + let field_align = get_type_alignment(&self.types, *field); + let field = self.codegen_type_size(*field); + acc_offset = format!( + "((({} + {}) & !{}) + {})", + acc_offset, + field_align - 1, + field_align - 1, + field + ); + } + let last_align = get_type_alignment(&self.types, fields[*idx]); + acc_offset = format!( + "(({} + {}) & !{})", + acc_offset, + last_align - 1, + last_align - 1 + ); + collect_ty = fields[*idx]; + } + Index::Variant(idx) => { + // The tag of a summation is at the end of the summation, so + // the variant pointer is just the base pointer. Do nothing. + let Type::Summation(ref variants) = self.types[collect_ty.idx()] else { + panic!() + }; + collect_ty = variants[*idx]; + } + Index::Position(ref pos) => { + let Type::Array(elem, ref dims) = self.types[collect_ty.idx()] else { + panic!() + }; + + // The offset of the position into an array is: + // + // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... + let elem_size = self.codegen_type_size(elem); + let elem_align = get_type_alignment(&self.types, elem); + let aligned_elem_size = format!( + "(({} + {}) & !{})", + elem_size, + elem_align - 1, + elem_align - 1 + ); + for (p, s) in zip(pos, dims) { + let p = self.get_value(*p); + acc_offset = format!("{} * ", acc_offset); + self.codegen_dynamic_constant(*s, &mut acc_offset)?; + acc_offset = format!("({} + {})", acc_offset, p); + } + + // Convert offset in # elements -> # bytes. + acc_offset = format!("({} * {})", acc_offset, aligned_elem_size); + collect_ty = elem; + } + } + } + Ok(acc_offset) + } + fn assemble_slice(slice: SliceDesc<H, W>) -> Vec<bool> { // First row of FUs let mut bits = vec![]; @@ -1179,15 +1622,6 @@ where // live_outs, // )) // } - - /* - * Lower data nodes in Hercules IR into LLVM instructions. - */ - fn codegen_data_node(&self, id: NodeID) -> Result<(), Error> { - todo!(); - - Ok(()) - } } fn write_xdot<const H: usize, const W: usize, T: Write>(slice: &SliceDesc<H, W>, writer: &mut T) diff --git a/juno_samples/grape_conv/Cargo.toml b/juno_samples/grape_conv/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..64e578b2d42444f1252225e3160eadf3ca4b0d97 --- /dev/null +++ b/juno_samples/grape_conv/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "juno_grape_conv" +version = "0.1.0" +authors = ["Xavier Routh <xrouth2@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_grape_conv" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] +grape = [] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +grape_sim = { path = "../../grape_sim" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/grape_conv/build.rs b/juno_samples/grape_conv/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..021af808a4356834ebfad4c0ce9cbf85bd716de2 --- /dev/null +++ b/juno_samples/grape_conv/build.rs @@ -0,0 +1,35 @@ +use juno_build::JunoCompiler; + +fn main() { + #[cfg(not(feature = "cuda"))] + { + JunoCompiler::new() + .file_in_src("conv.jn") + .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() + .build() + .unwrap(); + } + #[cfg(feature = "grape")] + { + JunoCompiler::new() + .file_in_src("conv.jn") + .unwrap() + .schedule_in_src("grape.sch") + .unwrap() + .build() + .unwrap(); + } + + #[cfg(feature = "cuda")] + { + JunoCompiler::new() + .file_in_src("simple.jn") + .unwrap() + .schedule_in_src("gpu.sch") + .unwrap() + .build() + .unwrap(); + } +} diff --git a/juno_samples/grape_conv/src/conv.jn b/juno_samples/grape_conv/src/conv.jn new file mode 100644 index 0000000000000000000000000000000000000000..2b67072ac6f8ec04d79d83c723bb8adc8cfdcced --- /dev/null +++ b/juno_samples/grape_conv/src/conv.jn @@ -0,0 +1,44 @@ +fn conv1d<n: usize, k: usize>(a : i16[n], kernel: i16[k]) -> i16[n] { + let res : i16[n]; + + for i = 0 to n { + let window_left = i as i64 - (k as i64) / 2; + let window_right = i as i64 + (k as i64) / 2; + if window_left < 0 { + window_left = 0; + } + if window_right >= n as i64 { + window_right = n as i64 - 1; + } + let acc: i16 = 0; + for j = 0 to (window_right - window_left + 1) as u64 { + let inc_j = j + window_left as u64; + acc += a[inc_j] * kernel[j]; + } + res[i] = acc; + } + + return res; +} + +fn wrapper(a0, a1, a2, a3, k0, k1, k2 : i16) -> i16, i16, i16, i16 { + let a : i16[4]; + let k : i16[3]; + let r : i16[4]; + a[0] = a0; + a[1] = a1; + a[2] = a2; + a[3] = a3; + k[0] = k0; + k[1] = k1; + k[2] = k2; + r = conv1d::<4, 3>(a, k); + return r[0], r[1], r[2], r[3]; +} + +#[entry] +fn entry(a: i16[4], k: i16[3]) -> i16, i16, i16, i16 { + let c, d, e, f = wrapper(a[0], a[1], a[2], a[3], k[0], k[1], k[2]); + return c, d, e, f; +} + diff --git a/juno_samples/grape_conv/src/cpu.sch b/juno_samples/grape_conv/src/cpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..7934b277d2ccc3daa385f95182ad6555dd4040f1 --- /dev/null +++ b/juno_samples/grape_conv/src/cpu.sch @@ -0,0 +1,18 @@ +gvn(*); +phi-elim(*); +dce(*); + + +ip-sroa(*); +sroa(*); +dce(*); +gvn(*); +phi-elim(*); +dce(*); + +infer-schedules(*); + + +gcm(*); +dce(*); +gcm(*); diff --git a/juno_samples/grape_conv/src/grape.sch b/juno_samples/grape_conv/src/grape.sch new file mode 100644 index 0000000000000000000000000000000000000000..3aef5b9af52356456d080dcdf5ce0bca1d2263ee --- /dev/null +++ b/juno_samples/grape_conv/src/grape.sch @@ -0,0 +1,30 @@ +gvn(*); +phi-elim(*); +ccp(*); +simplify-cfg(*); +dce(*); + +inline(wrapper); +delete-uncalled(*); +fixpoint stop after 10 { + forkify(*); + fork-guard-elim(*); + fork-unroll(*); + predication(*); + gvn(*); + phi-elim(*); + ccp(*); + simplify-cfg(*); + dce(*); + lift-dc-math(*); +} +a2p(*); +sroa(*); +gvn(*); +phi-elim(*); +ccp(*); +simplify-cfg(*); +dce(*); +grape(wrapper); +xdot[true](*); +gcm(*); diff --git a/juno_samples/grape_conv/src/main.rs b/juno_samples/grape_conv/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..bf509d5d118324ef0982cbd1a5ad6cf847c34c6f --- /dev/null +++ b/juno_samples/grape_conv/src/main.rs @@ -0,0 +1,72 @@ +#![feature(concat_idents)] + +#[cfg(feature = "cuda")] +use hercules_rt::CUDABox; + +use hercules_rt::{runner, HerculesCPURef}; + +juno_build::juno!("conv"); + +#[cfg(feature = "grape")] +use grape_sim::*; + +fn conv1d<const N: usize, const K: usize>(a: &[i16; N], kernel: &[i16; K]) -> [i16; N] { + let mut res = [0i16; N]; + + for i in 0..N { + let mut window_left = i as i64 - (K as i64) / 2; + let mut window_right = i as i64 + (K as i64) / 2; + + if window_left < 0 { + window_left = 0; + } + if window_right >= N as i64 { + window_right = N as i64 - 1; + } + + let mut acc: i16 = 0; + for j in 0..(window_right - window_left + 1) as usize { + let inc_j = j + window_left as usize; + acc += a[inc_j] * kernel[j]; + } + res[i] = acc; + } + + res +} + +fn main() { + async_std::task::block_on(async { + let a: Box<[i16]> = Box::new([1, 2, 3, 4]); + let b: Box<[i16]> = Box::new([5, 6, 7]); + + let result: [i16; 4] = conv1d::<4, 3>( + a.as_ref().try_into().unwrap(), + b.as_ref().try_into().unwrap(), + ); + + #[cfg(not(feature = "cuda"))] + { + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURef::from_slice(&b); + let mut r = runner!(entry); + let c = r.run(a, b).await; + print!("{:?}", c); + + assert_eq!(c, (result[0], result[1], result[2], result[3])); + } + #[cfg(feature = "cuda")] + { + let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a)); + let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); + let mut r = runner!(simple); + let c = r.run(8, a.get_ref(), b.get_ref()).await; + assert_eq!(c, 120); + } + }); +} + +#[test] +fn simple3_test() { + main(); +} diff --git a/juno_samples/grape_reduction/src/grape.sch b/juno_samples/grape_reduction/src/grape.sch index 248254316f8209ed3a4800c070636892e19ac07a..d098a3327e7790ccbea13a38e3abe7d1f2c77ef0 100644 --- a/juno_samples/grape_reduction/src/grape.sch +++ b/juno_samples/grape_reduction/src/grape.sch @@ -8,38 +8,66 @@ inline(fake_entry); delete-uncalled(*); forkify(*); +fork-tile[4, 0, false, true](*); +let a = fork-split(*); +print[a._1_fake_entry.fj1](); +let inner = outline(a._1_fake_entry.fj1); +rename["inner"](inner); fork-guard-elim(*); dce(*); +xdot[true](*); + +fixpoint stop after 4 { + forkify(a._1_fake_entry.fj1); + fork-guard-elim(a._1_fake_entry.fj1); + fork-unroll(a._1_fake_entry.fj1); + predication(a._1_fake_entry.fj1); + gvn(a._1_fake_entry.fj1); + phi-elim(a._1_fake_entry.fj1); + ccp(a._1_fake_entry.fj1); + simplify-cfg(a._1_fake_entry.fj1); + dce(a._1_fake_entry.fj1); + lift-dc-math(a._1_fake_entry.fj1); +} -fixpoint stop after 10 { - forkify(*); - fork-guard-elim(*); - fork-unroll(*); - predication(*); - gvn(*); - phi-elim(*); - ccp(*); - simplify-cfg(*); - dce(*); - lift-dc-math(*); +fixpoint stop after 4 { + forkify(inner); + fork-guard-elim(inner); + fork-unroll(inner); + predication(inner); + gvn(inner); + phi-elim(inner); + ccp(inner); + simplify-cfg(inner); + dce(inner); + lift-dc-math(inner); } + + +xdot[true](*); + +fork-unroll(a._1_fake_entry.fj0); + + // xdot[true](*); + +reassociate(inner); + + + + a2p(*); sroa(*); xdot[true](*); -// reassociate go brr -reassociate(*); -xdot[true](*); - gvn(*); phi-elim(*); ccp(*); simplify-cfg(*); dce(*); -grape(fake_entry); +grape(inner); xdot[true](*); gcm(*); xdot[true](*); diff --git a/juno_samples/grape_reduction/src/simple.jn b/juno_samples/grape_reduction/src/simple.jn index 19b5d43c6a65f25e829ec0c01b553032d2e024cc..df0d0a0c0ed41ed5dd718f5b0f9f5e997bad63d1 100644 --- a/juno_samples/grape_reduction/src/simple.jn +++ b/juno_samples/grape_reduction/src/simple.jn @@ -8,24 +8,10 @@ fn reduce<n: usize>(a : i16[n]) -> i16 { return acc; } -fn reduce2<n: usize, k: usize>(a : i16[n, k]) -> i16 { - let acc: i16 = 0; - - for j = 0 to k { - @inner for i = 0 to n { - acc += a[i, j]; - } - } - - return acc; -} - fn fake_entry(a: i16[16]) -> i16 { @this { let r = reduce::<16>(a); - // let r2 = reduce2::<16, 2>(b); - return r; } }