Skip to content
Snippets Groups Projects
Commit 3cd87810 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'fix-cpu-tilings' into 'main'

Fix backprop and matmul cpu schedules

See merge request !223
parents 8e7b89c5 25db4a50
No related branches found
No related tags found
1 merge request!223Fix backprop and matmul cpu schedules
Pipeline #202084 passed
...@@ -25,6 +25,10 @@ macro forkify!(X) { ...@@ -25,6 +25,10 @@ macro forkify!(X) {
} }
} }
macro fork-chunk![n](X) {
fork-tile[n, 0, false, false](X);
}
macro fork-tile![n](X) { macro fork-tile![n](X) {
fork-tile[n, 0, false, true](X); fork-tile[n, 0, false, true](X);
} }
...@@ -66,8 +70,8 @@ if feature("cuda") { ...@@ -66,8 +70,8 @@ if feature("cuda") {
// Parallelize by computing output array as 16 chunks // Parallelize by computing output array as 16 chunks
let par = matmul@outer \ matmul@inner; let par = matmul@outer \ matmul@inner;
fork-tile![4](par); fork-chunk![4](par);
let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); let (outer, inner, _) = fork-reshape[[0, 2], [1], [3]](par);
parallelize!(outer \ inner); parallelize!(outer \ inner);
let body = outline(inner); let body = outline(inner);
......
...@@ -38,8 +38,8 @@ let forward_input = outline(backprop@forward_input); ...@@ -38,8 +38,8 @@ let forward_input = outline(backprop@forward_input);
let forward_hidden = outline(backprop@forward_hidden); let forward_hidden = outline(backprop@forward_hidden);
if !feature("seq") { if !feature("seq") {
fork-tile[16, 0, false, true](forward_input@outer_loop \ forward_input@inner_loop); fork-tile[16, 0, false, false](forward_input@outer_loop \ forward_input@inner_loop);
let (outer, inner) = fork-reshape[[1], [0]](forward_input@outer_loop \ forward_input@inner_loop); let (outer, inner) = fork-reshape[[0], [1]](forward_input@outer_loop \ forward_input@inner_loop);
forward_input = outline(inner); forward_input = outline(inner);
inline(backprop@forward_input); inline(backprop@forward_input);
} }
...@@ -53,8 +53,8 @@ let adjust_hidden = outline(backprop@adjust_hidden); ...@@ -53,8 +53,8 @@ let adjust_hidden = outline(backprop@adjust_hidden);
let adjust_input = outline(backprop@adjust_input); let adjust_input = outline(backprop@adjust_input);
if !feature("seq") { if !feature("seq") {
fork-tile[16, 0, false, true](adjust_input); fork-tile[16, 0, false, false](adjust_input);
let (outer, inner) = fork-reshape[[1], [0, 2]](adjust_input); let (outer, inner) = fork-reshape[[0], [1, 2]](adjust_input);
adjust_input = outline(inner); adjust_input = outline(inner);
inline(backprop@adjust_input); inline(backprop@adjust_input);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment