Ignore:
Timestamp:
Jul 7, 2017, 8:58:29 AM (2 years ago)
Author:
cameron
Message:

Processing rate analysis and validation

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/kernel.cpp

    r5551 r5552  
    139139    }
    140140    const auto blockSize = idb->getBitBlockWidth();
     141    if (mStride == 0) {
     142        // Set the default kernel stride.
     143        mStride = blockSize;
     144    }
    141145    const auto requiredBlocks = codegen::SegmentSize + ((blockSize + mLookAheadPositions - 1) / blockSize);
    142146
    143147    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
    144         // Default reference stream set is the principal input stream set.
    145         if (mStreamSetInputs[i].rate.referenceStreamSet() == "") {
    146             mStreamSetInputs[i].rate.setReferenceStreamSet(mStreamSetInputs[0].name);
    147         }
    148148        if ((mStreamSetInputBuffers[i]->getBufferBlocks() != 0) && (mStreamSetInputBuffers[i]->getBufferBlocks() < requiredBlocks)) {
    149149            report_fatal_error(getName() + ": " + mStreamSetInputs[i].name + " requires buffer size " + std::to_string(requiredBlocks));
     
    157157    IntegerType * const sizeTy = idb->getSizeTy();
    158158    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
    159         // Default reference stream set is the principal input stream set for the principal output stream set.
    160         // Default reference stream set is the principal output stream set for other output stream sets.
    161         if (mStreamSetOutputs[i].rate.referenceStreamSet() == "") {
    162             if ((mStreamSetInputs.size() > 0) && (i == 0)) {
    163                 mStreamSetOutputs[i].rate.setReferenceStreamSet(mStreamSetInputs[0].name);
    164             }
    165             else {
    166                 mStreamSetOutputs[i].rate.setReferenceStreamSet(mStreamSetOutputs[0].name);
    167             }
    168 
    169         }
    170159        mScalarInputs.emplace_back(mStreamSetOutputBuffers[i]->getStreamSetHandle()->getType(), mStreamSetOutputs[i].name + BUFFER_PTR_SUFFIX);
    171160        if ((mStreamSetInputs.empty() && (i == 0)) || !mStreamSetOutputs[i].rate.isExact()) {
     
    204193   
    205194    mKernelStateType = StructType::create(idb->getContext(), mKernelFields, getName());
    206 }
     195   
     196    processingRateAnalysis();
     197}
     198   
     199   
     200void Kernel::processingRateAnalysis() {
     201   
     202    const unsigned inputSetCount = mStreamSetInputs.size();
     203    const unsigned outputSetCount = mStreamSetOutputs.size();
     204    const unsigned totalSetCount = inputSetCount + outputSetCount;
     205   
     206    mItemsPerStride.resize(totalSetCount);
     207    mIsDerived.resize(totalSetCount);
     208
     209    mItemsPerStride[0] = mStride;
     210    mIsDerived[0] = true;
     211   
     212    for (unsigned i = 0; i < inputSetCount; i++) {
     213        // Default reference stream set is the principal input stream set.
     214        auto & rate = mStreamSetInputs[i].rate;
     215        if (rate.referenceStreamSet() == "") {
     216            rate.setReferenceStreamSet(mStreamSetInputs[0].name);
     217        }
     218        Port port; unsigned ssIdx;
     219        std::tie(port, ssIdx) = getStreamPort(rate.referenceStreamSet());
     220        if ((port == Port::Output) || (ssIdx > i) || ((ssIdx == i) && (i > 0))) {
     221            report_fatal_error(getName() + ": input set " + mStreamSetInputs[i].name + ": forward or circular rate dependency");
     222        }
     223        if ((rate.isExact() || rate.isMaxRatio()) && mIsDerived[ssIdx]) {
     224            if ((mItemsPerStride[ssIdx] % rate.getRatioDenominator()) != 0) {
     225                report_fatal_error(getName() + ": " + mStreamSetInputs[i].name + " processing rate denominator does not exactly divide items per stride.");
     226            }
     227            mItemsPerStride[i] = rate.calculateRatio(mItemsPerStride[ssIdx]);
     228            mIsDerived[i] = rate.isExact();
     229        }
     230        else {
     231            mIsDerived[i] = false;
     232            mItemsPerStride[i] = mStride;
     233        }
     234    }
     235   
     236    for (unsigned i = inputSetCount; i < totalSetCount; i++) {
     237        auto & rate = mStreamSetOutputs[i-inputSetCount].rate;
     238        // Default reference stream set is the principal input stream set for the principal output stream set.
     239        // Default reference stream set is the principal output stream set for other output stream sets.
     240        if (rate.referenceStreamSet() == "") {
     241            if ((mStreamSetInputs.size() > 0) && (i == inputSetCount)) {
     242                rate.setReferenceStreamSet(mStreamSetInputs[0].name);
     243            }
     244            else {
     245                rate.setReferenceStreamSet(mStreamSetOutputs[0].name);
     246            }
     247        }
     248        Port port; unsigned ssIdx;
     249        std::tie(port, ssIdx) = getStreamPort(rate.referenceStreamSet());
     250        if (port == Port::Output) ssIdx += inputSetCount;
     251        if ((ssIdx > i) || ((ssIdx == i) && (i > 0))) {
     252            report_fatal_error(getName() + ": output set " + mStreamSetOutputs[i].name + ": forward or circular rate dependency");
     253        }
     254        if ((rate.isExact() || rate.isMaxRatio()) && mIsDerived[ssIdx]) {
     255            if ((mItemsPerStride[ssIdx] % rate.getRatioDenominator()) != 0) {
     256                report_fatal_error(getName() + ": " + mStreamSetOutputs[i-inputSetCount].name + " processing rate denominator does not exactly divide items per stride.");
     257            }
     258            mItemsPerStride[i] = rate.calculateRatio(mItemsPerStride[ssIdx]);
     259            mIsDerived[i] = rate.isExact();
     260        }
     261        else {
     262            mIsDerived[i] = false;
     263            mItemsPerStride[i] = mStride;
     264        }
     265    }
     266}
     267
    207268
    208269// Default kernel signature: generate the IR and emit as byte code.
     
    672733    const unsigned totalSetCount = inputSetCount + outputSetCount;
    673734   
    674     unsigned itemsPerStride[totalSetCount];
    675     bool isDerived[totalSetCount];
    676    
    677     if (mStride == 0) mStride = bitBlockWidth;
    678 
    679     itemsPerStride[0] = mStride;
    680     isDerived[0] = true;
    681     for (unsigned i = 1; i < inputSetCount; i++) {
    682         auto & rate = mStreamSetInputs[i].rate;
    683         std::string refSet = mStreamSetInputs[i].rate.referenceStreamSet();
    684         if (rate.isExact()) {
    685             Port port; unsigned ssIdx;
    686             std::tie(port, ssIdx) = getStreamPort(refSet);
    687             assert (port == Port::Input && ssIdx < i);
    688             if ((ssIdx == 0) || isDerived[ssIdx]) {
    689                 itemsPerStride[i] = rate.calculateRatio(itemsPerStride[ssIdx]);
    690                 isDerived[i] = true;
    691                 continue;
    692             }
    693         }
    694         isDerived[i] = false;
    695     }
    696     for (unsigned i = inputSetCount; i < totalSetCount; i++) {
    697         auto & rate = mStreamSetOutputs[i-inputSetCount].rate;
    698         std::string refSet = rate.referenceStreamSet();
    699         if (rate.isExact() || rate.isMaxRatio()) {
    700             Port port; unsigned ssIdx;
    701             std::tie(port, ssIdx) = getStreamPort(refSet);
    702             if (port == Port::Output) ssIdx += inputSetCount;
    703             if ((ssIdx == 0) || isDerived[ssIdx]) {
    704                 itemsPerStride[i] = rate.calculateRatio(itemsPerStride[ssIdx]);
    705                 isDerived[i] = rate.isExact();
    706                 continue;
    707             }
    708         }
    709         isDerived[i] = false;
    710     }
    711735    int maxBlocksToCopy[totalSetCount];
    712736    for (unsigned i = 0; i < totalSetCount; i++) {
    713         if (isDerived[i]) {
    714             if (itemsPerStride[i] % bitBlockWidth == 0) {
    715                 maxBlocksToCopy[i] = itemsPerStride[i] / bitBlockWidth;
     737        if (mIsDerived[i]) {
     738            if (mItemsPerStride[i] % bitBlockWidth == 0) {
     739                maxBlocksToCopy[i] = mItemsPerStride[i] / bitBlockWidth;
    716740            }
    717741            else {
    718742                // May not be block aligned, can overlap partial blocks at both ends.
    719                 maxBlocksToCopy[i] = itemsPerStride[i]/bitBlockWidth + 2;
     743                maxBlocksToCopy[i] = mItemsPerStride[i]/bitBlockWidth + 2;
    720744            }
    721745        }
     
    739763    multiBlockParmTypes.push_back(kb->getSizeTy());
    740764    for (unsigned i = 1; i < mStreamSetInputs.size(); i++) {
    741         if (!isDerived[i]) multiBlockParmTypes.push_back(kb->getSizeTy());
     765        if (!mIsDerived[i]) multiBlockParmTypes.push_back(kb->getSizeTy());
    742766    }
    743767    for (auto buffer : mStreamSetInputBuffers) {
     
    760784    (++args)->setName("itemsToDo");
    761785    for (unsigned i = 1; i < mStreamSetInputs.size(); i++) {
    762         if (!isDerived[i]) (++args)->setName(mStreamSetInputs[i].name + "_availItems");
     786        if (!mIsDerived[i]) (++args)->setName(mStreamSetInputs[i].name + "_availItems");
    763787    }
    764788    for (auto binding : mStreamSetInputs) {
     
    816840        Value * a = mAvailableItemCount[i];
    817841        auto & rate = mStreamSetInputs[i].rate;
    818         if (isDerived[i]) {
     842        if (mIsDerived[i]) {
    819843            Value * maxItems = rate.CreateMaxReferenceItemsCalculation(kb.get(), a);
    820844            itemsAvail = kb->CreateSelect(kb->CreateICmpULT(itemsAvail, maxItems), itemsAvail, maxItems);
     
    858882        processedItemCount.push_back(p);
    859883        inputBlockPtr.push_back(b);
    860         if (isDerived[i]) {
     884        if (mIsDerived[i]) {
    861885            auto & rate = mStreamSetInputs[i].rate;
    862886            Value * maxReferenceItems = nullptr;
     
    881905        producedItemCount.push_back(p);
    882906        outputBlockPtr.push_back(b);
    883         if (isDerived[inputSetCount + i]) {
     907        if (mIsDerived[inputSetCount + i]) {
    884908            auto & rate = mStreamSetOutputs[i].rate;
    885909            Value * maxReferenceItems = nullptr;
     
    906930    doMultiBlockArgs.push_back(getInstance());
    907931    doMultiBlockArgs.push_back(linearlyAvailItems);
    908     for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
    909         if (!isDerived[i]) {
     932    for (unsigned i = 1; i < mStreamSetInputs.size(); i++) {
     933        if (!mIsDerived[i]) {
    910934            Value * avail = kb->CreateSub(mAvailableItemCount[i], processedItemCount[i]);
    911935            Value * linearlyAvail = kb->getLinearlyAccessibleItems(mStreamSetInputs[i].name, processedItemCount[i]);
     
    9921016    tempArgs.push_back(tempBlockItems);
    9931017    // For non-derived inputs, add the available items.
    994     for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
    995         if (!isDerived[i]) {
     1018    for (unsigned i = 1; i < mStreamSetInputs.size(); i++) {
     1019        if (!mIsDerived[i]) {
    9961020            Value * avail = kb->CreateSub(mAvailableItemCount[i], processedItemCount[i]);
    9971021            tempArgs.push_back(kb->CreateSelect(kb->CreateICmpULT(avail, strideSize), avail, strideSize));
     
    10101034    for (unsigned i = 0; i < mStreamSetInputBuffers.size(); i++) {
    10111035        Type * bufPtrType = mStreamSetInputBuffers[i]->getPointerType();
    1012         if (isDerived[i]) {
     1036        if ((i == 0) || mIsDerived[i]) {
    10131037            Value * tempBufPtr = kb->CreateGEP(tempParameterArea, {kb->getInt32(0), kb->getInt32(i)});
    10141038            tempBufPtr = kb->CreatePointerCast(tempBufPtr, bufPtrType);
    1015             ConstantInt * strideItems = kb->getSize(itemsPerStride[i]);
     1039            ConstantInt * strideItems = kb->getSize(mItemsPerStride[i]);
    10161040            Value * strideBasePos = kb->CreateSub(processedItemCount[i], kb->CreateURem(processedItemCount[i], strideItems));
    1017             Value * blockBasePos = (itemsPerStride[i] % bitBlockWidth == 0) ? strideBasePos : kb->CreateAnd(strideBasePos, blockBaseMask);
     1041            Value * blockBasePos = (mItemsPerStride[i] % bitBlockWidth == 0) ? strideBasePos : kb->CreateAnd(strideBasePos, blockBaseMask);
    10181042
    10191043            // The number of items to copy is determined by the processing rate requirements.
     
    11891213                                   std::vector<Binding> && scalar_outputs,
    11901214                                   std::vector<Binding> && internal_scalars)
    1191 : Kernel(std::move(kernelName), std::move(stream_inputs), std::move(stream_outputs), std::move(scalar_parameters), std::move(scalar_outputs), std::move(internal_scalars))
    1192 , mStride(0) {
     1215: Kernel(std::move(kernelName), std::move(stream_inputs), std::move(stream_outputs), std::move(scalar_parameters), std::move(scalar_outputs), std::move(internal_scalars)) {
    11931216}
    11941217
Note: See TracChangeset for help on using the changeset viewer.