From 851042f86d1af09250b5c8522296bb5ba1bb763c Mon Sep 17 00:00:00 2001 From: Ryan Ziegler <rzig408@gmail.com> Date: Sun, 16 Feb 2025 17:13:53 -0500 Subject: [PATCH] huh --- llama/src/llama.jn | 19 +++++++++++++++++-- llama/src/main.rs | 4 ++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/llama/src/llama.jn b/llama/src/llama.jn index e89a1645..2ecce7da 100644 --- a/llama/src/llama.jn +++ b/llama/src/llama.jn @@ -1,4 +1,19 @@ +fn embedding_lookup<T:number, num_embeddings:usize, embedding_dim:usize, query_dim:usize>( + weights: T[num_embeddings, embedding_dim], + query: usize[query_dim] +) -> T[query_dim, embedding_dim] { + let res : T[query_dim, embedding_dim]; + for q = 0 to query_dim { + for e = 0 to embedding_dim { + res[q,e] = weights[q,e]; + } + } + return res; +} + #[entry] -fn llama() -> i32 { - return 1; +fn llama_f32<num_embeddings: usize, embedding_dim: usize>(weights_embedding: f32[num_embeddings, embedding_dim]) -> f32 { + let query: usize[1]; + query[0] = 0; + return embedding_lookup::<f32, num_embeddings, embedding_dim, 1>(weights_embedding, query)[0,0]; } diff --git a/llama/src/main.rs b/llama/src/main.rs index 1fef9c3d..06b5a486 100644 --- a/llama/src/main.rs +++ b/llama/src/main.rs @@ -2,11 +2,11 @@ #![feature(entry_insert)] use hercules_rt::runner; -juno_build::juno!("llama"); +juno_build::juno!("llama_f32"); fn main() { async_std::task::block_on(async { - let mut r = runner!(llama); + let mut r = runner!(llama_f32); let result = r.run().await; assert_eq!(result, 1); }); -- GitLab