source: icGREP/icgrep-devel/icgrep/IR_Gen/idisa_nvptx_builder.cpp @ 5463

Last change on this file since 5463 was 5440, checked in by nmedfort, 2 years ago

Large refactoring step. Removed IR generation code from Kernel (formally KernelBuilder?) and moved it into the new KernelBuilder? class.

File size: 11.5 KB
RevLine 
[5128]1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "idisa_nvptx_builder.h"
[5165]8#include <llvm/IR/InlineAsm.h>
[5260]9#include <llvm/IR/Module.h>
[5128]10
11namespace IDISA {
[5374]12   
13std::string IDISA_NVPTX20_Builder::getBuilderUniqueName() { return "NVPTX20_" + std::to_string(groupThreads);}
[5128]14
15int IDISA_NVPTX20_Builder::getGroupThreads(){
16    return groupThreads;
17}
18
19Value * IDISA_NVPTX20_Builder::bitblock_any(Value * val) {
20    Type * const int32ty = getInt32Ty();
[5440]21    Function * barrierOrFunc = cast<Function>(getModule()->getOrInsertFunction("llvm.nvvm.barrier0.or", int32ty, int32ty, nullptr));
[5309]22    Value * nonZero_i1 = CreateICmpUGT(val, ConstantInt::getNullValue(mBitBlockType));
[5128]23    Value * nonZero_i32 = CreateZExt(CreateBitCast(nonZero_i1, getInt1Ty()), int32ty);
24    Value * anyNonZero = CreateCall(barrierOrFunc, nonZero_i32);
[5309]25    return CreateICmpNE(anyNonZero,  ConstantInt::getNullValue(int32ty));
[5128]26}
27
28Value * IDISA_NVPTX20_Builder::bitblock_mask_from(Value * pos){
29    Type * const int64ty = getInt64Ty();
30    Value * id = CreateCall(tidFunc);
31    Value * id64 = CreateZExt(id, int64ty);
32    Value * threadSize = getInt64(groupThreads);
33    Value * fullBlocks = CreateUDiv(pos, threadSize);
34    Value * finalBlockSelect = CreateSExt(CreateICmpEQ(id64, fullBlocks), int64ty);
35    Value * finalBlockMask = CreateShl(getInt64(-1), CreateURem(pos, threadSize));
36    Value * unusedBlockMask = CreateSExt(CreateICmpUGT(id64, fullBlocks), int64ty);
37    return CreateBitCast(CreateOr(CreateAnd(finalBlockMask, finalBlockSelect), unusedBlockMask), mBitBlockType);
38}
39
40Value * IDISA_NVPTX20_Builder::bitblock_set_bit(Value * pos){
41    Type * const int64ty = getInt64Ty();
42    Value * id = CreateCall(tidFunc);
43    Value * id64 = CreateZExt(id, int64ty);
44    Value * threadSize = getInt64(groupThreads);
45    Value * fullBlocks = CreateUDiv(pos, threadSize);
46    Value * finalBlockSelect = CreateSExt(CreateICmpEQ(id64, fullBlocks), int64ty);
47    Value * finalBlockMask = CreateShl(getInt64(1), CreateURem(pos, threadSize));
48    return CreateBitCast(CreateAnd(finalBlockMask, finalBlockSelect), mBitBlockType);
49}
50   
51std::pair<Value *, Value *> IDISA_NVPTX20_Builder::bitblock_advance(Value * a, Value * shiftin, unsigned shift) {
52    Value * id = CreateCall(tidFunc);
[5292]53    Value * retVal = CreateCall(mLongAdvanceFunc, {id, a, CreateBitCast(getInt64(shift), mBitBlockType), shiftin});
54    Value * shifted = CreateExtractValue(retVal, {0});
55    Value * shiftOut = CreateExtractValue(retVal, {1});
[5128]56    return std::pair<Value *, Value *>(shiftOut, shifted);
57}
58
59std::pair<Value *, Value *> IDISA_NVPTX20_Builder::bitblock_add_with_carry(Value * a, Value * b, Value * carryIn) {
60    Value * id = CreateCall(tidFunc);
[5292]61    Value * retVal = CreateCall(mLongAddFunc, {id, a, b, carryIn});
62    Value * sum = CreateExtractValue(retVal, {0});
63    Value * carry_out_strm = CreateExtractValue(retVal, {1});
[5128]64    return std::pair<Value *, Value *>(carry_out_strm, sum);
65}
66
67void IDISA_NVPTX20_Builder::CreateGlobals(){
[5440]68    Module * const m = getModule();
[5128]69    Type * const carryTy = ArrayType::get(mBitBlockType, groupThreads+1);
[5440]70    carry = new GlobalVariable(*m,
[5128]71        /*Type=*/carryTy,
72        /*isConstant=*/false,
73        /*Linkage=*/llvm::GlobalValue::InternalLinkage,
74        /*Initializer=*/0, 
75        /*Name=*/"carry",
76        /*InsertBefore*/nullptr,
77        /*TLMode */llvm::GlobalValue::NotThreadLocal,
78        /*AddressSpace*/ 3,
79        /*isExternallyInitialized*/false);
80
81    Type * const bubbleTy = ArrayType::get(mBitBlockType, groupThreads);
82
[5440]83    bubble = new GlobalVariable(*m,
[5128]84        /*Type=*/bubbleTy,
85        /*isConstant=*/false,
86        /*Linkage=*/llvm::GlobalValue::InternalLinkage,
87        /*Initializer=*/0, 
88        /*Name=*/"bubble",
89        /*InsertBefore*/nullptr,
90        /*TLMode */llvm::GlobalValue::NotThreadLocal,
91        /*AddressSpace*/ 3,
92        /*isExternallyInitialized*/false);
93   
94    ConstantAggregateZero* carryConstArray = ConstantAggregateZero::get(carryTy);
95    carry->setInitializer(carryConstArray);
96    ConstantAggregateZero* bubbleConstAray = ConstantAggregateZero::get(bubbleTy);
97    bubble->setInitializer(bubbleConstAray);
98
99}
100
[5309]101void IDISA_NVPTX20_Builder::CreateBuiltinFunctions(){
[5230]102    Type * const voidTy = getVoidTy();
[5128]103    Type * const int32ty = getInt32Ty();
[5440]104    Module * const m = getModule();
105    barrierFunc = cast<Function>(m->getOrInsertFunction("llvm.nvvm.barrier0", voidTy, nullptr));
106    tidFunc = cast<Function>(m->getOrInsertFunction("llvm.nvvm.read.ptx.sreg.tid.x", int32ty, nullptr));
[5128]107}
108
109void IDISA_NVPTX20_Builder::CreateLongAdvanceFunc(){
[5440]110    Type * const int32ty = getInt32Ty();
111    Module * const m = getModule();
112    Type * returnType = StructType::get(m->getContext(), {mBitBlockType, mBitBlockType});
113    mLongAdvanceFunc = cast<Function>(m->getOrInsertFunction("LongAdvance", returnType, int32ty, mBitBlockType, mBitBlockType, mBitBlockType, nullptr));
114    mLongAdvanceFunc->setCallingConv(CallingConv::C);
115    auto args = mLongAdvanceFunc->arg_begin();
[5128]116
[5440]117    Value * const id = &*(args++);
118    id->setName("id");
119    Value * const val = &*(args++);
120    val->setName("val");
121    Value * const shftAmount = &*(args++);
122    shftAmount->setName("shftAmount");
123    Value * const blockCarry = &*(args++);
124    blockCarry->setName("blockCarry");
[5128]125
[5440]126    SetInsertPoint(BasicBlock::Create(m->getContext(), "entry", mLongAdvanceFunc,0));
[5128]127
[5440]128    Value * firstCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(0)});
129    CreateStore(blockCarry, firstCarryPtr);
[5128]130
[5440]131    Value * adv0 = CreateShl(val, shftAmount);
132    Value * nextid = CreateAdd(id, getInt32(1));
133    Value * carryNextPtr = CreateGEP(carry, {getInt32(0), nextid});
134    Value * lshr0 = CreateLShr(val, CreateSub(CreateBitCast(getInt64(64), mBitBlockType), shftAmount));
135    CreateStore(lshr0, carryNextPtr);
[5128]136
[5440]137    CreateCall(barrierFunc);
[5128]138
[5440]139    Value * lastCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(groupThreads)});
140    Value * blockCarryOut = CreateLoad(lastCarryPtr, "blockCarryOut");
[5128]141
[5440]142    Value * carryPtr = CreateGEP(carry, {getInt32(0), id});
143    Value * carryVal = CreateLoad(carryPtr, "carryVal");
144    Value * adv1 = CreateOr(adv0, carryVal);
[5128]145
146
[5440]147    Value * retVal = UndefValue::get(returnType);
148    retVal = CreateInsertValue(retVal, adv1, 0);
149    retVal = CreateInsertValue(retVal, blockCarryOut, 1);
150    CreateRet(retVal);
[5128]151
152}
153
154                                           
155                                           
156void IDISA_NVPTX20_Builder::CreateLongAddFunc(){
157  Type * const int64ty = getInt64Ty();
158  Type * const int32ty = getInt32Ty();
[5440]159  Module * const m = getModule();
[5128]160
[5440]161  Type * returnType = StructType::get(m->getContext(), {mBitBlockType, mBitBlockType});
162
163  mLongAddFunc = cast<Function>(m->getOrInsertFunction("LongAdd", returnType, int32ty, mBitBlockType, mBitBlockType, mBitBlockType, nullptr));
[5128]164  mLongAddFunc->setCallingConv(CallingConv::C);
165  Function::arg_iterator args = mLongAddFunc->arg_begin();
166
167  Value * const id = &*(args++);
168  id->setName("id");
169  Value * const valA = &*(args++);
170  valA->setName("valA");
171  Value * const valB = &*(args++);
172  valB->setName("valB");
173  Value * const blockCarry = &*(args++);
174  blockCarry->setName("blockCarry");
175
[5440]176  BasicBlock * entryBlock = BasicBlock::Create(m->getContext(), "entry", mLongAddFunc, 0);
177  BasicBlock * bubbleCalculateBlock = BasicBlock::Create(m->getContext(), "bubbleCalculate", mLongAddFunc, 0);
178  BasicBlock * bubbleSetBlock = BasicBlock::Create(m->getContext(), "bubbleSet", mLongAddFunc, 0);
[5128]179
180  SetInsertPoint(entryBlock);
181
182  Value * id64 = CreateZExt(id, int64ty);
183
184  Value * partial_sum = CreateAdd(valA, valB);
185  Value * gen = CreateAnd(valA, valB);
186  Value * prop = CreateXor(valA, valB);
187
188  Value * carryPtr = CreateGEP(carry, {getInt32(0), id});
189  Value * carryInitVal = CreateAnd(CreateOr(gen, CreateAnd(prop, CreateNot(partial_sum))), CreateBitCast(getInt64(0x8000000000000000), mBitBlockType));
190  carryInitVal = CreateLShr(carryInitVal, CreateBitCast(CreateSub(getInt64(63), id64), mBitBlockType));
191  CreateStore(carryInitVal, carryPtr);
192
193  Value * bubbleCond = CreateICmpEQ(CreateAdd(CreateBitCast(partial_sum, int64ty), getInt64(1)), getInt64(0));
194  CreateCondBr(bubbleCond, bubbleCalculateBlock, bubbleSetBlock);
195
196  SetInsertPoint(bubbleCalculateBlock);
197  Value * calcBubble = CreateBitCast(CreateShl(getInt64(1), id64), mBitBlockType);
198  CreateBr(bubbleSetBlock);
199
200  SetInsertPoint(bubbleSetBlock);
201  PHINode * bubbleInitVal = CreatePHI(mBitBlockType, 2, "bubbleInitVal");
202  bubbleInitVal->addIncoming(CreateBitCast(getInt64(0), mBitBlockType), entryBlock);
203  bubbleInitVal->addIncoming(calcBubble, bubbleCalculateBlock);
204
205  Value * bubblePtr = CreateGEP(bubble, {getInt32(0), id});
206  CreateStore(bubbleInitVal, bubblePtr);
207
208  CreateCall(barrierFunc);
209
210  Value * carryOffsetPtr = nullptr;
211  Value * carryVal = carryInitVal;
212  Value * bubbleOffsetPtr = nullptr;
213  Value * bubbleVal = bubbleInitVal;
214
215  for (int offset=groupThreads/2; offset>0; offset=offset>>1){
216    carryOffsetPtr = CreateGEP(carry, {getInt32(0), CreateXor(id, getInt32(offset))});
217    carryVal = CreateOr(carryVal, CreateLoad(carryOffsetPtr));
218    CreateStore(carryVal, carryPtr);
219    bubbleOffsetPtr = CreateGEP(bubble, {getInt32(0), CreateXor(id, getInt32(offset))});
220    bubbleVal = CreateOr(bubbleVal, CreateLoad(bubbleOffsetPtr));
221    CreateStore(bubbleVal, bubblePtr);
222    CreateCall(barrierFunc);
223  }
224
225  Value * firstCarryPtr = CreateGEP(carry, {getInt32(0), getInt32(0)});
226  Value * carryVal0 = CreateLoad(firstCarryPtr, "carry0");
227  Value * carry_mask = CreateOr(CreateShl(carryVal0, 1), blockCarry);
228  Value * firstBubblePtr = CreateGEP(bubble, {getInt32(0), getInt32(0)});
229  Value * bubble_mask = CreateLoad(firstBubblePtr, "bubble_mask");
230
231  Value * s = CreateAnd(CreateAdd(carry_mask, bubble_mask), CreateNot(bubble_mask));
232  Value * inc = CreateOr(s, CreateSub(s, carry_mask));
233  Value * rslt = CreateAdd(partial_sum, CreateAnd(CreateLShr(inc, CreateBitCast(id64, mBitBlockType)), CreateBitCast(getInt64(1), mBitBlockType)));
234
235  Value * blockCarryOut = CreateLShr(CreateOr(carryVal0, CreateAnd(bubble_mask, inc)), 63);
236
237  Value * retVal = UndefValue::get(returnType);
238  retVal = CreateInsertValue(retVal, rslt, 0);
239  retVal = CreateInsertValue(retVal, blockCarryOut, 1);
240  CreateRet(retVal);
241
242}
243
[5165]244void IDISA_NVPTX20_Builder::CreateBallotFunc(){
245    Type * const int32ty = getInt32Ty();
246    Type * const int1ty = getInt1Ty();
[5440]247    Module * const m = getModule();
248    Function * const ballotFn = cast<Function>(m->getOrInsertFunction("ballot_nvptx", int32ty, int1ty, nullptr));
[5165]249    ballotFn->setCallingConv(CallingConv::C);
250    Function::arg_iterator args = ballotFn->arg_begin();
251
252    Value * const input = &*(args++);
253    input->setName("input");
254
[5440]255    SetInsertPoint(BasicBlock::Create(m->getContext(), "entry", ballotFn, 0));
[5165]256
257    Value * conv = CreateZExt(input, int32ty);
258
[5260]259    const char * AsmStream = "{.reg .pred %p1;"
260                             "setp.ne.u32 %p1, $1, 0;"
261                             "vote.ballot.b32  $0, %p1;}";
[5165]262    FunctionType * AsmFnTy = FunctionType::get(int32ty, int32ty, false);
[5260]263    llvm::InlineAsm *IA = llvm::InlineAsm::get(AsmFnTy, AsmStream, "=r,r", true, false);
[5165]264    llvm::CallInst * result = CreateCall(IA, conv);
265    result->addAttribute(llvm::AttributeSet::FunctionIndex, llvm::Attribute::NoUnwind);
266
267    CreateRet(result);
[5128]268}
[5165]269
[5192]270LoadInst * IDISA_NVPTX20_Builder::CreateAtomicLoadAcquire(Value * ptr) {
271    return CreateLoad(ptr);
272   
[5165]273}
[5192]274StoreInst * IDISA_NVPTX20_Builder::CreateAtomicStoreRelease(Value * val, Value * ptr) {
275    return CreateStore(val, ptr);
276}
277
278   
279}
Note: See TracBrowser for help on using the repository browser.