From 1d2c07d060a4d256d517665e25c183adba8e9a27 Mon Sep 17 00:00:00 2001
From: RafaeNoor <abdurrafae98@live.com>
Date: Fri, 15 Apr 2022 01:38:24 -0500
Subject: [PATCH] Hetero-C++: fixing bugs in generated IR with get_node_id

---
 hpvm/projects/hetero-c++/include/DFGUtils.h   |  3 +++
 hpvm/projects/hetero-c++/lib/DFGUtils.cpp     | 21 ++++++++++++++++
 hpvm/projects/hetero-c++/lib/HPVMCGen.cpp     | 24 +++++++++++++++++--
 .../hetero-c++/lib/HPVMExtractTask.cpp        |  6 ++++-
 4 files changed, 51 insertions(+), 3 deletions(-)

diff --git a/hpvm/projects/hetero-c++/include/DFGUtils.h b/hpvm/projects/hetero-c++/include/DFGUtils.h
index ab695617c9..741077475c 100644
--- a/hpvm/projects/hetero-c++/include/DFGUtils.h
+++ b/hpvm/projects/hetero-c++/include/DFGUtils.h
@@ -130,3 +130,6 @@ bool isLoopInclusive(Loop* L, Value* InductionVar);
 
 
 bool isPrivCall(CallInst* CI);
+
+
+Value* castIntegerToBitwidth(Value* V, Instruction* InsertBefore, int BV);
diff --git a/hpvm/projects/hetero-c++/lib/DFGUtils.cpp b/hpvm/projects/hetero-c++/lib/DFGUtils.cpp
index 428004cee6..6d1dc1d72c 100644
--- a/hpvm/projects/hetero-c++/lib/DFGUtils.cpp
+++ b/hpvm/projects/hetero-c++/lib/DFGUtils.cpp
@@ -489,6 +489,7 @@ Function* CreateClone(Function* Orig, std::set<Argument*> Exclude, ValueToValueM
 
                 if(MappedInst){
                     LLVM_DEBUG(dbgs()<<"Removing :"<<*MappedInst<<"\n");
+                    MappedInst->replaceAllUsesWith(UndefValue::get(MappedInst->getType()));
                     MappedInst->eraseFromParent();
                 }
 
@@ -1264,6 +1265,26 @@ bool isPrivCall(CallInst* CI){
 
 
     return (CF->getName() == "__hpvm_priv") ||(CF->getName() == "__hetero_priv") ;
+}
+
+
+Value* castIntegerToBitwidth(Value* V, Instruction* InsertBefore, int BV){
+    LLVMContext& LC = V->getContext();
+
 
+    IntegerType* IType = IntegerType::get(LC, BV); 
 
+    if(V->getType()->isIntegerTy(BV)){
+        return V;
+    } else {
+        IntegerType* SizeTy = dyn_cast<IntegerType>(V->getType());
+        assert(SizeTy && "Dimension size must be an integer type");
+        Value* CastToI = nullptr;
+
+        if(SizeTy->getBitWidth() > BV){
+            return new TruncInst(V, IType, "trunc_", InsertBefore);
+        } else {
+            return new SExtInst(V, IType, "sext_", InsertBefore);
+        }
+    }
 }
diff --git a/hpvm/projects/hetero-c++/lib/HPVMCGen.cpp b/hpvm/projects/hetero-c++/lib/HPVMCGen.cpp
index 00dccd715c..863d056165 100644
--- a/hpvm/projects/hetero-c++/lib/HPVMCGen.cpp
+++ b/hpvm/projects/hetero-c++/lib/HPVMCGen.cpp
@@ -231,8 +231,28 @@ CallInst* HPVMCGenCreateNodeND::insertCallND(Instruction* insertBefore,
     ConstantInt* dimSizeConst = ConstantInt::get(dimSizeType, dimSizeVec.size()); 
     argsVec.push_back(dimSizeConst);
     argsVec.push_back(nodeFunc);
-    for (auto d: dimSizeVec)
-        argsVec.push_back(d);
+
+    IntegerType* I64Type = IntegerType::get(theContext.getLLVMContext(), 64); 
+    for (auto d: dimSizeVec){
+        if(!d->getType()->isIntegerTy(64)){
+
+            IntegerType* SizeTy = dyn_cast<IntegerType>(d->getType());
+            assert(SizeTy && "Dimension size must be an integer type");
+            Value* CastToI64 = nullptr;
+
+            if(SizeTy->getBitWidth() > 64){
+                // trunc
+                CastToI64 = new TruncInst(d, I64Type, "trunc_", insertBefore);
+            } else {
+
+                CastToI64 = new SExtInst(d, I64Type, "sext_", insertBefore);
+            }
+            argsVec.push_back(CastToI64);
+
+        } else {
+            argsVec.push_back(d);
+        }
+    }
     ArrayRef<Value *> argsList(argsVec); // safe: argsList not needed after return
     CallInst* newCall = this->insertCall(insertBefore, funcName, argsList);
     if (! newCall) HPVMWarnError("CreateNodeND function not VarArgs?");
diff --git a/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp b/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp
index ab45fa8e28..8ce2125c07 100644
--- a/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp
+++ b/hpvm/projects/hetero-c++/lib/HPVMExtractTask.cpp
@@ -2938,7 +2938,11 @@ CallInst* HPVMProgram::parallelizeLoop(/*Loop* ExtractedLoop, Loop* InnerLoop*/
         }
 
 
-        InductionVar->replaceAllUsesWith(getIDCall);
+        IntegerType* IVTy = dyn_cast<IntegerType>(InductionVar->getType());
+        assert(IVTy && "Induction Variable must be of integer type");
+
+        Value* MatchedType = castIntegerToBitwidth(getIDCall, getIDCall->getNextNode(), IVTy->getBitWidth());
+        InductionVar->replaceAllUsesWith(MatchedType);
         cast<Instruction>(InductionVar)->eraseFromParent();
 
 
-- 
GitLab