source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp

Last change on this file was 6252, checked in by nmedfort, 7 months ago

Bug fix for consumer information + slight simplification of copyback space calculation

File size: 39.4 KB
Line 
1/*
2 *  Copyright (c) 2014-16 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "pablo_compiler.h"
8#include <pablo/pablo_kernel.h>
9#include <pablo/pablo_toolchain.h>
10#include <pablo/codegenstate.h>
11#include <pablo/boolean.h>
12#include <pablo/arithmetic.h>
13#include <pablo/branch.h>
14#include <pablo/pe_advance.h>
15#include <pablo/pe_lookahead.h>
16#include <pablo/pe_matchstar.h>
17#include <pablo/pe_scanthru.h>
18#include <pablo/pe_infile.h>
19#include <pablo/pe_count.h>
20#include <pablo/pe_integer.h>
21#include <pablo/pe_string.h>
22#include <pablo/pe_zeroes.h>
23#include <pablo/pe_ones.h>
24#include <pablo/pe_repeat.h>
25#include <pablo/pe_pack.h>
26#include <pablo/pe_var.h>
27#include <pablo/ps_assign.h>
28#include <pablo/ps_terminate.h>
29#ifdef USE_CARRYPACK_MANAGER
30#include <pablo/carrypack_manager.h>
31#else
32#include <pablo/carry_manager.h>
33#endif
34#include <kernels/kernel_builder.h>
35#include <kernels/streamset.h>
36#include <llvm/IR/Module.h>
37#include <llvm/IR/Type.h>
38#include <llvm/Support/raw_os_ostream.h>
39#include <llvm/ADT/STLExtras.h> // for make_unique
40
41using namespace llvm;
42
43namespace pablo {
44
45using TypeId = PabloAST::ClassTypeId;
46
47inline static unsigned getAlignment(const Type * const type) {
48    return type->getPrimitiveSizeInBits() / 8;
49}
50
51inline static unsigned getAlignment(const Value * const expr) {
52    return getAlignment(expr->getType());
53}
54
55inline static unsigned getPointerElementAlignment(const Value * const ptr) {
56    return getAlignment(ptr->getType()->getPointerElementType());
57}
58
59void PabloCompiler::initializeKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
60    mBranchCount = 0;
61    examineBlock(b, mKernel->getEntryScope());
62    mCarryManager->initializeCarryData(b, mKernel);
63    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
64        const auto count = (mBranchCount * 2) + 1;
65        mKernel->addInternalScalar(ArrayType::get(mKernel->getSizeTy(), count), "profile");
66        mBasicBlock.reserve(count);
67    }
68}
69
70void PabloCompiler::releaseKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
71    mCarryManager->releaseCarryData(b);
72}
73
74void PabloCompiler::clearCarryData(const std::unique_ptr<kernel::KernelBuilder> & b) {
75    mCarryManager->clearCarryData(b);
76}
77
78void PabloCompiler::compile(const std::unique_ptr<kernel::KernelBuilder> & b) {
79    mCarryManager->initializeCodeGen(b);
80    PabloBlock * const entryBlock = mKernel->getEntryScope(); assert (entryBlock);
81    mMarker.emplace(entryBlock->createZeroes(), b->allZeroes());
82    mMarker.emplace(entryBlock->createOnes(), b->allOnes());
83    mBranchCount = 0;
84    addBranchCounter(b);
85    compileBlock(b, entryBlock);
86    mCarryManager->finalizeCodeGen(b);
87}
88
89void PabloCompiler::examineBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
90    for (const Statement * stmt : *block) {
91        if (LLVM_UNLIKELY(isa<Lookahead>(stmt))) {
92            const Lookahead * const la = cast<Lookahead>(stmt);
93            PabloAST * input = la->getExpression();
94            if (isa<Extract>(input)) {
95                input = cast<Extract>(input)->getArray();
96            }
97            bool notFound = true;
98            if (LLVM_LIKELY(isa<Var>(input))) {
99                for (unsigned i = 0; i < mKernel->getNumOfInputs(); ++i) {
100                    if (input == mKernel->getInput(i)) {
101                        const auto & binding = mKernel->getInputStreamSetBinding(i);
102                        if (LLVM_UNLIKELY(!binding.hasLookahead() || binding.getLookahead() < la->getAmount())) {
103                            std::string tmp;
104                            raw_string_ostream out(tmp);
105                            input->print(out);
106                            out << " must have a lookahead attribute of at least " << la->getAmount();
107                            report_fatal_error(out.str());
108                        }
109                        notFound = false;
110                        break;
111                    }
112                }
113            }
114            if (LLVM_UNLIKELY(notFound)) {
115                report_fatal_error("Lookahead " + stmt->getName() + " can only be performed on an input streamset");
116            }
117        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
118            ++mBranchCount;
119            examineBlock(b, cast<Branch>(stmt)->getBody());
120        } else if (LLVM_UNLIKELY(isa<Count>(stmt))) {
121            mKernel->addInternalScalar(stmt->getType(), stmt->getName().str());
122        }
123    }
124}
125
126void PabloCompiler::addBranchCounter(const std::unique_ptr<kernel::KernelBuilder> & b) {
127    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
128        Value * ptr = b->getScalarFieldPtr("profile");
129        assert (mBasicBlock.size() < ptr->getType()->getPointerElementType()->getArrayNumElements());
130        ptr = b->CreateGEP(ptr, {b->getInt32(0), b->getInt32(mBasicBlock.size())});
131        const auto alignment = getPointerElementAlignment(ptr);
132        Value * value = b->CreateAlignedLoad(ptr, alignment, false, "branchCounter");
133        value = b->CreateAdd(value, ConstantInt::get(cast<IntegerType>(value->getType()), 1));
134        b->CreateAlignedStore(value, ptr, alignment);
135        mBasicBlock.push_back(b->GetInsertBlock());
136    }
137}
138
139inline void PabloCompiler::compileBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
140    for (const Statement * statement : *block) {
141        compileStatement(b, statement);
142    }
143}
144
145void PabloCompiler::compileIf(const std::unique_ptr<kernel::KernelBuilder> & b, const If * const ifStatement) {
146    //
147    //  The If-ElseZero stmt:
148    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
149    //  If the value of the predicate is nonzero, then determine the values of variables
150    //  <var>* by executing the given statements.  Otherwise, the value of the
151    //  variables are all zero.  Requirements: (a) no variable that is defined within
152    //  the body of the if may be accessed outside unless it is explicitly
153    //  listed in the variable list, (b) every variable in the defined list receives
154    //  a value within the body, and (c) the logical consequence of executing
155    //  the statements in the event that the predicate is zero is that the
156    //  values of all defined variables indeed work out to be 0.
157    //
158    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
159    //  is inserted for each variable in the defined variable list.  It receives
160    //  a zero value from the ifentry block and the defined value from the if
161    //  body.
162    //
163
164    BasicBlock * const ifEntryBlock = b->GetInsertBlock();
165    ++mBranchCount;
166    BasicBlock * const ifBodyBlock = b->CreateBasicBlock("if.body_" + std::to_string(mBranchCount));
167    BasicBlock * const ifEndBlock = b->CreateBasicBlock("if.end_" + std::to_string(mBranchCount));
168
169    std::vector<std::pair<const Var *, Value *>> incoming;
170
171    for (const Var * var : ifStatement->getEscaped()) {
172        if (LLVM_UNLIKELY(var->isKernelParameter())) {
173            Value * marker = nullptr;
174            if (var->isScalar()) {
175                marker = b->getScalarFieldPtr(var->getName());
176            } else if (var->isReadOnly()) {
177                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
178            } else if (var->isReadNone()) {
179                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
180            }
181            mMarker[var] = marker;
182        } else {
183            auto f = mMarker.find(var);
184            if (LLVM_UNLIKELY(f == mMarker.end())) {
185                std::string tmp;
186                raw_string_ostream out(tmp);
187                var->print(out);
188                out << " is uninitialized prior to entering ";
189                ifStatement->print(out);
190                report_fatal_error(out.str());
191            }
192            incoming.emplace_back(var, f->second);
193        }
194    }
195
196    const PabloBlock * ifBody = ifStatement->getBody();
197
198    mCarryManager->enterIfScope(b, ifBody);
199
200    Value * condition = compileExpression(b, ifStatement->getCondition());
201    if (condition->getType() == b->getBitBlockType()) {
202        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
203    }
204
205    b->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
206
207    // Entry processing is complete, now handle the body of the if.
208    b->SetInsertPoint(ifBodyBlock);
209
210    mCarryManager->enterIfBody(b, ifEntryBlock);
211
212    addBranchCounter(b);
213
214    compileBlock(b, ifBody);
215
216    mCarryManager->leaveIfBody(b, b->GetInsertBlock());
217
218    BasicBlock * ifExitBlock = b->GetInsertBlock();
219
220    b->CreateBr(ifEndBlock);
221
222    ifEndBlock->moveAfter(ifExitBlock);
223
224    //End Block
225    b->SetInsertPoint(ifEndBlock);
226
227    mCarryManager->leaveIfScope(b, ifEntryBlock, ifExitBlock);
228
229    for (const auto i : incoming) {
230        const Var * var; Value * incoming;
231        std::tie(var, incoming) = i;
232
233        auto f = mMarker.find(var);
234        if (LLVM_UNLIKELY(f == mMarker.end())) {
235            std::string tmp;
236            raw_string_ostream out(tmp);
237            out << "PHINode creation error: ";
238            var->print(out);
239            out << " was not assigned an outgoing value.";
240            report_fatal_error(out.str());
241        }
242
243        Value * const outgoing = f->second;
244        if (LLVM_UNLIKELY(incoming == outgoing)) {
245            continue;
246        }
247
248        if (LLVM_UNLIKELY(incoming->getType() != outgoing->getType())) {
249            std::string tmp;
250            raw_string_ostream out(tmp);
251            out << "PHINode creation error: incoming type of ";
252            var->print(out);
253            out << " (";
254            incoming->getType()->print(out);
255            out << ") differs from the outgoing type (";
256            outgoing->getType()->print(out);
257            out << ") within ";
258            ifStatement->print(out);
259            report_fatal_error(out.str());
260        }
261
262        PHINode * phi = b->CreatePHI(incoming->getType(), 2, var->getName());
263        phi->addIncoming(incoming, ifEntryBlock);
264        phi->addIncoming(outgoing, ifExitBlock);
265        f->second = phi;
266    }
267
268    addBranchCounter(b);
269}
270
271void PabloCompiler::compileWhile(const std::unique_ptr<kernel::KernelBuilder> & b, const While * const whileStatement) {
272
273    const PabloBlock * const whileBody = whileStatement->getBody();
274
275    BasicBlock * whileEntryBlock = b->GetInsertBlock();
276
277    const auto escaped = whileStatement->getEscaped();
278
279#ifdef ENABLE_BOUNDED_WHILE
280    PHINode * bound_phi = nullptr;  // Needed for bounded while loops.
281#endif
282    // On entry to the while structure, proceed to execute the first iteration
283    // of the loop body unconditionally. The while condition is tested at the end of
284    // the loop.
285
286    for (const Var * var : escaped) {
287        if (LLVM_UNLIKELY(var->isKernelParameter())) {
288            Value * marker = nullptr;
289            if (var->isScalar()) {
290                marker = b->getScalarFieldPtr(var->getName());
291            } else if (var->isReadOnly()) {
292                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
293            } else if (var->isReadNone()) {
294                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
295            }
296            mMarker[var] = marker;
297        }
298    }
299
300    mCarryManager->enterLoopScope(b, whileBody);
301
302    BasicBlock * whileBodyBlock = b->CreateBasicBlock("while.body_" + std::to_string(mBranchCount));
303    BasicBlock * whileEndBlock = b->CreateBasicBlock("while.end_" + std::to_string(mBranchCount));
304    ++mBranchCount;
305
306    b->CreateBr(whileBodyBlock);
307
308    b->SetInsertPoint(whileBodyBlock);
309
310    //
311    // There are 3 sets of Phi nodes for the while loop.
312    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
313    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
314    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
315#ifdef ENABLE_BOUNDED_WHILE
316    // (4) The loop bound, if any.
317#endif
318
319    std::vector<std::pair<const Var *, PHINode *>> variants;
320
321    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
322    for (const auto var : escaped) {
323        auto f = mMarker.find(var);
324        if (LLVM_UNLIKELY(f == mMarker.end())) {
325            std::string tmp;
326            raw_string_ostream out(tmp);
327            out << "PHINode creation error: ";
328            var->print(out);
329            out << " is uninitialized prior to entering ";
330            whileStatement->print(out);
331            report_fatal_error(out.str());
332        }
333        Value * entryValue = f->second;
334        PHINode * phi = b->CreatePHI(entryValue->getType(), 2, var->getName());
335        phi->addIncoming(entryValue, whileEntryBlock);
336        f->second = phi;
337        assert(mMarker[var] == phi);
338        variants.emplace_back(var, phi);
339    }
340#ifdef ENABLE_BOUNDED_WHILE
341    if (whileStatement->getBound()) {
342        bound_phi = b->CreatePHI(b->getSizeTy(), 2, "while_bound");
343        bound_phi->addIncoming(b->getSize(whileStatement->getBound()), whileEntryBlock);
344    }
345#endif
346
347    mCarryManager->enterLoopBody(b, whileEntryBlock);
348
349    addBranchCounter(b);
350
351    compileBlock(b, whileBody);
352
353    // After the whileBody has been compiled, we may be in a different basic block.
354
355    mCarryManager->leaveLoopBody(b, b->GetInsertBlock());
356
357
358#ifdef ENABLE_BOUNDED_WHILE
359    if (whileStatement->getBound()) {
360        Value * new_bound = b->CreateSub(bound_phi, b->getSize(1));
361        bound_phi->addIncoming(new_bound, whileExitBlock);
362        condition = b->CreateAnd(condition, b->CreateICmpUGT(new_bound, ConstantInt::getNullValue(b->getSizeTy())));
363    }
364#endif
365
366    BasicBlock * const whileExitBlock = b->GetInsertBlock();
367
368    // and for any variant nodes in the loop body
369    for (const auto variant : variants) {
370        const Var * var; PHINode * incomingPhi;
371        std::tie(var, incomingPhi) = variant;
372        const auto f = mMarker.find(var);
373        if (LLVM_UNLIKELY(f == mMarker.end())) {
374            std::string tmp;
375            raw_string_ostream out(tmp);
376            out << "PHINode creation error: ";
377            var->print(out);
378            out << " is no longer assigned a value.";
379            report_fatal_error(out.str());
380        }
381
382        Value * const outgoingValue = f->second;
383
384        if (LLVM_UNLIKELY(incomingPhi->getType() != outgoingValue->getType())) {
385            std::string tmp;
386            raw_string_ostream out(tmp);
387            out << "PHINode creation error: incoming type of ";
388            var->print(out);
389            out << " (";
390            incomingPhi->getType()->print(out);
391            out << ") differs from the outgoing type (";
392            outgoingValue->getType()->print(out);
393            out << ") within ";
394            whileStatement->print(out);
395            report_fatal_error(out.str());
396        }
397
398        incomingPhi->addIncoming(outgoingValue, whileExitBlock);
399    }
400
401    // Terminate the while loop body with a conditional branch back.
402    Value * condition = compileExpression(b, whileStatement->getCondition());
403    if (condition->getType() == b->getBitBlockType()) {
404        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
405    }
406
407    b->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
408
409    whileEndBlock->moveAfter(whileExitBlock);
410
411    b->SetInsertPoint(whileEndBlock);
412
413    mCarryManager->leaveLoopScope(b, whileEntryBlock, whileExitBlock);
414
415    addBranchCounter(b);
416}
417
418void PabloCompiler::compileStatement(const std::unique_ptr<kernel::KernelBuilder> & b, const Statement * const stmt) {
419
420    if (LLVM_UNLIKELY(isa<If>(stmt))) {
421        compileIf(b, cast<If>(stmt));
422    } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
423        compileWhile(b, cast<While>(stmt));
424    } else {
425        const PabloAST * expr = stmt;
426        Value * value = nullptr;
427        if (isa<And>(stmt)) {
428            Value * const op0 = compileExpression(b, stmt->getOperand(0));
429            Value * const op1 = compileExpression(b, stmt->getOperand(1));
430            value = b->simd_and(op0, op1, stmt->getName());
431        } else if (isa<Or>(stmt)) {
432            Value * const op0 = compileExpression(b, stmt->getOperand(0));
433            Value * const op1 = compileExpression(b, stmt->getOperand(1));
434            value = b->simd_or(op0, op1, stmt->getName());
435        } else if (isa<Xor>(stmt)) {
436            Value * const op0 = compileExpression(b, stmt->getOperand(0));
437            Value * const op1 = compileExpression(b, stmt->getOperand(1));
438            value = b->simd_xor(op0, op1, stmt->getName());
439        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
440            Value* ifMask = compileExpression(b, sel->getCondition());
441            Value* ifTrue = b->simd_and(ifMask, compileExpression(b, sel->getTrueExpr()));
442            Value* ifFalse = b->simd_and(b->simd_not(ifMask), compileExpression(b, sel->getFalseExpr()));
443            value = b->simd_or(ifTrue, ifFalse);
444        } else if (isa<Not>(stmt)) {
445            value = b->simd_not(compileExpression(b, stmt->getOperand(0)), stmt->getName());
446        } else if (isa<Advance>(stmt)) {
447            const Advance * const adv = cast<Advance>(stmt);
448            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
449            // manager so that it properly selects the correct carry bit.
450            value = mCarryManager->advanceCarryInCarryOut(b, adv, compileExpression(b, adv->getExpression()));
451        } else if (isa<IndexedAdvance>(stmt)) {
452            const IndexedAdvance * const adv = cast<IndexedAdvance>(stmt);
453            Value * strm = compileExpression(b, adv->getExpression());
454            Value * index_strm = compileExpression(b, adv->getIndex());
455            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
456            // manager so that it properly selects the correct carry bit.
457            value = mCarryManager->indexedAdvanceCarryInCarryOut(b, adv, strm, index_strm);
458        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
459            Value * const marker = compileExpression(b, mstar->getMarker());
460            Value * const cc = compileExpression(b, mstar->getCharClass());
461            Value * const marker_and_cc = b->simd_and(marker, cc);
462            Value * const sum = mCarryManager->addCarryInCarryOut(b, mstar, marker_and_cc, cc);
463            value = b->simd_or(b->simd_xor(sum, cc), marker, mstar->getName());
464        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
465            Value * const from = compileExpression(b, sthru->getScanFrom());
466            Value * const thru = compileExpression(b, sthru->getScanThru());
467            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, thru);
468            value = b->simd_and(sum, b->simd_not(thru), sthru->getName());
469        } else if (const ScanTo * sthru = dyn_cast<ScanTo>(stmt)) {
470            Value * const marker_expr = compileExpression(b, sthru->getScanFrom());
471            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
472            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, marker_expr, b->simd_not(to));
473            value = b->simd_and(sum, to, sthru->getName());
474        } else if (const AdvanceThenScanThru * sthru = dyn_cast<AdvanceThenScanThru>(stmt)) {
475            Value * const from = compileExpression(b, sthru->getScanFrom());
476            Value * const thru = compileExpression(b, sthru->getScanThru());
477            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, thru));
478            value = b->simd_and(sum, b->simd_not(thru), sthru->getName());
479        } else if (const AdvanceThenScanTo * sthru = dyn_cast<AdvanceThenScanTo>(stmt)) {
480            Value * const from = compileExpression(b, sthru->getScanFrom());
481            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
482            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, b->simd_not(to)));
483            value = b->simd_and(sum, to, sthru->getName());
484        } else if (const TerminateAt * s = dyn_cast<TerminateAt>(stmt)) {
485            Value * signal_strm = compileExpression(b, s->getExpr());
486            BasicBlock * signalCallBack = b->CreateBasicBlock("signalCallBack");
487            BasicBlock * postSignal = b->CreateBasicBlock("postSignal");
488            b->CreateCondBr(b->bitblock_any(signal_strm), signalCallBack, postSignal);
489            b->SetInsertPoint(signalCallBack);
490            // Perhaps check for handler address and skip call back if none???
491            Value * handler = b->getScalarField("handler_address");
492            Function * const dispatcher = b->getModule()->getFunction("signal_dispatcher"); assert (dispatcher);
493            b->CreateCall(dispatcher, {handler, ConstantInt::get(b->getInt32Ty(), s->getSignalCode())});
494            //Value * rel_position = b->createCountForwardZeroes(signal_strm);
495            //Value * position = b->CreateAdd(b->getProcessedItemCount(), rel_position);
496            b->setTerminationSignal();
497            b->CreateBr(postSignal);
498            b->SetInsertPoint(postSignal);
499            value = signal_strm;
500        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
501            expr = cast<Assign>(stmt)->getVariable();
502            value = compileExpression(b, cast<Assign>(stmt)->getValue());
503            if (isa<Extract>(expr) || (isa<Var>(expr) && cast<Var>(expr)->isKernelParameter())) {
504                Value * const ptr = compileExpression(b, expr, false);
505                Type * const elemTy = ptr->getType()->getPointerElementType();
506                b->CreateAlignedStore(b->CreateZExt(value, elemTy), ptr, getAlignment(elemTy));
507                value = ptr;
508            }
509        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
510            Value * EOFmask = b->getScalarField("EOFmask");
511            value = b->simd_and(compileExpression(b, e->getExpr()), b->simd_not(EOFmask));
512        } else if (const AtEOF * e = dyn_cast<AtEOF>(stmt)) {
513            Value * EOFbit = b->getScalarField("EOFbit");
514            value = b->simd_and(compileExpression(b, e->getExpr()), EOFbit);
515        } else if (const Count * c = dyn_cast<Count>(stmt)) {
516            Value * EOFbit = b->getScalarField("EOFbit");
517            Value * EOFmask = b->getScalarField("EOFmask");
518            Value * const to_count = b->simd_and(b->simd_or(b->simd_not(EOFmask), EOFbit), compileExpression(b, c->getExpr()));
519            Value * const ptr = b->getScalarFieldPtr(stmt->getName().str());
520            const auto alignment = getPointerElementAlignment(ptr);
521            Value * const countSoFar = b->CreateAlignedLoad(ptr, alignment, c->getName() + "_accumulator");
522            const auto fieldWidth = b->getSizeTy()->getBitWidth();
523            Value * bitBlockCount = b->simd_popcount(b->getBitBlockWidth(), to_count);
524            value = b->CreateAdd(b->mvmd_extract(fieldWidth, bitBlockCount, 0), countSoFar, "countSoFar");
525            b->CreateAlignedStore(value, ptr, alignment);
526        } else if (const Lookahead * l = dyn_cast<Lookahead>(stmt)) {
527            PabloAST * stream = l->getExpression();
528            Value * index = nullptr;
529            if (LLVM_UNLIKELY(isa<Extract>(stream))) {
530                index = compileExpression(b, cast<Extract>(stream)->getIndex(), true);
531                stream = cast<Extract>(stream)->getArray();
532            } else {
533                index = b->getInt32(0);
534            }
535            const auto bit_shift = (l->getAmount() % b->getBitBlockWidth());
536            const auto block_shift = (l->getAmount() / b->getBitBlockWidth());
537            Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift));
538            // Value * base = b->CreatePointerCast(b->getBaseAddress(cast<Var>(stream)->getName()), ptr->getType());
539            Value * lookAhead = b->CreateBlockAlignedLoad(ptr);
540            if (LLVM_UNLIKELY(bit_shift == 0)) {  // Simple case with no intra-block shifting.
541                value = lookAhead;
542            } else { // Need to form shift result from two adjacent blocks.
543                Value * ptr1 = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift + 1));
544                Value * lookAhead1 = b->CreateBlockAlignedLoad(ptr1);
545                if (LLVM_UNLIKELY((bit_shift % 8) == 0)) { // Use a single whole-byte shift, if possible.
546                    value = b->mvmd_dslli(8, lookAhead1, lookAhead, (bit_shift / 8));
547                } else {
548                    Type  * const streamType = b->getIntNTy(b->getBitBlockWidth());
549                    Value * b1 = b->CreateBitCast(lookAhead1, streamType);
550                    Value * b0 = b->CreateBitCast(lookAhead, streamType);
551                    Value * result = b->CreateOr(b->CreateShl(b1, b->getBitBlockWidth() - bit_shift), b->CreateLShr(b0, bit_shift));
552                    value = b->CreateBitCast(result, b->getBitBlockType());
553                }
554            }
555        } else if (const Repeat * const s = dyn_cast<Repeat>(stmt)) {
556            value = compileExpression(b, s->getValue());
557            Type * const ty = s->getType();
558            if (LLVM_LIKELY(ty->isVectorTy())) {
559                const auto fw = s->getFieldWidth()->value();
560                value = b->CreateZExtOrTrunc(value, b->getIntNTy(fw));
561                value = b->simd_fill(fw, value);
562            } else {
563                value = b->CreateZExtOrTrunc(value, ty);
564            }
565        } else if (const PackH * const p = dyn_cast<PackH>(stmt)) {
566            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
567            const auto packWidth = p->getFieldWidth()->value();
568            assert (sourceWidth == packWidth);
569            Value * const base = compileExpression(b, p->getValue(), false);
570            const auto result_packs = sourceWidth/2;
571            if (LLVM_LIKELY(result_packs > 1)) {
572                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
573            }
574            Constant * const ZERO = b->getInt32(0);
575            for (unsigned i = 0; i < result_packs; ++i) {
576                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
577                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
578                Value * P = b->bitCast(b->hsimd_packh(packWidth, A, B));
579                if (LLVM_UNLIKELY(result_packs == 1)) {
580                    value = P;
581                    break;
582                }
583                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
584            }
585        } else if (const PackL * const p = dyn_cast<PackL>(stmt)) {
586            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
587            const auto packWidth = p->getFieldWidth()->value();
588            assert (sourceWidth == packWidth);
589            Value * const base = compileExpression(b, p->getValue(), false);
590            const auto result_packs = sourceWidth/2;
591            if (LLVM_LIKELY(result_packs > 1)) {
592                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
593            }
594            Constant * const ZERO = b->getInt32(0);
595            for (unsigned i = 0; i < result_packs; ++i) {
596                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
597                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
598                Value * P = b->bitCast(b->hsimd_packl(packWidth, A, B));
599                if (LLVM_UNLIKELY(result_packs == 1)) {
600                    value = P;
601                    break;
602                }
603                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
604            }
605        } else {
606            std::string tmp;
607            raw_string_ostream out(tmp);
608            out << "PabloCompiler: ";
609            stmt->print(out);
610            out << " was not recognized by the compiler";
611            report_fatal_error(out.str());
612        }
613        assert (expr);
614        assert (value);
615        mMarker[expr] = value;
616        if (DebugOptionIsSet(DumpTrace)) {
617            std::string tmp;
618            raw_string_ostream name(tmp);
619            expr->print(name);
620            if (value->getType()->isVectorTy()) {
621                b->CallPrintRegister(name.str(), value);
622            } else if (value->getType()->isIntegerTy()) {
623                b->CallPrintInt(name.str(), value);
624            }
625        }
626    }
627}
628
629unsigned getIntegerBitWidth(const Type * ty) {
630    if (ty->isArrayTy()) {
631        assert (ty->getArrayNumElements() == 1);
632        ty = ty->getArrayElementType();
633    }
634    if (ty->isVectorTy()) {
635        assert (ty->getVectorNumElements() == 0);
636        ty = ty->getVectorElementType();
637    }
638    return ty->getIntegerBitWidth();
639}
640
641Value * PabloCompiler::compileExpression(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloAST * const expr, const bool ensureLoaded) {
642    const auto f = mMarker.find(expr);
643    Value * value = nullptr;
644    if (LLVM_LIKELY(f != mMarker.end())) {
645        value = f->second;
646    } else {
647        if (isa<Integer>(expr)) {
648            value = ConstantInt::get(cast<Integer>(expr)->getType(), cast<Integer>(expr)->value());
649        } else if (isa<Zeroes>(expr)) {
650            value = b->allZeroes();
651        } else if (LLVM_UNLIKELY(isa<Ones>(expr))) {
652            value = b->allOnes();
653        } else if (isa<Extract>(expr)) {
654            const Extract * const extract = cast<Extract>(expr);
655            const Var * const var = cast<Var>(extract->getArray());
656            Value * const index = compileExpression(b, extract->getIndex());
657            value = getPointerToVar(b, var, index);
658        } else if (LLVM_UNLIKELY(isa<Var>(expr))) {
659            const Var * const var = cast<Var>(expr);
660            if (LLVM_LIKELY(var->isKernelParameter() && var->isScalar())) {
661                value = b->getScalarFieldPtr(var->getName());
662            } else { // use before def error
663                std::string tmp;
664                raw_string_ostream out(tmp);
665                out << "PabloCompiler: ";
666                expr->print(out);
667                out << " is not a scalar value or was used before definition";
668                report_fatal_error(out.str());
669            }
670        } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
671            const Operator * const op = cast<Operator>(expr);
672            const PabloAST * lh = op->getLH();
673            const PabloAST * rh = op->getRH();
674            if ((isa<Var>(lh) || isa<Extract>(lh)) || (isa<Var>(rh) || isa<Extract>(rh))) {
675                if (getIntegerBitWidth(lh->getType()) != getIntegerBitWidth(rh->getType())) {
676                    llvm::report_fatal_error("Integer types must be identical!");
677                }
678                const unsigned intWidth = std::min(getIntegerBitWidth(lh->getType()), getIntegerBitWidth(rh->getType()));
679                const unsigned maskWidth = b->getBitBlockWidth() / intWidth;
680                IntegerType * const maskTy = b->getIntNTy(maskWidth);
681                VectorType * const vTy = VectorType::get(b->getIntNTy(intWidth), maskWidth);
682
683                Value * baseLhv = nullptr;
684                Value * lhvStreamIndex = nullptr;
685                if (isa<Var>(lh)) {
686                    lhvStreamIndex = b->getInt32(0);
687                } else if (isa<Extract>(lh)) {
688                    lhvStreamIndex = compileExpression(b, cast<Extract>(lh)->getIndex());
689                    lh = cast<Extract>(lh)->getArray();
690                } else {
691                    baseLhv = compileExpression(b, lh);
692                }
693
694                Value * baseRhv = nullptr;
695                Value * rhvStreamIndex = nullptr;
696                if (isa<Var>(rh)) {
697                    rhvStreamIndex = b->getInt32(0);
698                } else if (isa<Extract>(rh)) {
699                    rhvStreamIndex = compileExpression(b, cast<Extract>(rh)->getIndex());
700                    rh = cast<Extract>(rh)->getArray();
701                } else {
702                    baseRhv = compileExpression(b, rh);
703                }
704
705                const TypeId typeId = op->getClassTypeId();
706
707                if (LLVM_UNLIKELY(typeId == TypeId::Add || typeId == TypeId::Subtract)) {
708
709                    value = b->CreateAlloca(vTy, b->getInt32(intWidth));
710
711                    for (unsigned i = 0; i < intWidth; ++i) {
712                        llvm::Constant * const index = b->getInt32(i);
713                        Value * lhv = nullptr;
714                        if (baseLhv) {
715                            lhv = baseLhv;
716                        } else {
717                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
718                            lhv = b->CreateBlockAlignedLoad(lhv);
719                        }
720                        lhv = b->CreateBitCast(lhv, vTy);
721
722                        Value * rhv = nullptr;
723                        if (baseRhv) {
724                            rhv = baseRhv;
725                        } else {
726                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
727                            rhv = b->CreateBlockAlignedLoad(rhv);
728                        }
729                        rhv = b->CreateBitCast(rhv, vTy);
730
731                        Value * result = nullptr;
732                        if (typeId == TypeId::Add) {
733                            result = b->CreateAdd(lhv, rhv);
734                        } else { // if (typeId == TypeId::Subtract) {
735                            result = b->CreateSub(lhv, rhv);
736                        }
737                        b->CreateAlignedStore(result, b->CreateGEP(value, {b->getInt32(0), b->getInt32(i)}), getAlignment(result));
738                    }
739
740                } else {
741
742                    value = UndefValue::get(VectorType::get(maskTy, intWidth));
743
744                    for (unsigned i = 0; i < intWidth; ++i) {
745                        llvm::Constant * const index = b->getInt32(i);
746                        Value * lhv = nullptr;
747                        if (baseLhv) {
748                            lhv = baseLhv;
749                        } else {
750                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
751                            lhv = b->CreateBlockAlignedLoad(lhv);
752                        }
753                        lhv = b->CreateBitCast(lhv, vTy);
754
755                        Value * rhv = nullptr;
756                        if (baseRhv) {
757                            rhv = baseRhv;
758                        } else {
759                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
760                            rhv = b->CreateBlockAlignedLoad(rhv);
761                        }
762                        rhv = b->CreateBitCast(rhv, vTy);
763                        Value * comp = nullptr;
764                        switch (typeId) {
765                            case TypeId::GreaterThanEquals:
766                            case TypeId::LessThan:
767                                comp = b->simd_ult(intWidth, lhv, rhv);
768                                break;
769                            case TypeId::Equals:
770                            case TypeId::NotEquals:
771                                comp = b->simd_eq(intWidth, lhv, rhv);
772                                break;
773                            case TypeId::LessThanEquals:
774                            case TypeId::GreaterThan:
775                                comp = b->simd_ugt(intWidth, lhv, rhv);
776                                break;
777                            default: llvm_unreachable("invalid vector operator id");
778                        }
779                        Value * const mask = b->CreateZExtOrTrunc(b->hsimd_signmask(intWidth, comp), maskTy);
780                        value = b->mvmd_insert(maskWidth, value, mask, i);
781                    }
782                    value = b->CreateBitCast(value, b->getBitBlockType());
783                    switch (typeId) {
784                        case TypeId::GreaterThanEquals:
785                        case TypeId::LessThanEquals:
786                        case TypeId::NotEquals:
787                            value = b->CreateNot(value);
788                        default: break;
789                    }
790                }
791
792            } else {
793                Value * const lhv = compileExpression(b, lh);
794                Value * const rhv = compileExpression(b, rh);
795                switch (op->getClassTypeId()) {
796                    case TypeId::Add:
797                        value = b->CreateAdd(lhv, rhv); break;
798                    case TypeId::Subtract:
799                        value = b->CreateSub(lhv, rhv); break;
800                    case TypeId::LessThan:
801                        value = b->CreateICmpULT(lhv, rhv); break;
802                    case TypeId::LessThanEquals:
803                        value = b->CreateICmpULE(lhv, rhv); break;
804                    case TypeId::Equals:
805                        value = b->CreateICmpEQ(lhv, rhv); break;
806                    case TypeId::GreaterThanEquals:
807                        value = b->CreateICmpUGE(lhv, rhv); break;
808                    case TypeId::GreaterThan:
809                        value = b->CreateICmpUGT(lhv, rhv); break;
810                    case TypeId::NotEquals:
811                        value = b->CreateICmpNE(lhv, rhv); break;
812                    default: llvm_unreachable("invalid scalar operator id");
813                }
814            }
815        } else { // use before def error
816            std::string tmp;
817            raw_string_ostream out(tmp);
818            out << "PabloCompiler: ";
819            expr->print(out);
820            out << " was used before definition";
821            report_fatal_error(out.str());
822        }
823        assert (value);
824        // mMarker.insert({expr, value});
825    }
826    if (LLVM_UNLIKELY(value->getType()->isPointerTy() && ensureLoaded)) {
827        value = b->CreateAlignedLoad(value, getPointerElementAlignment(value));
828    }
829    return value;
830}
831
832Value * PabloCompiler::getPointerToVar(const std::unique_ptr<kernel::KernelBuilder> & b, const Var * var, Value * index1, Value * index2)  {
833    assert (var && index1);
834    if (LLVM_LIKELY(var->isKernelParameter())) {
835        if (LLVM_UNLIKELY(var->isScalar())) {
836            std::string tmp;
837            raw_string_ostream out(tmp);
838            out << mKernel->getName();
839            out << ": cannot index scalar value ";
840            var->print(out);
841            report_fatal_error(out.str());
842        } else if (var->isReadOnly()) {
843            if (index2) {
844                return b->getInputStreamPackPtr(var->getName(), index1, index2);
845            } else {
846                return b->getInputStreamBlockPtr(var->getName(), index1);
847            }
848        } else if (var->isReadNone()) {
849            if (index2) {
850                return b->getOutputStreamPackPtr(var->getName(), index1, index2);
851            } else {
852                return b->getOutputStreamBlockPtr(var->getName(), index1);
853            }
854        } else {
855            std::string tmp;
856            raw_string_ostream out(tmp);
857            out << mKernel->getName();
858            out << ": stream ";
859            var->print(out);
860            out << " cannot be read from or written to";
861            report_fatal_error(out.str());
862        }
863    } else {
864        Value * const ptr = compileExpression(b, var, false);
865        std::vector<Value *> offsets;
866        offsets.push_back(ConstantInt::getNullValue(index1->getType()));
867        offsets.push_back(index1);
868        if (index2) offsets.push_back(index2);
869        return b->CreateGEP(ptr, offsets);
870    }
871}
872
873PabloCompiler::PabloCompiler(PabloKernel * const kernel)
874: mKernel(kernel)
875, mCarryManager(make_unique<CarryManager>())
876, mBranchCount(0) {
877    assert ("PabloKernel cannot be null!" && kernel);
878}
879
880PabloCompiler::~PabloCompiler() {
881}
882
883}
Note: See TracBrowser for help on using the repository browser.