diff --git a/llama/src/llama.jn b/llama/src/llama.jn index e89a1645b5c5323740be71634076b2f10b19eef1..2ecce7da2626d17e811d316d6a7d2ce1f16dd6fc 100644 --- a/llama/src/llama.jn +++ b/llama/src/llama.jn @@ -1,4 +1,19 @@ +fn embedding_lookup<T:number, num_embeddings:usize, embedding_dim:usize, query_dim:usize>( + weights: T[num_embeddings, embedding_dim], + query: usize[query_dim] +) -> T[query_dim, embedding_dim] { + let res : T[query_dim, embedding_dim]; + for q = 0 to query_dim { + for e = 0 to embedding_dim { + res[q,e] = weights[q,e]; + } + } + return res; +} + #[entry] -fn llama() -> i32 { - return 1; +fn llama_f32<num_embeddings: usize, embedding_dim: usize>(weights_embedding: f32[num_embeddings, embedding_dim]) -> f32 { + let query: usize[1]; + query[0] = 0; + return embedding_lookup::<f32, num_embeddings, embedding_dim, 1>(weights_embedding, query)[0,0]; } diff --git a/llama/src/main.rs b/llama/src/main.rs index 1fef9c3d8078f68d44bae4bae3687f3ea61df25d..06b5a486d66cdde94bb91af68d0491aa2ea22661 100644 --- a/llama/src/main.rs +++ b/llama/src/main.rs @@ -2,11 +2,11 @@ #![feature(entry_insert)] use hercules_rt::runner; -juno_build::juno!("llama"); +juno_build::juno!("llama_f32"); fn main() { async_std::task::block_on(async { - let mut r = runner!(llama); + let mut r = runner!(llama_f32); let result = r.run().await; assert_eq!(result, 1); });