source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 5445

Last change on this file since 5445 was 5440, checked in by nmedfort, 2 years ago

Large refactoring step. Removed IR generation code from Kernel (formally KernelBuilder?) and moved it into the new KernelBuilder? class.

File size: 13.2 KB
Line 
1#include "kernel_builder.h"
2#include <kernels/kernel.h>
3#include <kernels/streamset.h>
4
5using namespace llvm;
6using namespace parabix;
7
8using Value = Value;
9
10namespace kernel {
11
12Value * KernelBuilder::getScalarFieldPtr(llvm::Value * instance, Value * const index) {
13    return CreateGEP(instance, {getInt32(0), index});
14}
15
16Value * KernelBuilder::getScalarFieldPtr(llvm::Value * instance, const std::string & fieldName) {
17    return getScalarFieldPtr(instance, getInt32(mKernel->getScalarIndex(fieldName)));
18}
19
20llvm::Value * KernelBuilder::getScalarFieldPtr(llvm::Value * index) {
21    return getScalarFieldPtr(mKernel->getInstance(), index);
22}
23
24llvm::Value *KernelBuilder:: getScalarFieldPtr(const std::string & fieldName) {
25    return getScalarFieldPtr(mKernel->getInstance(), fieldName);
26}
27
28Value * KernelBuilder::getScalarField(const std::string & fieldName) {
29    return CreateLoad(getScalarFieldPtr(fieldName), fieldName);
30}
31
32void KernelBuilder::setScalarField(const std::string & fieldName, Value * value) {
33    CreateStore(value, getScalarFieldPtr(fieldName));
34}
35
36Value * KernelBuilder::getStreamSetBufferPtr(const std::string & name) {
37    return getScalarField(name + Kernel::BUFFER_PTR_SUFFIX);
38}
39
40LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
41    return CreateAtomicLoadAcquire(getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
42}
43
44void KernelBuilder::releaseLogicalSegmentNo(Value * nextSegNo) {
45    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
46}
47
48Value * KernelBuilder::getProducedItemCount(const std::string & name, Value * doFinal) {
49    Kernel::Port port; unsigned index;
50    std::tie(port, index) = mKernel->getStreamPort(name);
51    assert (port == Kernel::Port::Output);
52    const auto rate = mKernel->getStreamOutput(index).rate;
53    if (rate.isExact()) {
54        const auto & refSet = rate.referenceStreamSet();
55        std::string principalField;
56        if (refSet.empty()) {
57            if (mKernel->getStreamInputs().empty()) {
58                principalField = mKernel->getStreamOutput(0).name + Kernel::PRODUCED_ITEM_COUNT_SUFFIX;
59            } else {
60                principalField = mKernel->getStreamInput(0).name + Kernel::PROCESSED_ITEM_COUNT_SUFFIX;
61            }
62        } else {
63            std::tie(port, index) = mKernel->getStreamPort(refSet);
64            if (port == Kernel::Port::Input) {
65               principalField = refSet + Kernel::PROCESSED_ITEM_COUNT_SUFFIX;
66            } else {
67               principalField = refSet + Kernel::PRODUCED_ITEM_COUNT_SUFFIX;
68            }
69        }
70        Value * const principleCount = getScalarField(principalField);
71        return rate.CreateRatioCalculation(this, principleCount, doFinal);
72    }
73    return getScalarField(name + Kernel::PRODUCED_ITEM_COUNT_SUFFIX);
74}
75
76Value * KernelBuilder::getProcessedItemCount(const std::string & name) {
77    Kernel::Port port; unsigned index;
78    std::tie(port, index) = mKernel->getStreamPort(name);
79    assert (port == Kernel::Port::Input);
80    const auto & rate = mKernel->getStreamInput(index).rate;
81    if (rate.isExact()) {
82        std::string refSet = rate.referenceStreamSet();
83        if (refSet.empty()) {
84            refSet = mKernel->getStreamInput(0).name;
85        }
86        Value * const principleCount = getScalarField(refSet + Kernel::PROCESSED_ITEM_COUNT_SUFFIX);
87        return rate.CreateRatioCalculation(this, principleCount);
88    }
89    return getScalarField(name + Kernel::PROCESSED_ITEM_COUNT_SUFFIX);
90}
91
92Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
93    const auto & inputs = mKernel->getStreamInputs();
94    for (unsigned i = 0; i < inputs.size(); ++i) {
95        if (inputs[i].name == name) {
96            return mKernel->getAvailableItemCount(i);
97        }
98    }
99    return nullptr;
100}
101
102Value * KernelBuilder::getConsumedItemCount(const std::string & name) {
103    return getScalarField(name + Kernel::CONSUMED_ITEM_COUNT_SUFFIX);
104}
105
106void KernelBuilder::setProducedItemCount(const std::string & name, Value * value) {
107    setScalarField(name + Kernel::PRODUCED_ITEM_COUNT_SUFFIX, value);
108}
109
110void KernelBuilder::setProcessedItemCount(const std::string & name, Value * value) {
111    setScalarField(name + Kernel::PROCESSED_ITEM_COUNT_SUFFIX, value);
112}
113
114void KernelBuilder::setConsumedItemCount(const std::string & name, Value * value) {
115    setScalarField(name + Kernel::CONSUMED_ITEM_COUNT_SUFFIX, value);
116}
117
118Value * KernelBuilder::getTerminationSignal() {
119    return getScalarField(Kernel::TERMINATION_SIGNAL);
120}
121
122void KernelBuilder::setTerminationSignal() {
123    setScalarField(Kernel::TERMINATION_SIGNAL, getTrue());
124}
125
126Value * KernelBuilder::getLinearlyAccessibleItems(const std::string & name, Value * fromPosition) {
127    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
128    return buf->getLinearlyAccessibleItems(this, fromPosition);
129}
130
131Value * KernelBuilder::getConsumerLock(const std::string & name) {
132    return getScalarField(name + Kernel::CONSUMER_SUFFIX);
133}
134
135void KernelBuilder::setConsumerLock(const std::string & name, Value * value) {
136    setScalarField(name + Kernel::CONSUMER_SUFFIX, value);
137}
138
139inline Value * KernelBuilder::computeBlockIndex(Value * itemCount) {
140    const auto divisor = getBitBlockWidth();
141    if (LLVM_LIKELY((divisor & (divisor - 1)) == 0)) {
142        return CreateLShr(itemCount, std::log2(divisor));
143    } else {
144        return CreateUDiv(itemCount, getSize(divisor));
145    }
146}
147
148Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * streamIndex) {
149    Value * const blockIndex = computeBlockIndex(getProcessedItemCount(name));
150    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
151    return buf->getStreamBlockPtr(this, getStreamSetBufferPtr(name), streamIndex, blockIndex, true);
152}
153
154Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * streamIndex) {
155    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex));
156}
157
158Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
159    Value * const blockIndex = computeBlockIndex(getProcessedItemCount(name));
160    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
161    return buf->getStreamPackPtr(this, getStreamSetBufferPtr(name), streamIndex, blockIndex, packIndex, true);
162}
163
164Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex) {
165    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex));
166}
167
168Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
169    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
170    return buf->getStreamSetCount(this, getStreamSetBufferPtr(name));
171}
172
173Value * KernelBuilder::getAdjustedInputStreamBlockPtr(Value * blockAdjustment, const std::string & name, Value * streamIndex) {
174    Value * const blockIndex = CreateAdd(computeBlockIndex(getProcessedItemCount(name)), blockAdjustment);
175    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
176    return buf->getStreamBlockPtr(this, getStreamSetBufferPtr(name), streamIndex, blockIndex, true);
177}
178
179Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex) {
180    Value * const blockIndex = computeBlockIndex(getProducedItemCount(name));
181    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
182    return buf->getStreamBlockPtr(this, getStreamSetBufferPtr(name), streamIndex, blockIndex, false);
183}
184
185void KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, Value * toStore) {
186    return CreateBlockAlignedStore(toStore, getOutputStreamBlockPtr(name, streamIndex));
187}
188
189Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
190    Value * const blockIndex = computeBlockIndex(getProducedItemCount(name));
191    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
192    return buf->getStreamPackPtr(this, getStreamSetBufferPtr(name), streamIndex, blockIndex, packIndex, false);
193}
194
195void KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, Value * toStore) {
196    return CreateBlockAlignedStore(toStore, getOutputStreamPackPtr(name, streamIndex, packIndex));
197}
198
199Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
200    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
201    return buf->getStreamSetCount(this, getStreamSetBufferPtr(name));
202}
203
204Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * streamIndex, Value * absolutePosition) {
205    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
206    return buf->getRawItemPointer(this, getStreamSetBufferPtr(name), streamIndex, absolutePosition);
207}
208
209Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * streamIndex, Value * absolutePosition) {
210    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
211    return buf->getRawItemPointer(this, getStreamSetBufferPtr(name), streamIndex, absolutePosition);
212}
213
214Value * KernelBuilder::getBaseAddress(const std::string & name) {
215    return mKernel->getAnyStreamSetBuffer(name)->getBaseAddress(this, getStreamSetBufferPtr(name));
216}
217
218void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
219    return mKernel->getAnyStreamSetBuffer(name)->setBaseAddress(this, getStreamSetBufferPtr(name), addr);
220}
221
222Value * KernelBuilder::getBufferedSize(const std::string & name) {
223    return mKernel->getAnyStreamSetBuffer(name)->getBufferedSize(this, getStreamSetBufferPtr(name));
224}
225
226void KernelBuilder::setBufferedSize(const std::string & name, Value * size) {
227    mKernel->getAnyStreamSetBuffer(name)->setBufferedSize(this, getStreamSetBufferPtr(name), size);
228}
229
230
231CallInst * KernelBuilder::createDoSegmentCall(const std::vector<Value *> & args) {
232    Function * const doSegment = mKernel->getDoSegmentFunction(getModule());
233    assert (doSegment->getArgumentList().size() == args.size());
234    return CreateCall(doSegment, args);
235}
236
237Value * KernelBuilder::getAccumulator(const std::string & accumName) {
238    auto results = mKernel->mOutputScalarResult;
239    if (LLVM_UNLIKELY(results == nullptr)) {
240        report_fatal_error("Cannot get accumulator " + accumName + " until " + mKernel->getName() + " has terminated.");
241    }
242    const auto & outputs = mKernel->getScalarOutputs();
243    const auto n = outputs.size();
244    if (LLVM_UNLIKELY(n == 0)) {
245        report_fatal_error(mKernel->getName() + " has no output scalars.");
246    } else {
247        for (unsigned i = 0; i < n; ++i) {
248            const Binding & b = outputs[i];
249            if (b.name == accumName) {
250                if (n == 1) {
251                    return results;
252                } else {
253                    return CreateExtractValue(results, {i});
254                }
255            }
256        }
257        report_fatal_error(mKernel->getName() + " has no output scalar named " + accumName);
258    }
259}
260
261BasicBlock * KernelBuilder::CreateConsumerWait() {
262    const auto consumers = mKernel->getStreamOutputs();
263    BasicBlock * const entry = GetInsertBlock();
264    if (consumers.empty()) {
265        return entry;
266    } else {
267        Function * const parent = entry->getParent();
268        IntegerType * const sizeTy = getSizeTy();
269        ConstantInt * const zero = getInt32(0);
270        ConstantInt * const one = getInt32(1);
271        ConstantInt * const size0 = getSize(0);
272
273        Value * const segNo = acquireLogicalSegmentNo();
274        const auto n = consumers.size();
275        BasicBlock * load[n + 1];
276        BasicBlock * wait[n];
277        for (unsigned i = 0; i < n; ++i) {
278            load[i] = BasicBlock::Create(getContext(), consumers[i].name + "Load", parent);
279            wait[i] = BasicBlock::Create(getContext(), consumers[i].name + "Wait", parent);
280        }
281        load[n] = BasicBlock::Create(getContext(), "Resume", parent);
282        CreateBr(load[0]);
283        for (unsigned i = 0; i < n; ++i) {
284
285            SetInsertPoint(load[i]);
286            Value * const outputConsumers = getConsumerLock(consumers[i].name);
287
288            Value * const consumerCount = CreateLoad(CreateGEP(outputConsumers, {zero, zero}));
289            Value * const consumerPtr = CreateLoad(CreateGEP(outputConsumers, {zero, one}));
290            Value * const noConsumers = CreateICmpEQ(consumerCount, size0);
291            CreateUnlikelyCondBr(noConsumers, load[i + 1], wait[i]);
292
293            SetInsertPoint(wait[i]);
294            PHINode * const consumerPhi = CreatePHI(sizeTy, 2);
295            consumerPhi->addIncoming(size0, load[i]);
296
297            Value * const conSegPtr = CreateLoad(CreateGEP(consumerPtr, consumerPhi));
298            Value * const processedSegmentCount = CreateAtomicLoadAcquire(conSegPtr);
299            Value * const ready = CreateICmpEQ(segNo, processedSegmentCount);
300            assert (ready->getType() == getInt1Ty());
301            Value * const nextConsumerIdx = CreateAdd(consumerPhi, CreateZExt(ready, sizeTy));
302            consumerPhi->addIncoming(nextConsumerIdx, wait[i]);
303            Value * const next = CreateICmpEQ(nextConsumerIdx, consumerCount);
304            CreateCondBr(next, load[i + 1], wait[i]);
305        }
306
307        BasicBlock * const exit = load[n];
308        SetInsertPoint(exit);
309        return exit;
310    }
311}
312
313}
Note: See TracBrowser for help on using the repository browser.