From c6e4587014294885dd8919b58ddedecebca5c8a4 Mon Sep 17 00:00:00 2001
From: RafaeNoor <abdurrafae98@live.com>
Date: Sat, 16 Apr 2022 13:40:35 -0500
Subject: [PATCH] Added support for generating getNumNodes in Hetero-C++

---
 hpvm/projects/hetero-c++/include/DFGUtils.h   |  3 ++
 .../hetero-c++/include/HPVMCGenFunctions.h    |  2 +-
 hpvm/projects/hetero-c++/lib/DFGUtils.cpp     | 23 ++++++++++++
 .../hetero-c++/lib/HPVMCGenFunctions.cpp      | 13 +++----
 .../hetero-c++/lib/HPVMExtractTask.cpp        | 36 +++++++++++++++++--
 5 files changed, 67 insertions(+), 10 deletions(-)

diff --git a/hpvm/projects/hetero-c++/include/DFGUtils.h b/hpvm/projects/hetero-c++/include/DFGUtils.h
index c2f90c39d1..95f0dbfff2 100644
--- a/hpvm/projects/hetero-c++/include/DFGUtils.h
+++ b/hpvm/projects/hetero-c++/include/DFGUtils.h
@@ -134,3 +134,6 @@ bool isPrivCall(CallInst* CI);
 
 bool isNonZeroCall(CallInst* CI);
 Value* castIntegerToBitwidth(Value* V, Instruction* InsertBefore, int BV);
