Ignore:
Timestamp:
Dec 16, 2016, 4:16:28 PM (3 years ago)
Author:
nmedfort
Message:

Rewrite of the CarryManager? to support non-carry-collapsing loops.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp

    r5217 r5227  
    2323namespace pablo {
    2424
    25 PabloCompiler::PabloCompiler(PabloKernel * kernel)
    26 : iBuilder(kernel->getBuilder())
    27 , mBitBlockType(iBuilder->getBitBlockType())
    28 , mCarryManager(nullptr)
    29 , mKernel(kernel)
    30 , mWhileDepth(0)
    31 , mIfDepth(0)
    32 , mFunction(nullptr)
    33 , mMaxWhileDepth(0) {
     25using TypeId = PabloAST::ClassTypeId;
     26
     27void PabloCompiler::initializeKernelData() {
     28    Examine();
     29    mCarryManager->initializeCarryData(mKernel);
     30}
    3431   
    35 }
    36 
    37 Type * PabloCompiler::initializeKernelData() {
    38     Examine();
    39     mCarryManager = std::unique_ptr<CarryManager>(new CarryManager(iBuilder));
    40     return mCarryManager->initializeCarryData(mKernel);
    41 }
    42    
    43 void PabloCompiler::compile(Value * const self, Function * doBlockFunction) {
     32void PabloCompiler::compile(Value * const self, Function * function) {
    4433
    4534    // Make sure that we generate code into the right module.
    46     mFunction = doBlockFunction;
    4735    mSelf = self;
    48 
    49     #ifdef PRINT_TIMING_INFORMATION
    50     const timestamp_t pablo_compilation_start = read_cycle_counter();
    51     #endif
     36    mFunction = function;
    5237
    5338    //Generate Kernel//
    54     iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doBlockFunction, 0));
    55 
    56     mCarryManager->initializeCodeGen(mKernel, mSelf);
     39    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", function, 0));
     40
     41    mCarryManager->initializeCodeGen(self, function);
    5742     
    5843    PabloBlock * const entryBlock = mKernel->getEntryBlock(); assert (entryBlock);
     
    9580
    9681inline void PabloCompiler::Examine() {
    97     mWhileDepth = 0;
    98     mIfDepth = 0;
    99     mMaxWhileDepth = 0;
    10082    Examine(mKernel->getEntryBlock());
    10183}
     
    10991                mKernel->setLookAhead(la->getAmount());
    11092            }
    111         } else {
    112             if (LLVM_UNLIKELY(isa<If>(stmt))) {
    113                 Examine(cast<If>(stmt)->getBody());
    114             } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
    115                 mMaxWhileDepth = std::max(mMaxWhileDepth, ++mWhileDepth);
    116                 Examine(cast<While>(stmt)->getBody());
    117                 --mWhileDepth;
    118             }
     93        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
     94            Examine(cast<Branch>(stmt)->getBody());
    11995        }
    12096    }   
     
    155131    //
    156132
    157     Module * const mod = iBuilder->getModule();
    158133    BasicBlock * const ifEntryBlock = iBuilder->GetInsertBlock();
    159     BasicBlock * const ifBodyBlock = BasicBlock::Create(mod->getContext(), "if.body", mFunction, 0);
    160     BasicBlock * const ifEndBlock = BasicBlock::Create(mod->getContext(), "if.end", mFunction, 0);
     134    BasicBlock * const ifBodyBlock = BasicBlock::Create(mFunction->getContext(), "if.body", mFunction);
     135    BasicBlock * const ifEndBlock = BasicBlock::Create(mFunction->getContext(), "if.end", mFunction);
    161136   
    162137    std::vector<std::pair<const Var *, Value *>> incoming;
     
    177152    PabloBlock * ifBody = ifStatement->getBody();
    178153   
    179     Value * const condition = compileExpression(ifStatement->getCondition());
     154    mCarryManager->enterIfScope(ifBody);
     155
     156    Value * condition = compileExpression(ifStatement->getCondition());
     157    if (condition->getType() == iBuilder->getBitBlockType()) {
     158        condition = iBuilder->bitblock_any(mCarryManager->generateSummaryTest(condition));
     159    }
    180160   
    181     mCarryManager->enterScope(ifBody);
    182     iBuilder->CreateCondBr(mCarryManager->generateSummaryTest(condition), ifBodyBlock, ifEndBlock);
     161    iBuilder->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
    183162   
    184163    // Entry processing is complete, now handle the body of the if.
    185164    iBuilder->SetInsertPoint(ifBodyBlock);
    186    
     165
     166    mCarryManager->enterIfBody(ifEntryBlock);
     167
    187168    compileBlock(ifBody);
    188169
    189     BasicBlock * ifExitBlock = iBuilder->GetInsertBlock();
    190 
    191     if (mCarryManager->hasCarries()) {
    192         mCarryManager->storeCarryOutSummary();
    193     }
    194     mCarryManager->addOuterSummaryToNestedSummary();
     170    BasicBlock * ifExitBlock = iBuilder->GetInsertBlock();   
     171
     172    mCarryManager->leaveIfBody(ifExitBlock);
    195173
    196174    iBuilder->CreateBr(ifEndBlock);
    197175    //End Block
    198176    iBuilder->SetInsertPoint(ifEndBlock);
     177
     178    mCarryManager->leaveIfScope(ifEntryBlock, ifExitBlock);
     179
    199180    for (const auto i : incoming) {
    200181        const Var * var; Value * value;
     
    210191        }
    211192
    212         PHINode * phi = iBuilder->CreatePHI(mBitBlockType, 2, getName(var));
     193        Value * const next = f->second;
     194
     195        assert (value->getType() == next->getType());
     196
     197        PHINode * phi = iBuilder->CreatePHI(value->getType(), 2, getName(var));
    213198        phi->addIncoming(value, ifEntryBlock);
    214         phi->addIncoming(f->second, ifExitBlock);
     199        phi->addIncoming(next, ifExitBlock);
    215200        f->second = phi;
    216201
    217202        assert (mMarkerMap[var] == phi);
    218     }
    219     // Create the phi Node for the summary variable, if needed.
    220     mCarryManager->buildCarryDataPhisAfterIfBody(ifEntryBlock, ifExitBlock);
    221     mCarryManager->leaveScope();
     203    }   
    222204}
    223205
     
    228210    BasicBlock * whileEntryBlock = iBuilder->GetInsertBlock();
    229211
    230     Module * const mod = iBuilder->getModule();
    231     BasicBlock * whileBodyBlock = BasicBlock::Create(mod->getContext(), "while.body", mFunction, 0);
    232     BasicBlock * whileEndBlock = BasicBlock::Create(mod->getContext(), "while.end", mFunction, 0);
    233 
    234     mCarryManager->enterScope(whileBody);
    235     mCarryManager->ensureCarriesLoadedRecursive();
     212    BasicBlock * whileBodyBlock = BasicBlock::Create(iBuilder->getContext(), "while.body", mFunction);
     213
     214    BasicBlock * whileEndBlock = BasicBlock::Create(iBuilder->getContext(), "while.end", mFunction);
     215
     216    const auto escaped = whileStatement->getEscaped();
    236217
    237218#ifdef ENABLE_BOUNDED_WHILE
     
    242223    // the loop.
    243224
     225    mCarryManager->enterLoopScope(whileBody);
     226
    244227    iBuilder->CreateBr(whileBodyBlock);
     228
    245229    iBuilder->SetInsertPoint(whileBodyBlock);
    246230
     
    254238#endif
    255239
    256     mCarryManager->initializeWhileEntryCarryDataPhis(whileEntryBlock);
    257 
    258240    std::vector<std::pair<const Var *, PHINode *>> variants;
    259241
    260242    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
    261     for (const auto var : whileStatement->getEscaped()) {       
     243    for (const auto var : escaped) {
    262244        auto f = mMarkerMap.find(var);
    263245        if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
     
    269251            llvm::report_fatal_error(out.str());
    270252        }
    271 
    272         PHINode * phi = iBuilder->CreatePHI(mBitBlockType, 2, getName(var));
    273         phi->addIncoming(f->second, whileEntryBlock);
     253        Value * entryValue = f->second;
     254        PHINode * phi = iBuilder->CreatePHI(entryValue->getType(), 2, getName(var));
     255        phi->addIncoming(entryValue, whileEntryBlock);
    274256        f->second = phi;
    275257        assert(mMarkerMap[var] == phi);
     
    282264    }
    283265#endif
     266
     267    mCarryManager->enterLoopBody(whileEntryBlock);
     268
    284269    //
    285270    // Now compile the loop body proper.  Carry-out accumulated values
    286271    // and iterated values of Next nodes will be computed.
    287     ++mWhileDepth;
    288272    compileBlock(whileBody);
    289273
     
    291275    BasicBlock * whileExitBlock = iBuilder->GetInsertBlock();
    292276
    293     if (mCarryManager->hasCarries()) {
    294         mCarryManager->storeCarryOutSummary();
    295     }
    296     mCarryManager->finalizeWhileBlockCarryDataPhis(whileExitBlock);
     277    mCarryManager->leaveLoopBody(whileExitBlock);
    297278
    298279    // Terminate the while loop body with a conditional branch back.
    299     Value * cond_expr = iBuilder->bitblock_any(compileExpression(whileStatement->getCondition()));
     280    Value * condition = compileExpression(whileStatement->getCondition());
     281    if (condition->getType() == iBuilder->getBitBlockType()) {
     282        condition = iBuilder->bitblock_any(condition);
     283    }
    300284#ifdef ENABLE_BOUNDED_WHILE
    301285    if (whileStatement->getBound()) {
    302286        Value * new_bound = iBuilder->CreateSub(bound_phi, ConstantInt::get(iBuilder->getSizeTy(), 1));
    303287        bound_phi->addIncoming(new_bound, whileExitBlock);
    304         cond_expr = iBuilder->CreateAnd(cond_expr, iBuilder->CreateICmpUGT(new_bound, ConstantInt::getNullValue(iBuilder->getSizeTy())));
    305     }
    306 #endif   
    307     iBuilder->CreateCondBr(cond_expr, whileBodyBlock, whileEndBlock);
     288        condition = iBuilder->CreateAnd(condition, iBuilder->CreateICmpUGT(new_bound, ConstantInt::getNullValue(iBuilder->getSizeTy())));
     289    }
     290#endif
    308291
    309292    // and for any variant nodes in the loop body
     
    319302            llvm::report_fatal_error(out.str());
    320303        }
    321         phi->addIncoming(f->second, whileExitBlock);
     304        Value * exitValue = f->second;
     305        assert (phi->getType() == exitValue->getType());
     306        phi->addIncoming(exitValue, whileExitBlock);
    322307        f->second = phi;
    323308    }
    324309
     310    iBuilder->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
     311
    325312    iBuilder->SetInsertPoint(whileEndBlock);
    326     --mWhileDepth;
    327 
    328     mCarryManager->ensureCarriesStoredRecursive();
    329     mCarryManager->leaveScope();
     313
     314    mCarryManager->leaveLoopScope(whileEntryBlock, whileExitBlock);
     315
    330316}
    331317
     
    359345                    std::string tmp;
    360346                    raw_string_ostream out(tmp);
    361                     out << "Use-before-definition error: ";
     347                    out << "PabloCompiler: use-before-definition error: ";
    362348                    expr->print(out);
    363349                    out << " does not dominate ";
    364350                    stmt->print(out);
    365                     throw std::runtime_error(out.str());
     351                    llvm::report_fatal_error(out.str());
    366352                }
    367353                Value * const ptr = f->second;
     
    374360                    value = iBuilder->CreateAdd(value, count);
    375361                }
    376 
    377 //                cast<PointerType>(ptr->getType())->getElementType()->getPrimitiveSizeInBits() / 8;
    378362
    379363                const Type * const type = value->getType();
     
    392376            Value * index = compileExpression(extract->getIndex());
    393377            value = iBuilder->CreateGEP(array, {ConstantInt::getNullValue(index->getType()), index}, getName(stmt));
    394         } else if (const And * pablo_and = dyn_cast<And>(stmt)) {
    395             value = iBuilder->simd_and(compileExpression(pablo_and->getOperand(0)), compileExpression(pablo_and->getOperand(1)));
    396         } else if (const Or * pablo_or = dyn_cast<Or>(stmt)) {
    397             value = iBuilder->simd_or(compileExpression(pablo_or->getOperand(0)), compileExpression(pablo_or->getOperand(1)));
    398         } else if (const Xor * pablo_xor = dyn_cast<Xor>(stmt)) {
    399             value = iBuilder->simd_xor(compileExpression(pablo_xor->getOperand(0)), compileExpression(pablo_xor->getOperand(1)));
     378        } else if (isa<And>(stmt)) {
     379            value = compileExpression(stmt->getOperand(0));
     380            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
     381                value = iBuilder->simd_and(value, compileExpression(stmt->getOperand(1)));
     382            }
     383        } else if (isa<Or>(stmt)) {
     384            value = compileExpression(stmt->getOperand(0));
     385            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
     386                value = iBuilder->simd_or(value, compileExpression(stmt->getOperand(1)));
     387            }
     388        } else if (isa<Xor>(stmt)) {
     389            value = compileExpression(stmt->getOperand(0));
     390            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
     391                value = iBuilder->simd_xor(value, compileExpression(stmt->getOperand(1)));
     392            }
    400393        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
    401394            Value* ifMask = compileExpression(sel->getCondition());
     
    403396            Value* ifFalse = iBuilder->simd_and(iBuilder->simd_not(ifMask), compileExpression(sel->getFalseExpr()));
    404397            value = iBuilder->simd_or(ifTrue, ifFalse);
    405         } else if (const Not * pablo_not = dyn_cast<Not>(stmt)) {
    406             value = iBuilder->simd_not(compileExpression(pablo_not->getExpr()));
    407         } else if (const Advance * adv = dyn_cast<Advance>(stmt)) {
    408             Value * const strm_value = compileExpression(adv->getExpr());
    409             value = mCarryManager->advanceCarryInCarryOut(adv->getLocalIndex(), adv->getAmount(), strm_value);
     398        } else if (isa<Not>(stmt)) {
     399            value = iBuilder->simd_not(compileExpression(stmt->getOperand(0)));
     400        } else if (isa<Advance>(stmt)) {
     401            const Advance * const adv = cast<Advance>(stmt);
     402            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
     403            // manager so that it properly selects the correct carry bit.
     404            value = mCarryManager->advanceCarryInCarryOut(adv, compileExpression(adv->getExpression()));
    410405        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
    411406            Value * const marker = compileExpression(mstar->getMarker());
    412407            Value * const cc = compileExpression(mstar->getCharClass());
    413408            Value * const marker_and_cc = iBuilder->simd_and(marker, cc);
    414             Value * const sum = mCarryManager->addCarryInCarryOut(mstar->getLocalCarryIndex(), marker_and_cc, cc);
     409            Value * const sum = mCarryManager->addCarryInCarryOut(mstar, marker_and_cc, cc);
    415410            value = iBuilder->simd_or(iBuilder->simd_xor(sum, cc), marker);
    416411        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
    417             Value * const  marker_expr = compileExpression(sthru->getScanFrom());
    418             Value * const  cc_expr = compileExpression(sthru->getScanThru());
    419             Value * const  sum = mCarryManager->addCarryInCarryOut(sthru->getLocalCarryIndex(), marker_expr, cc_expr);
     412            Value * const marker_expr = compileExpression(sthru->getScanFrom());
     413            Value * const cc_expr = compileExpression(sthru->getScanThru());
     414            Value * const sum = mCarryManager->addCarryInCarryOut(sthru, marker_expr, cc_expr);
    420415            value = iBuilder->simd_and(sum, iBuilder->simd_not(cc_expr));
    421416        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
     
    473468                    Value * b0 = iBuilder->CreateBitCast(lookAhead, streamType);
    474469                    Value * result = iBuilder->CreateOr(iBuilder->CreateShl(b1, iBuilder->getBitBlockWidth() - bit_shift), iBuilder->CreateLShr(b0, bit_shift));
    475                     value = iBuilder->CreateBitCast(result, mBitBlockType);
     470                    value = iBuilder->CreateBitCast(result, iBuilder->getBitBlockType());
    476471                }
    477472            }
     
    501496    } else if (LLVM_UNLIKELY(isa<Integer>(expr))) {
    502497        return iBuilder->getInt64(cast<Integer>(expr)->value());
     498    } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
     499        const Operator * op = cast<Operator>(expr);
     500        Value * lh = compileExpression(op->getLH());
     501        Value * rh = compileExpression(op->getRH());
     502        assert (lh->getType() == rh->getType());
     503        switch (op->getClassTypeId()) {
     504            case TypeId::Add:
     505                return iBuilder->CreateAdd(lh, rh);
     506            case TypeId::Subtract:
     507                return iBuilder->CreateSub(lh, rh);
     508            case TypeId::LessThan:
     509                return iBuilder->CreateICmpSLT(lh, rh);
     510            case TypeId::LessThanEquals:
     511                return iBuilder->CreateICmpSLE(lh, rh);
     512            case TypeId::Equals:
     513                return iBuilder->CreateICmpEQ(lh, rh);
     514            case TypeId::GreaterThanEquals:
     515                return iBuilder->CreateICmpSGE(lh, rh);
     516            case TypeId::GreaterThan:
     517                return iBuilder->CreateICmpSGT(lh, rh);
     518            case TypeId::NotEquals:
     519                return iBuilder->CreateICmpNE(lh, rh);
     520            default:
     521                break;
     522        }
     523        std::string tmp;
     524        raw_string_ostream out(tmp);
     525        expr->print(out);
     526        out << " is not a valid Operator";
     527        llvm::report_fatal_error(out.str());
    503528    }
    504529    const auto f = mMarkerMap.find(expr);
     
    517542}
    518543
    519 }
     544PabloCompiler::PabloCompiler(PabloKernel * kernel)
     545: iBuilder(kernel->getBuilder())
     546, mCarryManager(new CarryManager(iBuilder))
     547, mKernel(kernel)
     548, mFunction(nullptr) {
     549
     550}
     551
     552PabloCompiler::~PabloCompiler() {
     553    delete mCarryManager;
     554}
     555
     556}
Note: See TracChangeset for help on using the changeset viewer.