Ignore:
Timestamp:
May 10, 2017, 4:26:11 PM (2 years ago)
Author:
nmedfort
Message:

Large refactoring step. Removed IR generation code from Kernel (formally KernelBuilder?) and moved it into the new KernelBuilder? class.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/IR_Gen/idisa_nvptx_builder.cpp

    r5374 r5440  
    1919Value * IDISA_NVPTX20_Builder::bitblock_any(Value * val) {
    2020    Type * const int32ty = getInt32Ty();
    21     Function * barrierOrFunc = cast<Function>(mMod->getOrInsertFunction("llvm.nvvm.barrier0.or", int32ty, int32ty, nullptr));
     21    Function * barrierOrFunc = cast<Function>(getModule()->getOrInsertFunction("llvm.nvvm.barrier0.or", int32ty, int32ty, nullptr));
    2222    Value * nonZero_i1 = CreateICmpUGT(val, ConstantInt::getNullValue(mBitBlockType));
    2323    Value * nonZero_i32 = CreateZExt(CreateBitCast(nonZero_i1, getInt1Ty()), int32ty);
     
    6666
    6767void IDISA_NVPTX20_Builder::CreateGlobals(){
    68 
     68    Module * const m = getModule();
    6969    Type * const carryTy = ArrayType::get(mBitBlockType, groupThreads+1);
    70     carry = new GlobalVariable(*mMod,
     70    carry = new GlobalVariable(*m,
    7171        /*Type=*/carryTy,
    7272        /*isConstant=*/false,
     
    8181    Type * const bubbleTy = ArrayType::get(mBitBlockType, groupThreads);
    8282
    83     bubble = new GlobalVariable(*mMod,
     83    bubble = new GlobalVariable(*m,
    8484        /*Type=*/bubbleTy,
    8585        /*isConstant=*/false,
     
    102102    Type * const voidTy = getVoidTy();
    103103    Type * const int32ty = getInt32Ty();
    104     barrierFunc = cast<Function>(mMod->getOrInsertFunction("llvm.nvvm.barrier0", voidTy, nullptr));
    105     tidFunc = cast<Function>(mMod->getOrInsertFunction("llvm.nvvm.read.ptx.sreg.tid.x", int32ty, nullptr));
    106 
     104    Module * const m = getModule();
     105    barrierFunc = cast<Function>(m->getOrInsertFunction("llvm.nvvm.barrier0", voidTy, nullptr));
     106    tidFunc = cast<Function>(m->getOrInsertFunction("llvm.nvvm.read.ptx.sreg.tid.x", int32ty, nullptr));
    107107}
    108108
    109109void IDISA_NVPTX20_Builder::CreateLongAdvanceFunc(){
    110   Type * const int32ty = getInt32Ty();
    111   Type * returnType = StructType::get(mMod->getContext(), {mBitBlockType, mBitBlockType});
    112 
    113   mLongAdvanceFunc = cast<Function>(mMod->getOrInsertFunction("LongAdvance", returnType, int32ty, mBitBlockType, mBitBlockType, mBitBlockType, nullptr));
    114   mLongAdvanceFunc->setCallingConv(CallingConv::C);
    115   Function::arg_iterator args = mLongAdvanceFunc->arg_begin();
    116 
    117   Value * const id = &*(args++);
    118   id->setName("id");
    119   Value * const val = &*(args++);
    120   val->setName("val");
    121   Value * const shftAmount = &*(args++);
    122   shftAmount->setName("shftAmount");
    123   Value * const blockCarry = &*(args++);
    124   blockCarry->setName("blockCarry");
    125 
    126   SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", mLongAdvanceFunc,0));
    127 
    128   Value * firstCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(0)});
    129   CreateStore(blockCarry, firstCarryPtr);
    130 
    131   Value * adv0 = CreateShl(val, shftAmount);
    132   Value * nextid = CreateAdd(id, getInt32(1));
    133   Value * carryNextPtr = CreateGEP(carry, {getInt32(0), nextid});
    134   Value * lshr0 = CreateLShr(val, CreateSub(CreateBitCast(getInt64(64), mBitBlockType), shftAmount));
    135   CreateStore(lshr0, carryNextPtr);
    136 
    137   CreateCall(barrierFunc);
    138 
    139   Value * lastCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(groupThreads)});
    140   Value * blockCarryOut = CreateLoad(lastCarryPtr, "blockCarryOut");
    141 
    142   Value * carryPtr = CreateGEP(carry, {getInt32(0), id});
    143   Value * carryVal = CreateLoad(carryPtr, "carryVal");
    144   Value * adv1 = CreateOr(adv0, carryVal);
    145 
    146  
    147   Value * retVal = UndefValue::get(returnType);
    148   retVal = CreateInsertValue(retVal, adv1, 0);
    149   retVal = CreateInsertValue(retVal, blockCarryOut, 1);
    150   CreateRet(retVal);
     110    Type * const int32ty = getInt32Ty();
     111    Module * const m = getModule();
     112    Type * returnType = StructType::get(m->getContext(), {mBitBlockType, mBitBlockType});
     113    mLongAdvanceFunc = cast<Function>(m->getOrInsertFunction("LongAdvance", returnType, int32ty, mBitBlockType, mBitBlockType, mBitBlockType, nullptr));
     114    mLongAdvanceFunc->setCallingConv(CallingConv::C);
     115    auto args = mLongAdvanceFunc->arg_begin();
     116
     117    Value * const id = &*(args++);
     118    id->setName("id");
     119    Value * const val = &*(args++);
     120    val->setName("val");
     121    Value * const shftAmount = &*(args++);
     122    shftAmount->setName("shftAmount");
     123    Value * const blockCarry = &*(args++);
     124    blockCarry->setName("blockCarry");
     125
     126    SetInsertPoint(BasicBlock::Create(m->getContext(), "entry", mLongAdvanceFunc,0));
     127
     128    Value * firstCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(0)});
     129    CreateStore(blockCarry, firstCarryPtr);
     130
     131    Value * adv0 = CreateShl(val, shftAmount);
     132    Value * nextid = CreateAdd(id, getInt32(1));
     133    Value * carryNextPtr = CreateGEP(carry, {getInt32(0), nextid});
     134    Value * lshr0 = CreateLShr(val, CreateSub(CreateBitCast(getInt64(64), mBitBlockType), shftAmount));
     135    CreateStore(lshr0, carryNextPtr);
     136
     137    CreateCall(barrierFunc);
     138
     139    Value * lastCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(groupThreads)});
     140    Value * blockCarryOut = CreateLoad(lastCarryPtr, "blockCarryOut");
     141
     142    Value * carryPtr = CreateGEP(carry, {getInt32(0), id});
     143    Value * carryVal = CreateLoad(carryPtr, "carryVal");
     144    Value * adv1 = CreateOr(adv0, carryVal);
     145
     146
     147    Value * retVal = UndefValue::get(returnType);
     148    retVal = CreateInsertValue(retVal, adv1, 0);
     149    retVal = CreateInsertValue(retVal, blockCarryOut, 1);
     150    CreateRet(retVal);
    151151
    152152}
     
    157157  Type * const int64ty = getInt64Ty();
    158158  Type * const int32ty = getInt32Ty();
    159   Type * returnType = StructType::get(mMod->getContext(), {mBitBlockType, mBitBlockType});
    160 
    161   mLongAddFunc = cast<Function>(mMod->getOrInsertFunction("LongAdd", returnType, int32ty, mBitBlockType, mBitBlockType, mBitBlockType, nullptr));
     159  Module * const m = getModule();
     160
     161  Type * returnType = StructType::get(m->getContext(), {mBitBlockType, mBitBlockType});
     162
     163  mLongAddFunc = cast<Function>(m->getOrInsertFunction("LongAdd", returnType, int32ty, mBitBlockType, mBitBlockType, mBitBlockType, nullptr));
    162164  mLongAddFunc->setCallingConv(CallingConv::C);
    163165  Function::arg_iterator args = mLongAddFunc->arg_begin();
     
    172174  blockCarry->setName("blockCarry");
    173175
    174   BasicBlock * entryBlock = BasicBlock::Create(mMod->getContext(), "entry", mLongAddFunc, 0);
    175   BasicBlock * bubbleCalculateBlock = BasicBlock::Create(mMod->getContext(), "bubbleCalculate", mLongAddFunc, 0);
    176   BasicBlock * bubbleSetBlock = BasicBlock::Create(mMod->getContext(), "bubbleSet", mLongAddFunc, 0);
     176  BasicBlock * entryBlock = BasicBlock::Create(m->getContext(), "entry", mLongAddFunc, 0);
     177  BasicBlock * bubbleCalculateBlock = BasicBlock::Create(m->getContext(), "bubbleCalculate", mLongAddFunc, 0);
     178  BasicBlock * bubbleSetBlock = BasicBlock::Create(m->getContext(), "bubbleSet", mLongAddFunc, 0);
    177179
    178180  SetInsertPoint(entryBlock);
     
    243245    Type * const int32ty = getInt32Ty();
    244246    Type * const int1ty = getInt1Ty();
    245     Function * const ballotFn = cast<Function>(mMod->getOrInsertFunction("ballot_nvptx", int32ty, int1ty, nullptr));
     247    Module * const m = getModule();
     248    Function * const ballotFn = cast<Function>(m->getOrInsertFunction("ballot_nvptx", int32ty, int1ty, nullptr));
    246249    ballotFn->setCallingConv(CallingConv::C);
    247250    Function::arg_iterator args = ballotFn->arg_begin();
     
    250253    input->setName("input");
    251254
    252     SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", ballotFn, 0));
     255    SetInsertPoint(BasicBlock::Create(m->getContext(), "entry", ballotFn, 0));
    253256
    254257    Value * conv = CreateZExt(input, int32ty);
Note: See TracChangeset for help on using the changeset viewer.