+
+
+bool isHPVMGraphIntrinsic(Value* V);
diff --git a/hpvm/projects/hetero-c++/include/HPVMCGenFunctions.h b/hpvm/projects/hetero-c++/include/HPVMCGenFunctions.h
index caa9311476..d2e9928b81 100644
--- a/hpvm/projects/hetero-c++/include/HPVMCGenFunctions.h
+++ b/hpvm/projects/hetero-c++/include/HPVMCGenFunctions.h
@@ -103,7 +103,7 @@ class HPVMCGenGetNumNodeInstances : public HPVMCGenBase {
   public:
   HPVMCGenGetNumNodeInstances(HPVMCGenContext & HCGC) : HPVMCGenBase(HCGC) { }
 
-  CallInst* insertCall(Instruction* InsertPoint, Argument * Node, unsigned dimension);
+  CallInst* insertCall(Instruction* InsertPoint, Value * Node, unsigned dimension);
 };
 
 class HPVMCGenLaunch : public HPVMCGenBase {
diff --git a/hpvm/projects/hetero-c++/lib/DFGUtils.cpp b/hpvm/projects/hetero-c++/lib/DFGUtils.cpp
index 63be8a449c..a1747b5329 100644
--- a/hpvm/projects/hetero-c++/lib/DFGUtils.cpp
+++ b/hpvm/projects/hetero-c++/lib/DFGUtils.cpp
@@ -1302,4 +1302,27 @@ Value* castIntegerToBitwidth(Value* V, Instruction* InsertBefore, int BV){
 }
 
 
+bool isHPVMGraphIntrinsic(Value* V){
+    CallInst* CI = dyn_cast<CallInst>(V);
+    if(!CI) return false;
+
+    Function* CF = CI->getCalledFunction();
+
+    if(!CF) return false;
+
+    if(isTaskBeginMarker(CI)) return true;
+    if(isParallelLoopMarker(CI)) return true;
+    if(isLaunchBeginMarker(CI)) return true;
+
+
+
+    if(CF->getName().str() == "__hpvm__attributes" || CF->getName().str() == "__hpvm__return" 
+        || CF->getName().str() == "__hpvm__order"
+            ){
+        return true;
+    }
+
+    return false;
+
+}
 
diff --git a/hpvm/projects/hetero-c++/lib/HPVMCGenFunctions.cpp b/hpvm/projects/hetero-c++/lib/HPVMCGenFunctions.cpp
index 5fe4bae961..8c4a450d5a 100644
--- a/hpvm/projects/hetero-c++/lib/HPVMCGenFunctions.cpp
+++ b/hpvm/projects/hetero-c++/lib/HPVMCGenFunctions.cpp
@@ -314,27 +314,28 @@ CallInst *HPVMCGenGetNodeInstanceID::insertCall(Instruction* InsertPoint, Value
     return NewCall;
 }
 
-CallInst *HPVMCGenGetNumNodeInstances::insertCall(Instruction* InsertPoint, Argument * Node, unsigned dimension) {
+CallInst *HPVMCGenGetNumNodeInstances::insertCall(Instruction* InsertPoint, Value * Node, unsigned dimension) {
     if(dimension > 2) {
         HPVMFatalError("Detected invalid dimension: " + std::to_string(dimension));
     }
 
-    char dim;
+    std::string dim;
     switch(dimension) {
         case 0:
-            dim = 'x';
+            dim = "__hpvm__getNumNodeInstances_x";
             break;
         
         case 1:
-            dim = 'y';
+            dim = "__hpvm__getNumNodeInstances_y";
             break;
 
         case 2:
-            dim = 'z';
+            dim = "__hpvm__getNumNodeInstances_z";
             break;
     }
 
-    const StringRef FuncName("__hpvm__getNumNodeInstances_" + dim);
+
+    const StringRef FuncName(dim);
     Function * HpvmGetNode = theContext.getModule().getFunction(FuncName);
     if (!HpvmGetNode) {
         HPVMFatalError("__hpvm__getNumNodeInstances function not found in context");
diff --git a/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp b/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp
index a751d0ec8e..0edd768a3f 100644
--- a/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp
+++ b/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp
@@ -2856,13 +2856,17 @@ CallInst* HPVMProgram::parallelizeLoop(/*Loop* ExtractedLoop, Loop* InnerLoop*/
     HPVMCGenContext cgenContext(M);
     HPVMCGenGetNode cgenGetNode(cgenContext);
     HPVMCGenGetNodeInstanceID cgenGetID(cgenContext);
+    HPVMCGenGetNumNodeInstances cgenGetNumNode(cgenContext);
 
     CallInst* getNodeCall = cgenGetNode.insertCall(OrigF->getEntryBlock().getTerminator());
 
     for(int i = 0; i < Limits.size(); i++){
         CallInst* getIDCall = cgenGetID.insertCall(OrigF->getEntryBlock().getTerminator(),
             getNodeCall, i);
-        
+
+        CallInst* getNumCall = cgenGetNumNode.insertCall(OrigF->getEntryBlock().getTerminator(),
+            getNodeCall, i);
+
         Value* InductionVar = InductionVars[i];
 
         Loop* L = LoopNest[i];
@@ -2954,7 +2958,7 @@ CallInst* HPVMProgram::parallelizeLoop(/*Loop* ExtractedLoop, Loop* InnerLoop*/
         Value* MatchedType = castIntegerToBitwidth(getIDCall, getIDCall->getNextNode(), IVTy->getBitWidth());
 
 
-        auto shouldReplace = [&](Use &U)->bool {
+        auto shouldReplaceIV = [&](Use &U)->bool {
             auto useInst = U.getUser();
             Instruction* IU = dyn_cast<Instruction>(useInst);
 
@@ -2965,12 +2969,38 @@ CallInst* HPVMProgram::parallelizeLoop(/*Loop* ExtractedLoop, Loop* InnerLoop*/
 
         };
 
-        InductionVar->replaceUsesWithIf(MatchedType, shouldReplace);
+
+        InductionVar->replaceUsesWithIf(MatchedType, shouldReplaceIV);
 
         cast<Instruction>(InductionVar)->eraseFromParent();
 
 
 
+        IntegerType* LimTy = dyn_cast<IntegerType>(InductionVar->getType());
+        assert(LimTy && "Loop bounds  must be of integer type");
+
+        Value* MatchedLimitType = castIntegerToBitwidth(getNumCall, getNumCall->getNextNode(), LimTy->getBitWidth());
+
+        auto shouldReplaceLimit = [&](Use &U)->bool {
+            auto useInst = U.getUser();
+            Instruction* IU = dyn_cast<Instruction>(useInst);
+
+            if(!IU) return false;
+            if(IU->getParent()->getParent() != OrigF) return false;
+            if(isHPVMGraphIntrinsic(IU)) return false;
+
+            Instruction* LimInst = dyn_cast<Instruction>(MatchedLimitType);
+
+            return IU && DTCache.dominates(LimInst, IU);
+
+        };
+
+
+        Limits[i]->replaceUsesWithIf(MatchedLimitType, shouldReplaceLimit);
+
+
+
+
         LLVM_DEBUG(errs()<<"Transformed Function: "<<*OrigF<<"\n");
 
 
-- 
GitLab