Skip to content
Snippets Groups Projects
Commit 9015a730 authored by daniilp2's avatar daniilp2
Browse files

Added an extra level of indirection for the GPU benchmark to work

parent e6a62f5f
No related branches found
No related tags found
No related merge requests found
......@@ -105,7 +105,7 @@ void basicSgemmLvl2(float *A, size_t bytes_A, int lda, float *B, size_t bytes_B,
int ldb, float *C, size_t bytes_C, int ldc, int k,
float alpha, float beta, size_t dim_X1, size_t dim_Y1,
size_t dim_X2, size_t dim_Y2) {
__hpvm__hint(hpvm::REMOTE_TARGET);
__hpvm__hint(hpvm::CPU_TARGET);
__hpvm__attributes(3, A, B, C, 1, C);
void *sgemm_node =
__hpvm__createNodeND(2, basicSgemmLvl1, (size_t)dim_X2, (size_t)dim_Y2);
......@@ -130,7 +130,7 @@ void basicSgemmLvl3(float *A, size_t bytes_A, int lda, float *B, size_t bytes_B,
int ldb, float *C, size_t bytes_C, int ldc, int k,
float alpha, float beta, size_t dim_X1, size_t dim_Y1,
size_t dim_X2, size_t dim_Y2) {
__hpvm__hint(hpvm::CPU_TARGET);
__hpvm__hint(hpvm::REMOTE_TARGET);
__hpvm__attributes(3, A, B, C, 1, C);
void *sgemm_node = __hpvm__createNodeND(0, basicSgemmLvl2);
__hpvm__bindIn(sgemm_node, 0, 0, 0);
......@@ -151,6 +151,32 @@ void basicSgemmLvl3(float *A, size_t bytes_A, int lda, float *B, size_t bytes_B,
__hpvm__bindIn(sgemm_node, 15, 15, 0);
}
// A wrapper level used in codegen for some backends
void basicSgemmLvl4(float *A, size_t bytes_A, int lda, float *B, size_t bytes_B,
int ldb, float *C, size_t bytes_C, int ldc, int k,
float alpha, float beta, size_t dim_X1, size_t dim_Y1,
size_t dim_X2, size_t dim_Y2) {
__hpvm__hint(hpvm::CPU_TARGET);
__hpvm__attributes(3, A, B, C, 1, C);
void *sgemm_node = __hpvm__createNodeND(0, basicSgemmLvl3);
__hpvm__bindIn(sgemm_node, 0, 0, 0);
__hpvm__bindIn(sgemm_node, 1, 1, 0);
__hpvm__bindIn(sgemm_node, 2, 2, 0);
__hpvm__bindIn(sgemm_node, 3, 3, 0);
__hpvm__bindIn(sgemm_node, 4, 4, 0);
__hpvm__bindIn(sgemm_node, 5, 5, 0);
__hpvm__bindIn(sgemm_node, 6, 6, 0);
__hpvm__bindIn(sgemm_node, 7, 7, 0);
__hpvm__bindIn(sgemm_node, 8, 8, 0);
__hpvm__bindIn(sgemm_node, 9, 9, 0);
__hpvm__bindIn(sgemm_node, 10, 10, 0);
__hpvm__bindIn(sgemm_node, 11, 11, 0);
__hpvm__bindIn(sgemm_node, 12, 12, 0);
__hpvm__bindIn(sgemm_node, 13, 13, 0);
__hpvm__bindIn(sgemm_node, 14, 14, 0);
__hpvm__bindIn(sgemm_node, 15, 15, 0);
}
__attribute__((noinline)) void basicSgemm(char transa, char transb, int m,
int n, int k, float alpha, float *A,
size_t bytesA, int lda, float *B,
......@@ -194,7 +220,7 @@ __attribute__((noinline)) void basicSgemm(char transa, char transb, int m,
dg[0] / db[0],
dg[1] / db[1]};
*(RootIn *)root_in = root_in_local;
void *sgemmDFG = __hpvm__launch(0, basicSgemmLvl3, root_in);
void *sgemmDFG = __hpvm__launch(0, basicSgemmLvl4, root_in);
__hpvm__wait(sgemmDFG);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment