Ignore:
Timestamp:
Nov 15, 2016, 3:32:01 PM (3 years ago)
Author:
lindanl
Message:

edtid GPU optimization: merging result in GPU.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/editd/editd.cpp

    r5213 r5215  
    142142        curMatch.dist = dist;
    143143        matchList.push_back(curMatch);
    144         std::cout << "pos: " << match_pos << ", dist:" << dist << "\n";
     144        // std::cout << "pos: " << match_pos << ", dist:" << dist << "\n";
    145145    }
    146146
     
    501501}
    502502
     503void mergeGPUCodeGen(){
     504        LLVMContext TheContext;
     505    Module * M = new Module("editd-gpu", TheContext);
     506    IDISA::IDISA_Builder * iBuilder = IDISA::GetIDISA_GPU_Builder(M);
     507    M->setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64");
     508    M->setTargetTriple("nvptx64-nvidia-cuda");
     509
     510    Type * const mBitBlockType = iBuilder->getBitBlockType();
     511    Type * const int32ty = iBuilder->getInt32Ty();
     512    Type * const voidTy = Type::getVoidTy(M->getContext());
     513    Type * const resultTy = PointerType::get(ArrayType::get(mBitBlockType, editDistance+1), 1);
     514    Type * const stridesTy = PointerType::get(int32ty, 1);
     515
     516    Function * const main = cast<Function>(M->getOrInsertFunction("mergeResult", voidTy, resultTy, stridesTy, nullptr));
     517    main->setCallingConv(CallingConv::C);
     518    Function::arg_iterator args = main->arg_begin();
     519   
     520    Value * const resultStream = &*(args++);
     521    resultStream->setName("resultStream");
     522    Value * const stridesPtr = &*(args++);
     523    stridesPtr->setName("stridesPtr");
     524
     525    BasicBlock * entryBlock = BasicBlock::Create(iBuilder->getContext(), "entryBlock", main, 0);
     526    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", main, 0);
     527    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", main, 0);
     528    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", main, 0);
     529   
     530    iBuilder->SetInsertPoint(entryBlock);
     531
     532    Function * tidFunc = M->getFunction("llvm.nvvm.read.ptx.sreg.tid.x");
     533    Value * tid = iBuilder->CreateCall(tidFunc);
     534
     535    Function * bidFunc = cast<Function>(M->getOrInsertFunction("llvm.nvvm.read.ptx.sreg.ctaid.x", int32ty, nullptr));
     536    Value * bid = iBuilder->CreateCall(bidFunc);
     537    Value * strides = iBuilder->CreateLoad(stridesPtr);
     538    Value * strideBlocks = ConstantInt::get(int32ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
     539    Value * outputBlocks = iBuilder->CreateMul(strides, strideBlocks);
     540    Value * resultStreamPtr = iBuilder->CreateGEP(resultStream, tid);
     541
     542    iBuilder->CreateBr(strideLoopCond);
     543    iBuilder->SetInsertPoint(strideLoopCond);
     544    PHINode * strideNo = iBuilder->CreatePHI(int32ty, 2, "strideNo");
     545    strideNo->addIncoming(ConstantInt::get(int32ty, 0), entryBlock);
     546    Value * notDone = iBuilder->CreateICmpULT(strideNo, strides);
     547    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
     548 
     549    iBuilder->SetInsertPoint(strideLoopBody);
     550    Value * myResultStreamPtr = iBuilder->CreateGEP(resultStreamPtr, {iBuilder->CreateMul(strideBlocks, strideNo)});
     551    Value * myResultStream = iBuilder->CreateLoad(iBuilder->CreateGEP(myResultStreamPtr, {iBuilder->getInt32(0), bid}));
     552    for (unsigned i=1; i<GROUPBLOCKS; i++){
     553        Value * nextStreamPtr = iBuilder->CreateGEP(myResultStreamPtr, {iBuilder->CreateMul(outputBlocks, iBuilder->getInt32(i)), bid});
     554        myResultStream = iBuilder->CreateOr(myResultStream, iBuilder->CreateLoad(nextStreamPtr));
     555    }   
     556    iBuilder->CreateStore(myResultStream, iBuilder->CreateGEP(myResultStreamPtr, {iBuilder->getInt32(0), bid}));
     557    strideNo->addIncoming(iBuilder->CreateAdd(strideNo, ConstantInt::get(int32ty, 1)), strideLoopBody);
     558    iBuilder->CreateBr(strideLoopCond);
     559   
     560    iBuilder->SetInsertPoint(stridesDone);
     561    iBuilder->CreateRetVoid();
     562   
     563    MDNode * Node = MDNode::get(M->getContext(),
     564                                {llvm::ValueAsMetadata::get(main),
     565                                 MDString::get(M->getContext(), "kernel"),
     566                                 ConstantAsMetadata::get(ConstantInt::get(iBuilder->getInt32Ty(), 1))});
     567    NamedMDNode *NMD = M->getOrInsertNamedMetadata("nvvm.annotations");
     568    NMD->addOperand(Node);
     569
     570    Compile2PTX(M, "merge.ll", "merge.ptx");
     571
     572}
     573
    503574editdFunctionType editdScanCPUCodeGen() {
    504575                           
     
    542613}
    543614
    544 void mergeResult(ulong * rslt){
    545     int strideSize = GROUPTHREADS * sizeof(ulong) * 8;
    546     int strides = size/strideSize + 1;
    547     int groupItems = strides * GROUPTHREADS * (editDistance + 1);
    548     for(int i=0; i<groupItems; i++){
    549         for(int j=1; j<GROUPBLOCKS; j++){
    550             rslt[i] = rslt[i] | rslt[j * groupItems + i];
    551         }
    552     }
    553 }
    554615#endif
    555616
     
    583644        editdGPUCodeGen(patterns.length()/GROUPTHREADS - 1);
    584645
     646        mergeGPUCodeGen();
     647
    585648        ulong * rslt = RunPTX(PTXFilename, chStream, size, patterns.c_str(), patterns.length(), editDistance);
    586649
    587650        editdFunctionType editd_ptr = editdScanCPUCodeGen();
    588651
    589         mergeResult(rslt);
    590652        editd(editd_ptr, (char*)rslt, size);
     653       
    591654        run_second_filter(pattern_segs, total_len, 0.15);
    592655
Note: See TracChangeset for help on using the changeset viewer.