From ef7f352750d3a679af653f77ea6ae7f12a866264 Mon Sep 17 00:00:00 2001 From: Ryan Ziegler <rzig408@gmail.com> Date: Tue, 4 Feb 2025 22:52:08 -0500 Subject: [PATCH] current state --- Cargo.lock | 122 ++++++++++++++++++++++++++++++++++ Cargo.toml | 8 +++ hercules_opt/src/gcm.rs | 2 +- juno_scheduler/src/lib.rs | 1 + torch/samples/mlp/build.rs | 16 ++--- torch/samples/mlp/src/cpu.sch | 2 + torch/src/convert.rs | 15 ++++- 7 files changed, 154 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index af7902c6..4b11434b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,6 +310,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" version = "1.6.1" @@ -510,6 +519,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -556,6 +574,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "deranged" version = "0.3.11" @@ -585,6 +613,16 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dot" version = "0.1.0" @@ -774,6 +812,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.21" @@ -2036,6 +2084,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2291,6 +2350,63 @@ dependencies = [ "winnow", ] +[[package]] +name = "torch" +version = "0.1.0" +dependencies = [ + "itertools 0.14.0", + "torch_macros", +] + +[[package]] +name = "torch_llama" +version = "0.1.0" +dependencies = [ + "hercules_rt", + "juno_build", + "torch", +] + +[[package]] +name = "torch_macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "torch_mlp" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "rand", + "sha2", + "torch", + "with_builtin_macros", +] + +[[package]] +name = "torch_mobilenet" +version = "0.1.0" +dependencies = [ + "hercules_rt", + "juno_build", + "torch", +] + +[[package]] +name = "torch_shufflenet" +version = "0.1.0" +dependencies = [ + "hercules_rt", + "juno_build", + "torch", +] + [[package]] name = "tracing" version = "0.1.41" @@ -2307,6 +2423,12 @@ version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unicode-ident" version = "1.0.15" diff --git a/Cargo.toml b/Cargo.toml index 890d7924..7090c90c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,4 +32,12 @@ members = [ "juno_samples/schedule_test", "juno_samples/edge_detection", "juno_samples/fork_join_tests", + + "torch", + "torch/samples/mlp", + "torch/samples/mobilenet", + "torch/samples/shufflenet", + + # Llama runs into dc id issue + "torch/samples/llama", ] diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index b2f9767c..1ce96115 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1170,7 +1170,7 @@ fn color_nodes( if let Some(device) = device { colors.insert(*id, device); } else { - assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not."); + assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not. Name: {:?}, id: {:?}", _editor.func().name, id); } } if bad_node.is_some() { diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index d4ab432a..8ef26f93 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -1,4 +1,5 @@ #![feature(exact_size_is_empty)] +#![feature(entry_insert)] use std::collections::{HashMap, HashSet}; use std::fs::File; diff --git a/torch/samples/mlp/build.rs b/torch/samples/mlp/build.rs index f8bbb8c3..c29750b2 100644 --- a/torch/samples/mlp/build.rs +++ b/torch/samples/mlp/build.rs @@ -50,12 +50,12 @@ fn compute_file_hash(path: &str) -> Result<String, std::io::Error> { } fn main() { - TorchCompiler::new().compile_py("mlp.py").unwrap(); - JunoCompiler::new() - .file_in_src("mlp.jn") - .unwrap() - .schedule_in_src("cpu.sch") - .unwrap() - .build() - .unwrap(); + // TorchCompiler::new().compile_py("mlp.py").unwrap(); + JunoCompiler::new() + .file_in_src("mlp.jn") + .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() + .build() + .unwrap(); } diff --git a/torch/samples/mlp/src/cpu.sch b/torch/samples/mlp/src/cpu.sch index db5d70c5..3febebbd 100644 --- a/torch/samples/mlp/src/cpu.sch +++ b/torch/samples/mlp/src/cpu.sch @@ -6,6 +6,8 @@ macro juno-setup!(X) { lift-dc-math(X); } macro codegen-prep!(X) { + dce(X); + xdot[true](X); gcm(X); dce(X); float-collections(X); diff --git a/torch/src/convert.rs b/torch/src/convert.rs index 6cb1da34..fbeb9589 100644 --- a/torch/src/convert.rs +++ b/torch/src/convert.rs @@ -4059,11 +4059,17 @@ where let inputs_str = inputs.join(", "); if operator.is_idable() { - write!(juno_file, "let {} = {};", outputs[0], inputs_str); + if !activations.contains_key(&outputs[0]) { + write!(juno_file, "let ").unwrap(); + } + write!(juno_file, "{} = {};\n", outputs[0], inputs_str).unwrap(); } else if outputs.len() == 1 { + if !activations.contains_key(&outputs[0]) { + write!(juno_file, "let ").unwrap(); + } write!( juno_file, - "let {} = {}{}({});\n", + "{} = {}{}({});\n", outputs[0], operator.function_name(), operator_call, @@ -4082,7 +4088,10 @@ where ) .unwrap(); for (i, output) in outputs.iter().enumerate() { - write!(juno_file, "let {} = {}.{};\n", output, outputs_str, i).unwrap(); + if !activations.contains_key(output) { + write!(juno_file, "let ").unwrap(); + } + write!(juno_file, "{} = {}.{};\n", output, outputs_str, i).unwrap(); } } } -- GitLab