From ef7f352750d3a679af653f77ea6ae7f12a866264 Mon Sep 17 00:00:00 2001
From: Ryan Ziegler <rzig408@gmail.com>
Date: Tue, 4 Feb 2025 22:52:08 -0500
Subject: [PATCH] current state

---
 Cargo.lock                    | 122 ++++++++++++++++++++++++++++++++++
 Cargo.toml                    |   8 +++
 hercules_opt/src/gcm.rs       |   2 +-
 juno_scheduler/src/lib.rs     |   1 +
 torch/samples/mlp/build.rs    |  16 ++---
 torch/samples/mlp/src/cpu.sch |   2 +
 torch/src/convert.rs          |  15 ++++-
 7 files changed, 154 insertions(+), 12 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index af7902c6..4b11434b 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -310,6 +310,15 @@ dependencies = [
  "wyz",
 ]
 
+[[package]]
+name = "block-buffer"
+version = "0.10.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
+dependencies = [
+ "generic-array",
+]
+
 [[package]]
 name = "blocking"
 version = "1.6.1"
@@ -510,6 +519,15 @@ dependencies = [
  "crossbeam-utils",
 ]
 
+[[package]]
+name = "cpufeatures"
+version = "0.2.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "crc32fast"
 version = "1.4.2"
@@ -556,6 +574,16 @@ version = "0.2.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
 
+[[package]]
+name = "crypto-common"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
+dependencies = [
+ "generic-array",
+ "typenum",
+]
+
 [[package]]
 name = "deranged"
 version = "0.3.11"
@@ -585,6 +613,16 @@ dependencies = [
  "syn 2.0.96",
 ]
 
+[[package]]
+name = "digest"
+version = "0.10.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
+dependencies = [
+ "block-buffer",
+ "crypto-common",
+]
+
 [[package]]
 name = "dot"
 version = "0.1.0"
@@ -774,6 +812,16 @@ dependencies = [
  "pin-project-lite",
 ]
 
+[[package]]
+name = "generic-array"
+version = "0.14.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
+dependencies = [
+ "typenum",
+ "version_check",
+]
+
 [[package]]
 name = "getopts"
 version = "0.2.21"
@@ -2036,6 +2084,17 @@ dependencies = [
  "serde",
 ]
 
+[[package]]
+name = "sha2"
+version = "0.10.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "digest",
+]
+
 [[package]]
 name = "shlex"
 version = "1.3.0"
@@ -2291,6 +2350,63 @@ dependencies = [
  "winnow",
 ]
 
+[[package]]
+name = "torch"
+version = "0.1.0"
+dependencies = [
+ "itertools 0.14.0",
+ "torch_macros",
+]
+
+[[package]]
+name = "torch_llama"
+version = "0.1.0"
+dependencies = [
+ "hercules_rt",
+ "juno_build",
+ "torch",
+]
+
+[[package]]
+name = "torch_macros"
+version = "0.1.0"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.96",
+]
+
+[[package]]
+name = "torch_mlp"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "rand",
+ "sha2",
+ "torch",
+ "with_builtin_macros",
+]
+
+[[package]]
+name = "torch_mobilenet"
+version = "0.1.0"
+dependencies = [
+ "hercules_rt",
+ "juno_build",
+ "torch",
+]
+
+[[package]]
+name = "torch_shufflenet"
+version = "0.1.0"
+dependencies = [
+ "hercules_rt",
+ "juno_build",
+ "torch",
+]
+
 [[package]]
 name = "tracing"
 version = "0.1.41"
@@ -2307,6 +2423,12 @@ version = "0.1.33"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
 
+[[package]]
+name = "typenum"
+version = "1.17.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
+
 [[package]]
 name = "unicode-ident"
 version = "1.0.15"
diff --git a/Cargo.toml b/Cargo.toml
index 890d7924..7090c90c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,4 +32,12 @@ members = [
 	"juno_samples/schedule_test",
 	"juno_samples/edge_detection",
 	"juno_samples/fork_join_tests",
+
+	"torch",
+	"torch/samples/mlp",
+	"torch/samples/mobilenet",
+	"torch/samples/shufflenet",
+
+	# Llama runs into dc id issue
+	"torch/samples/llama",
 ]
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index b2f9767c..1ce96115 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1170,7 +1170,7 @@ fn color_nodes(
         if let Some(device) = device {
             colors.insert(*id, device);
         } else {
-            assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not.");
+            assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not. Name: {:?}, id: {:?}", _editor.func().name, id);
         }
     }
     if bad_node.is_some() {
diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index d4ab432a..8ef26f93 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -1,4 +1,5 @@
 #![feature(exact_size_is_empty)]
+#![feature(entry_insert)]
 
 use std::collections::{HashMap, HashSet};
 use std::fs::File;
diff --git a/torch/samples/mlp/build.rs b/torch/samples/mlp/build.rs
index f8bbb8c3..c29750b2 100644
--- a/torch/samples/mlp/build.rs
+++ b/torch/samples/mlp/build.rs
@@ -50,12 +50,12 @@ fn compute_file_hash(path: &str) -> Result<String, std::io::Error> {
 }
 
 fn main() {
-        TorchCompiler::new().compile_py("mlp.py").unwrap();
-        JunoCompiler::new()
-            .file_in_src("mlp.jn")
-            .unwrap()
-            .schedule_in_src("cpu.sch")
-            .unwrap()
-            .build()
-            .unwrap();
+    // TorchCompiler::new().compile_py("mlp.py").unwrap();
+    JunoCompiler::new()
+        .file_in_src("mlp.jn")
+        .unwrap()
+        .schedule_in_src("cpu.sch")
+        .unwrap()
+        .build()
+        .unwrap();
 }
diff --git a/torch/samples/mlp/src/cpu.sch b/torch/samples/mlp/src/cpu.sch
index db5d70c5..3febebbd 100644
--- a/torch/samples/mlp/src/cpu.sch
+++ b/torch/samples/mlp/src/cpu.sch
@@ -6,6 +6,8 @@ macro juno-setup!(X) {
   lift-dc-math(X);
 }
 macro codegen-prep!(X) {
+  dce(X);
+  xdot[true](X);
   gcm(X);
   dce(X);
   float-collections(X);
diff --git a/torch/src/convert.rs b/torch/src/convert.rs
index 6cb1da34..fbeb9589 100644
--- a/torch/src/convert.rs
+++ b/torch/src/convert.rs
@@ -4059,11 +4059,17 @@ where
             let inputs_str = inputs.join(", ");
 
             if operator.is_idable() {
-                write!(juno_file, "let {} = {};", outputs[0], inputs_str);
+                if !activations.contains_key(&outputs[0]) {
+                    write!(juno_file, "let ").unwrap();
+                }
+                write!(juno_file, "{} = {};\n", outputs[0], inputs_str).unwrap();
             } else if outputs.len() == 1 {
+                if !activations.contains_key(&outputs[0]) {
+                    write!(juno_file, "let ").unwrap();
+                }
                 write!(
                     juno_file,
-                    "let {} = {}{}({});\n",
+                    "{} = {}{}({});\n",
                     outputs[0],
                     operator.function_name(),
                     operator_call,
@@ -4082,7 +4088,10 @@ where
                 )
                 .unwrap();
                 for (i, output) in outputs.iter().enumerate() {
-                    write!(juno_file, "let {} = {}.{};\n", output, outputs_str, i).unwrap();
+                    if !activations.contains_key(output) {
+                        write!(juno_file, "let ").unwrap();
+                    }
+                    write!(juno_file, "{} = {}.{};\n", output, outputs_str, i).unwrap();
                 }
             }
         }
-- 
GitLab