source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 6184

Last change on this file since 6184 was 6184, checked in by nmedfort, 8 months ago

Initial version of PipelineKernel? + revised StreamSet? model.

File size: 38.0 KB
Line 
1/*
2 *  Copyright (c) 2014-16 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "pablo_compiler.h"
8#include <pablo/pablo_kernel.h>
9#include <pablo/pablo_toolchain.h>
10#include <pablo/codegenstate.h>
11#include <pablo/boolean.h>
12#include <pablo/arithmetic.h>
13#include <pablo/branch.h>
14#include <pablo/pe_advance.h>
15#include <pablo/pe_lookahead.h>
16#include <pablo/pe_matchstar.h>
17#include <pablo/pe_scanthru.h>
18#include <pablo/pe_infile.h>
19#include <pablo/pe_count.h>
20#include <pablo/pe_integer.h>
21#include <pablo/pe_string.h>
22#include <pablo/pe_zeroes.h>
23#include <pablo/pe_ones.h>
24#include <pablo/pe_repeat.h>
25#include <pablo/pe_pack.h>
26#include <pablo/pe_var.h>
27#include <pablo/ps_assign.h>
28#ifdef USE_CARRYPACK_MANAGER
29#include <pablo/carrypack_manager.h>
30#else
31#include <pablo/carry_manager.h>
32#endif
33#include <kernels/kernel_builder.h>
34#include <kernels/streamset.h>
35#include <llvm/IR/Module.h>
36#include <llvm/IR/Type.h>
37#include <llvm/Support/raw_os_ostream.h>
38#include <llvm/ADT/STLExtras.h> // for make_unique
39
40using namespace llvm;
41
42namespace pablo {
43
44using TypeId = PabloAST::ClassTypeId;
45
46inline static unsigned getAlignment(const Type * const type) {
47    return type->getPrimitiveSizeInBits() / 8;
48}
49
50inline static unsigned getAlignment(const Value * const expr) {
51    return getAlignment(expr->getType());
52}
53
54inline static unsigned getPointerElementAlignment(const Value * const ptr) {
55    return getAlignment(ptr->getType()->getPointerElementType());
56}
57
58void PabloCompiler::initializeKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
59    mBranchCount = 0;
60    examineBlock(b, mKernel->getEntryScope());
61    mCarryManager->initializeCarryData(b, mKernel);
62    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
63        const auto count = (mBranchCount * 2) + 1;
64        mKernel->addInternalScalar(ArrayType::get(mKernel->getSizeTy(), count), "profile");
65        mBasicBlock.reserve(count);
66    }
67}
68
69void PabloCompiler::releaseKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
70    mCarryManager->releaseCarryData(b);
71}
72
73void PabloCompiler::clearCarryData(const std::unique_ptr<kernel::KernelBuilder> & b) {
74    mCarryManager->clearCarryData(b);
75}
76
77void PabloCompiler::compile(const std::unique_ptr<kernel::KernelBuilder> & b) {
78    mCarryManager->initializeCodeGen(b);
79    PabloBlock * const entryBlock = mKernel->getEntryScope(); assert (entryBlock);
80    mMarker.emplace(entryBlock->createZeroes(), b->allZeroes());
81    mMarker.emplace(entryBlock->createOnes(), b->allOnes());
82    mBranchCount = 0;
83    addBranchCounter(b);
84    compileBlock(b, entryBlock);
85    mCarryManager->finalizeCodeGen(b);
86}
87
88void PabloCompiler::examineBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
89    for (const Statement * stmt : *block) {
90        if (LLVM_UNLIKELY(isa<Lookahead>(stmt))) {
91            const Lookahead * const la = cast<Lookahead>(stmt);
92            PabloAST * input = la->getExpression();
93            if (isa<Extract>(input)) {
94                input = cast<Extract>(input)->getArray();
95            }
96            bool notFound = true;
97            if (LLVM_LIKELY(isa<Var>(input))) {
98                for (unsigned i = 0; i < mKernel->getNumOfInputs(); ++i) {
99                    if (input == mKernel->getInput(i)) {
100                        const auto & binding = mKernel->getInputStreamSetBinding(i);
101                        if (LLVM_UNLIKELY(!binding.hasLookahead() || binding.getLookahead() < la->getAmount())) {
102                            std::string tmp;
103                            raw_string_ostream out(tmp);
104                            input->print(out);
105                            out << " must have a lookahead attribute of at least " << la->getAmount();
106                            report_fatal_error(out.str());
107                        }
108                        notFound = false;
109                        break;
110                    }
111                }
112            }
113            if (LLVM_UNLIKELY(notFound)) {
114                report_fatal_error("Lookahead " + stmt->getName() + " can only be performed on an input streamset");
115            }
116        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
117            ++mBranchCount;
118            examineBlock(b, cast<Branch>(stmt)->getBody());
119        } else if (LLVM_UNLIKELY(isa<Count>(stmt))) {
120            mKernel->addInternalScalar(stmt->getType(), stmt->getName().str());
121        }
122    }   
123}
124
125void PabloCompiler::addBranchCounter(const std::unique_ptr<kernel::KernelBuilder> & b) {
126    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {       
127        Value * ptr = b->getScalarFieldPtr("profile");
128        assert (mBasicBlock.size() < ptr->getType()->getPointerElementType()->getArrayNumElements());
129        ptr = b->CreateGEP(ptr, {b->getInt32(0), b->getInt32(mBasicBlock.size())});
130        const auto alignment = getPointerElementAlignment(ptr);
131        Value * value = b->CreateAlignedLoad(ptr, alignment, false, "branchCounter");
132        value = b->CreateAdd(value, ConstantInt::get(cast<IntegerType>(value->getType()), 1));
133        b->CreateAlignedStore(value, ptr, alignment);
134        mBasicBlock.push_back(b->GetInsertBlock());
135    }
136}
137
138inline void PabloCompiler::compileBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
139    for (const Statement * statement : *block) {
140        compileStatement(b, statement);
141    }
142}
143
144void PabloCompiler::compileIf(const std::unique_ptr<kernel::KernelBuilder> & b, const If * const ifStatement) {
145    //
146    //  The If-ElseZero stmt:
147    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
148    //  If the value of the predicate is nonzero, then determine the values of variables
149    //  <var>* by executing the given statements.  Otherwise, the value of the
150    //  variables are all zero.  Requirements: (a) no variable that is defined within
151    //  the body of the if may be accessed outside unless it is explicitly
152    //  listed in the variable list, (b) every variable in the defined list receives
153    //  a value within the body, and (c) the logical consequence of executing
154    //  the statements in the event that the predicate is zero is that the
155    //  values of all defined variables indeed work out to be 0.
156    //
157    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
158    //  is inserted for each variable in the defined variable list.  It receives
159    //  a zero value from the ifentry block and the defined value from the if
160    //  body.
161    //
162
163    BasicBlock * const ifEntryBlock = b->GetInsertBlock();
164    ++mBranchCount;
165    BasicBlock * const ifBodyBlock = b->CreateBasicBlock("if.body_" + std::to_string(mBranchCount));
166    BasicBlock * const ifEndBlock = b->CreateBasicBlock("if.end_" + std::to_string(mBranchCount));
167   
168    std::vector<std::pair<const Var *, Value *>> incoming;
169
170    for (const Var * var : ifStatement->getEscaped()) {
171        if (LLVM_UNLIKELY(var->isKernelParameter())) {
172            Value * marker = nullptr;
173            if (var->isScalar()) {
174                marker = b->getScalarFieldPtr(var->getName());
175            } else if (var->isReadOnly()) {
176                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
177            } else if (var->isReadNone()) {
178                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
179            }
180            mMarker[var] = marker;
181        } else {
182            auto f = mMarker.find(var);
183            if (LLVM_UNLIKELY(f == mMarker.end())) {
184                std::string tmp;
185                raw_string_ostream out(tmp);
186                var->print(out);
187                out << " is uninitialized prior to entering ";
188                ifStatement->print(out);
189                report_fatal_error(out.str());
190            }
191            incoming.emplace_back(var, f->second);
192        }
193    }
194
195    const PabloBlock * ifBody = ifStatement->getBody();
196   
197    mCarryManager->enterIfScope(b, ifBody);
198
199    Value * condition = compileExpression(b, ifStatement->getCondition());
200    if (condition->getType() == b->getBitBlockType()) {
201        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
202    }
203   
204    b->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
205   
206    // Entry processing is complete, now handle the body of the if.
207    b->SetInsertPoint(ifBodyBlock);
208
209    mCarryManager->enterIfBody(b, ifEntryBlock);
210
211    addBranchCounter(b);
212
213    compileBlock(b, ifBody);
214
215    mCarryManager->leaveIfBody(b, b->GetInsertBlock());
216
217    BasicBlock * ifExitBlock = b->GetInsertBlock();
218
219    b->CreateBr(ifEndBlock);
220
221    ifEndBlock->moveAfter(ifExitBlock);
222
223    //End Block
224    b->SetInsertPoint(ifEndBlock);
225
226    mCarryManager->leaveIfScope(b, ifEntryBlock, ifExitBlock);
227
228    for (const auto i : incoming) {
229        const Var * var; Value * incoming;
230        std::tie(var, incoming) = i;
231
232        auto f = mMarker.find(var);
233        if (LLVM_UNLIKELY(f == mMarker.end())) {
234            std::string tmp;
235            raw_string_ostream out(tmp);
236            out << "PHINode creation error: ";
237            var->print(out);
238            out << " was not assigned an outgoing value.";
239            report_fatal_error(out.str());
240        }
241
242        Value * const outgoing = f->second;
243        if (LLVM_UNLIKELY(incoming == outgoing)) {
244            continue;
245        }
246
247        if (LLVM_UNLIKELY(incoming->getType() != outgoing->getType())) {
248            std::string tmp;
249            raw_string_ostream out(tmp);
250            out << "PHINode creation error: incoming type of ";
251            var->print(out);
252            out << " (";
253            incoming->getType()->print(out);
254            out << ") differs from the outgoing type (";
255            outgoing->getType()->print(out);
256            out << ") within ";
257            ifStatement->print(out);
258            report_fatal_error(out.str());
259        }
260
261        PHINode * phi = b->CreatePHI(incoming->getType(), 2, var->getName());
262        phi->addIncoming(incoming, ifEntryBlock);
263        phi->addIncoming(outgoing, ifExitBlock);
264        f->second = phi;
265    }
266
267    addBranchCounter(b);
268}
269
270void PabloCompiler::compileWhile(const std::unique_ptr<kernel::KernelBuilder> & b, const While * const whileStatement) {
271
272    const PabloBlock * const whileBody = whileStatement->getBody();
273
274    BasicBlock * whileEntryBlock = b->GetInsertBlock();
275
276    const auto escaped = whileStatement->getEscaped();
277
278#ifdef ENABLE_BOUNDED_WHILE
279    PHINode * bound_phi = nullptr;  // Needed for bounded while loops.
280#endif
281    // On entry to the while structure, proceed to execute the first iteration
282    // of the loop body unconditionally. The while condition is tested at the end of
283    // the loop.
284
285    for (const Var * var : escaped) {
286        if (LLVM_UNLIKELY(var->isKernelParameter())) {
287            Value * marker = nullptr;
288            if (var->isScalar()) {
289                marker = b->getScalarFieldPtr(var->getName());
290            } else if (var->isReadOnly()) {
291                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
292            } else if (var->isReadNone()) {
293                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
294            }
295            mMarker[var] = marker;
296        }
297    }
298
299    mCarryManager->enterLoopScope(b, whileBody);
300
301    BasicBlock * whileBodyBlock = b->CreateBasicBlock("while.body_" + std::to_string(mBranchCount));
302    BasicBlock * whileEndBlock = b->CreateBasicBlock("while.end_" + std::to_string(mBranchCount));
303    ++mBranchCount;
304
305    b->CreateBr(whileBodyBlock);
306
307    b->SetInsertPoint(whileBodyBlock);
308
309    //
310    // There are 3 sets of Phi nodes for the while loop.
311    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
312    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
313    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
314#ifdef ENABLE_BOUNDED_WHILE
315    // (4) The loop bound, if any.
316#endif
317
318    std::vector<std::pair<const Var *, PHINode *>> variants;
319
320    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
321    for (const auto var : escaped) {
322        auto f = mMarker.find(var);
323        if (LLVM_UNLIKELY(f == mMarker.end())) {
324            std::string tmp;
325            raw_string_ostream out(tmp);
326            out << "PHINode creation error: ";
327            var->print(out);
328            out << " is uninitialized prior to entering ";
329            whileStatement->print(out);
330            report_fatal_error(out.str());
331        }
332        Value * entryValue = f->second;
333        PHINode * phi = b->CreatePHI(entryValue->getType(), 2, var->getName());
334        phi->addIncoming(entryValue, whileEntryBlock);
335        f->second = phi;
336        assert(mMarker[var] == phi);
337        variants.emplace_back(var, phi);
338    }
339#ifdef ENABLE_BOUNDED_WHILE
340    if (whileStatement->getBound()) {
341        bound_phi = b->CreatePHI(b->getSizeTy(), 2, "while_bound");
342        bound_phi->addIncoming(b->getSize(whileStatement->getBound()), whileEntryBlock);
343    }
344#endif
345
346    mCarryManager->enterLoopBody(b, whileEntryBlock);
347
348    addBranchCounter(b);
349
350    compileBlock(b, whileBody);
351
352    // After the whileBody has been compiled, we may be in a different basic block.
353
354    mCarryManager->leaveLoopBody(b, b->GetInsertBlock());
355
356
357#ifdef ENABLE_BOUNDED_WHILE
358    if (whileStatement->getBound()) {
359        Value * new_bound = b->CreateSub(bound_phi, b->getSize(1));
360        bound_phi->addIncoming(new_bound, whileExitBlock);
361        condition = b->CreateAnd(condition, b->CreateICmpUGT(new_bound, ConstantInt::getNullValue(b->getSizeTy())));
362    }
363#endif
364
365    BasicBlock * const whileExitBlock = b->GetInsertBlock();
366
367    // and for any variant nodes in the loop body
368    for (const auto variant : variants) {
369        const Var * var; PHINode * incomingPhi;
370        std::tie(var, incomingPhi) = variant;
371        const auto f = mMarker.find(var);
372        if (LLVM_UNLIKELY(f == mMarker.end())) {
373            std::string tmp;
374            raw_string_ostream out(tmp);
375            out << "PHINode creation error: ";
376            var->print(out);
377            out << " is no longer assigned a value.";
378            report_fatal_error(out.str());
379        }
380
381        Value * const outgoingValue = f->second;
382
383        if (LLVM_UNLIKELY(incomingPhi->getType() != outgoingValue->getType())) {
384            std::string tmp;
385            raw_string_ostream out(tmp);
386            out << "PHINode creation error: incoming type of ";
387            var->print(out);
388            out << " (";
389            incomingPhi->getType()->print(out);
390            out << ") differs from the outgoing type (";
391            outgoingValue->getType()->print(out);
392            out << ") within ";
393            whileStatement->print(out);
394            report_fatal_error(out.str());
395        }
396
397        incomingPhi->addIncoming(outgoingValue, whileExitBlock);
398    }
399
400    // Terminate the while loop body with a conditional branch back.
401    Value * condition = compileExpression(b, whileStatement->getCondition());
402    if (condition->getType() == b->getBitBlockType()) {
403        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
404    }
405
406    b->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
407
408    whileEndBlock->moveAfter(whileExitBlock);
409
410    b->SetInsertPoint(whileEndBlock);
411
412    mCarryManager->leaveLoopScope(b, whileEntryBlock, whileExitBlock);
413
414    addBranchCounter(b);
415}
416
417void PabloCompiler::compileStatement(const std::unique_ptr<kernel::KernelBuilder> & b, const Statement * const stmt) {
418
419    if (LLVM_UNLIKELY(isa<If>(stmt))) {
420        compileIf(b, cast<If>(stmt));
421    } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
422        compileWhile(b, cast<While>(stmt));
423    } else {
424        const PabloAST * expr = stmt;
425        Value * value = nullptr;
426        if (isa<And>(stmt)) {
427            Value * const op0 = compileExpression(b, stmt->getOperand(0));
428            Value * const op1 = compileExpression(b, stmt->getOperand(1));
429            value = b->simd_and(op0, op1);
430        } else if (isa<Or>(stmt)) {
431            Value * const op0 = compileExpression(b, stmt->getOperand(0));
432            Value * const op1 = compileExpression(b, stmt->getOperand(1));
433            value = b->simd_or(op0, op1);
434        } else if (isa<Xor>(stmt)) {
435            Value * const op0 = compileExpression(b, stmt->getOperand(0));
436            Value * const op1 = compileExpression(b, stmt->getOperand(1));
437            value = b->simd_xor(op0, op1);
438        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
439            Value* ifMask = compileExpression(b, sel->getCondition());
440            Value* ifTrue = b->simd_and(ifMask, compileExpression(b, sel->getTrueExpr()));
441            Value* ifFalse = b->simd_and(b->simd_not(ifMask), compileExpression(b, sel->getFalseExpr()));
442            value = b->simd_or(ifTrue, ifFalse);
443        } else if (isa<Not>(stmt)) {
444            value = b->simd_not(compileExpression(b, stmt->getOperand(0)));
445        } else if (isa<Advance>(stmt)) {
446            const Advance * const adv = cast<Advance>(stmt);
447            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
448            // manager so that it properly selects the correct carry bit.
449            value = mCarryManager->advanceCarryInCarryOut(b, adv, compileExpression(b, adv->getExpression()));
450        } else if (isa<IndexedAdvance>(stmt)) {
451            const IndexedAdvance * const adv = cast<IndexedAdvance>(stmt);
452            Value * strm = compileExpression(b, adv->getExpression());
453            Value * index_strm = compileExpression(b, adv->getIndex());
454            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
455            // manager so that it properly selects the correct carry bit.
456            value = mCarryManager->indexedAdvanceCarryInCarryOut(b, adv, strm, index_strm);
457        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
458            Value * const marker = compileExpression(b, mstar->getMarker());
459            Value * const cc = compileExpression(b, mstar->getCharClass());
460            Value * const marker_and_cc = b->simd_and(marker, cc);
461            Value * const sum = mCarryManager->addCarryInCarryOut(b, mstar, marker_and_cc, cc);
462            value = b->simd_or(b->simd_xor(sum, cc), marker);
463        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
464            Value * const from = compileExpression(b, sthru->getScanFrom());
465            Value * const thru = compileExpression(b, sthru->getScanThru());
466            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, thru);
467            value = b->simd_and(sum, b->simd_not(thru));
468        } else if (const ScanTo * sthru = dyn_cast<ScanTo>(stmt)) {
469            Value * const marker_expr = compileExpression(b, sthru->getScanFrom());
470            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
471            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, marker_expr, b->simd_not(to));
472            value = b->simd_and(sum, to);
473        } else if (const AdvanceThenScanThru * sthru = dyn_cast<AdvanceThenScanThru>(stmt)) {
474            Value * const from = compileExpression(b, sthru->getScanFrom());
475            Value * const thru = compileExpression(b, sthru->getScanThru());
476            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, thru));
477            value = b->simd_and(sum, b->simd_not(thru));
478        } else if (const AdvanceThenScanTo * sthru = dyn_cast<AdvanceThenScanTo>(stmt)) {
479            Value * const from = compileExpression(b, sthru->getScanFrom());
480            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
481            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, b->simd_not(to)));
482            value = b->simd_and(sum, to);
483        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
484            expr = cast<Assign>(stmt)->getVariable();
485            value = compileExpression(b, cast<Assign>(stmt)->getValue());
486            if (isa<Extract>(expr) || (isa<Var>(expr) && cast<Var>(expr)->isKernelParameter())) {
487                Value * const ptr = compileExpression(b, expr, false);
488                b->CreateAlignedStore(value, ptr, getAlignment(value));
489                value = ptr;
490            }
491        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
492            Value * EOFmask = b->getScalarField("EOFmask");
493            value = b->simd_and(compileExpression(b, e->getExpr()), b->simd_not(EOFmask));
494        } else if (const AtEOF * e = dyn_cast<AtEOF>(stmt)) {
495            Value * EOFbit = b->getScalarField("EOFbit");
496            value = b->simd_and(compileExpression(b, e->getExpr()), EOFbit);
497        } else if (const Count * c = dyn_cast<Count>(stmt)) {
498            Value * EOFbit = b->getScalarField("EOFbit");
499            Value * EOFmask = b->getScalarField("EOFmask");
500            Value * const to_count = b->simd_and(b->simd_or(b->simd_not(EOFmask), EOFbit), compileExpression(b, c->getExpr()));
501            Value * const ptr = b->getScalarFieldPtr(stmt->getName().str());
502            const auto alignment = getPointerElementAlignment(ptr);
503            Value * const countSoFar = b->CreateAlignedLoad(ptr, alignment, c->getName() + "_accumulator");
504            const auto fieldWidth = b->getSizeTy()->getBitWidth();
505            Value * bitBlockCount = b->simd_popcount(b->getBitBlockWidth(), to_count);
506            value = b->CreateAdd(b->mvmd_extract(fieldWidth, bitBlockCount, 0), countSoFar, "countSoFar");
507            b->CreateAlignedStore(value, ptr, alignment);
508        } else if (const Lookahead * l = dyn_cast<Lookahead>(stmt)) {
509            PabloAST * stream = l->getExpression();
510            Value * index = nullptr;
511            if (LLVM_UNLIKELY(isa<Extract>(stream))) {               
512                index = compileExpression(b, cast<Extract>(stream)->getIndex(), true);
513                stream = cast<Extract>(stream)->getArray();
514            } else {
515                index = b->getInt32(0);
516            }
517            const auto bit_shift = (l->getAmount() % b->getBitBlockWidth());
518            const auto block_shift = (l->getAmount() / b->getBitBlockWidth());
519            Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift));
520            // Value * base = b->CreatePointerCast(b->getBaseAddress(cast<Var>(stream)->getName()), ptr->getType());
521            Value * lookAhead = b->CreateBlockAlignedLoad(ptr);
522            if (LLVM_UNLIKELY(bit_shift == 0)) {  // Simple case with no intra-block shifting.
523                value = lookAhead;
524            } else { // Need to form shift result from two adjacent blocks.
525                Value * ptr1 = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift + 1));
526                Value * lookAhead1 = b->CreateBlockAlignedLoad(ptr1);
527                if (LLVM_UNLIKELY((bit_shift % 8) == 0)) { // Use a single whole-byte shift, if possible.
528                    value = b->mvmd_dslli(8, lookAhead1, lookAhead, (bit_shift / 8));
529                } else {
530                    Type  * const streamType = b->getIntNTy(b->getBitBlockWidth());
531                    Value * b1 = b->CreateBitCast(lookAhead1, streamType);
532                    Value * b0 = b->CreateBitCast(lookAhead, streamType);
533                    Value * result = b->CreateOr(b->CreateShl(b1, b->getBitBlockWidth() - bit_shift), b->CreateLShr(b0, bit_shift));
534                    value = b->CreateBitCast(result, b->getBitBlockType());
535                }
536            }
537        } else if (const Repeat * const s = dyn_cast<Repeat>(stmt)) {
538            value = compileExpression(b, s->getValue());
539            Type * const ty = s->getType();
540            if (LLVM_LIKELY(ty->isVectorTy())) {
541                const auto fw = s->getFieldWidth()->value();
542                value = b->CreateZExtOrTrunc(value, b->getIntNTy(fw));
543                value = b->simd_fill(fw, value);
544            } else {
545                value = b->CreateZExtOrTrunc(value, ty);
546            }
547        } else if (const PackH * const p = dyn_cast<PackH>(stmt)) {
548            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
549            const auto packWidth = p->getFieldWidth()->value();
550            assert (sourceWidth == packWidth);
551            Value * const base = compileExpression(b, p->getValue(), false);
552            const auto result_packs = sourceWidth/2;
553            if (LLVM_LIKELY(result_packs > 1)) {
554                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
555            }
556            Constant * const ZERO = b->getInt32(0);
557            for (unsigned i = 0; i < result_packs; ++i) {
558                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
559                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
560                Value * P = b->bitCast(b->hsimd_packh(packWidth, A, B));
561                if (LLVM_UNLIKELY(result_packs == 1)) {
562                    value = P;
563                    break;
564                }
565                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
566            }
567        } else if (const PackL * const p = dyn_cast<PackL>(stmt)) {
568            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
569            const auto packWidth = p->getFieldWidth()->value();
570            assert (sourceWidth == packWidth);
571            Value * const base = compileExpression(b, p->getValue(), false);
572            const auto result_packs = sourceWidth/2;
573            if (LLVM_LIKELY(result_packs > 1)) {
574                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
575            }
576            Constant * const ZERO = b->getInt32(0);
577            for (unsigned i = 0; i < result_packs; ++i) {
578                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
579                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
580                Value * P = b->bitCast(b->hsimd_packl(packWidth, A, B));
581                if (LLVM_UNLIKELY(result_packs == 1)) {
582                    value = P;
583                    break;
584                }
585                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
586            }
587        } else {
588            std::string tmp;
589            raw_string_ostream out(tmp);
590            out << "PabloCompiler: ";
591            stmt->print(out);
592            out << " was not recognized by the compiler";
593            report_fatal_error(out.str());
594        }
595        assert (expr);
596        assert (value);
597        mMarker[expr] = value;
598        if (DebugOptionIsSet(DumpTrace)) {
599            std::string tmp;
600            raw_string_ostream name(tmp);
601            expr->print(name);
602            if (value->getType()->isVectorTy()) {
603                b->CallPrintRegister(name.str(), value);
604            } else if (value->getType()->isIntegerTy()) {
605                b->CallPrintInt(name.str(), value);
606            }
607        }
608    }
609}
610
611unsigned getIntegerBitWidth(const Type * ty) {
612    if (ty->isArrayTy()) {
613        assert (ty->getArrayNumElements() == 1);
614        ty = ty->getArrayElementType();
615    }
616    if (ty->isVectorTy()) {
617        assert (ty->getVectorNumElements() == 0);
618        ty = ty->getVectorElementType();
619    }
620    return ty->getIntegerBitWidth();
621}
622
623Value * PabloCompiler::compileExpression(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloAST * const expr, const bool ensureLoaded) {
624    const auto f = mMarker.find(expr);   
625    Value * value = nullptr;
626    if (LLVM_LIKELY(f != mMarker.end())) {
627        value = f->second;
628    } else {
629        if (isa<Integer>(expr)) {
630            value = ConstantInt::get(cast<Integer>(expr)->getType(), cast<Integer>(expr)->value());
631        } else if (isa<Zeroes>(expr)) {
632            value = b->allZeroes();
633        } else if (LLVM_UNLIKELY(isa<Ones>(expr))) {
634            value = b->allOnes();
635        } else if (isa<Extract>(expr)) {
636            const Extract * const extract = cast<Extract>(expr);
637            const Var * const var = cast<Var>(extract->getArray());
638            Value * const index = compileExpression(b, extract->getIndex());
639            value = getPointerToVar(b, var, index);
640        } else if (LLVM_UNLIKELY(isa<Var>(expr))) {
641            const Var * const var = cast<Var>(expr);
642            if (LLVM_LIKELY(var->isKernelParameter() && var->isScalar())) {
643                value = b->getScalarFieldPtr(var->getName());
644            } else { // use before def error
645                std::string tmp;
646                raw_string_ostream out(tmp);
647                out << "PabloCompiler: ";
648                expr->print(out);
649                out << " is not a scalar value or was used before definition";
650                report_fatal_error(out.str());
651            }
652        } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
653            const Operator * const op = cast<Operator>(expr);
654            const PabloAST * lh = op->getLH();
655            const PabloAST * rh = op->getRH();
656            if ((isa<Var>(lh) || isa<Extract>(lh)) || (isa<Var>(rh) || isa<Extract>(rh))) {
657                if (getIntegerBitWidth(lh->getType()) != getIntegerBitWidth(rh->getType())) {
658                    llvm::report_fatal_error("Integer types must be identical!");
659                }
660                const unsigned intWidth = std::min(getIntegerBitWidth(lh->getType()), getIntegerBitWidth(rh->getType()));
661                const unsigned maskWidth = b->getBitBlockWidth() / intWidth;
662                IntegerType * const maskTy = b->getIntNTy(maskWidth);
663                VectorType * const vTy = VectorType::get(b->getIntNTy(intWidth), maskWidth);
664
665                Value * baseLhv = nullptr;
666                Value * lhvStreamIndex = nullptr;
667                if (isa<Var>(lh)) {
668                    lhvStreamIndex = b->getInt32(0);
669                } else if (isa<Extract>(lh)) {
670                    lhvStreamIndex = compileExpression(b, cast<Extract>(lh)->getIndex());
671                    lh = cast<Extract>(lh)->getArray();
672                } else {
673                    baseLhv = compileExpression(b, lh);
674                }
675
676                Value * baseRhv = nullptr;
677                Value * rhvStreamIndex = nullptr;
678                if (isa<Var>(rh)) {
679                    rhvStreamIndex = b->getInt32(0);
680                } else if (isa<Extract>(rh)) {
681                    rhvStreamIndex = compileExpression(b, cast<Extract>(rh)->getIndex());
682                    rh = cast<Extract>(rh)->getArray();
683                } else {
684                    baseRhv = compileExpression(b, rh);
685                }
686
687                const TypeId typeId = op->getClassTypeId();
688
689                if (LLVM_UNLIKELY(typeId == TypeId::Add || typeId == TypeId::Subtract)) {
690
691                    value = b->CreateAlloca(vTy, b->getInt32(intWidth));
692
693                    for (unsigned i = 0; i < intWidth; ++i) {
694                        llvm::Constant * const index = b->getInt32(i);
695                        Value * lhv = nullptr;
696                        if (baseLhv) {
697                            lhv = baseLhv;
698                        } else {
699                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
700                            lhv = b->CreateBlockAlignedLoad(lhv);
701                        }
702                        lhv = b->CreateBitCast(lhv, vTy);
703
704                        Value * rhv = nullptr;
705                        if (baseRhv) {
706                            rhv = baseRhv;
707                        } else {
708                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
709                            rhv = b->CreateBlockAlignedLoad(rhv);
710                        }
711                        rhv = b->CreateBitCast(rhv, vTy);
712
713                        Value * result = nullptr;
714                        if (typeId == TypeId::Add) {
715                            result = b->CreateAdd(lhv, rhv);
716                        } else { // if (typeId == TypeId::Subtract) {
717                            result = b->CreateSub(lhv, rhv);
718                        }
719                        b->CreateAlignedStore(result, b->CreateGEP(value, {b->getInt32(0), b->getInt32(i)}), getAlignment(result));
720                    }
721
722                } else {
723
724                    value = UndefValue::get(VectorType::get(maskTy, intWidth));
725
726                    for (unsigned i = 0; i < intWidth; ++i) {
727                        llvm::Constant * const index = b->getInt32(i);
728                        Value * lhv = nullptr;
729                        if (baseLhv) {
730                            lhv = baseLhv;
731                        } else {
732                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
733                            lhv = b->CreateBlockAlignedLoad(lhv);
734                        }
735                        lhv = b->CreateBitCast(lhv, vTy);
736
737                        Value * rhv = nullptr;
738                        if (baseRhv) {
739                            rhv = baseRhv;
740                        } else {
741                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
742                            rhv = b->CreateBlockAlignedLoad(rhv);
743                        }
744                        rhv = b->CreateBitCast(rhv, vTy);
745                        Value * comp = nullptr;
746                        switch (typeId) {
747                            case TypeId::GreaterThanEquals:
748                            case TypeId::LessThan:
749                                comp = b->simd_ult(intWidth, lhv, rhv);
750                                break;
751                            case TypeId::Equals:
752                            case TypeId::NotEquals:
753                                comp = b->simd_eq(intWidth, lhv, rhv);
754                                break;
755                            case TypeId::LessThanEquals:
756                            case TypeId::GreaterThan:
757                                comp = b->simd_ugt(intWidth, lhv, rhv);
758                                break;
759                            default: llvm_unreachable("invalid vector operator id");
760                        }
761                        Value * const mask = b->CreateZExtOrTrunc(b->hsimd_signmask(intWidth, comp), maskTy);
762                        value = b->mvmd_insert(maskWidth, value, mask, i);
763                    }
764                    value = b->CreateBitCast(value, b->getBitBlockType());
765                    switch (typeId) {
766                        case TypeId::GreaterThanEquals:
767                        case TypeId::LessThanEquals:
768                        case TypeId::NotEquals:
769                            value = b->CreateNot(value);
770                        default: break;
771                    }
772                }
773
774            } else {
775                Value * const lhv = compileExpression(b, lh);
776                Value * const rhv = compileExpression(b, rh);
777                switch (op->getClassTypeId()) {
778                    case TypeId::Add:
779                        value = b->CreateAdd(lhv, rhv); break;
780                    case TypeId::Subtract:
781                        value = b->CreateSub(lhv, rhv); break;
782                    case TypeId::LessThan:
783                        value = b->CreateICmpULT(lhv, rhv); break;
784                    case TypeId::LessThanEquals:
785                        value = b->CreateICmpULE(lhv, rhv); break;
786                    case TypeId::Equals:
787                        value = b->CreateICmpEQ(lhv, rhv); break;
788                    case TypeId::GreaterThanEquals:
789                        value = b->CreateICmpUGE(lhv, rhv); break;
790                    case TypeId::GreaterThan:
791                        value = b->CreateICmpUGT(lhv, rhv); break;
792                    case TypeId::NotEquals:
793                        value = b->CreateICmpNE(lhv, rhv); break;
794                    default: llvm_unreachable("invalid scalar operator id");
795                }
796            }
797        } else { // use before def error
798            std::string tmp;
799            raw_string_ostream out(tmp);
800            out << "PabloCompiler: ";
801            expr->print(out);
802            out << " was used before definition";
803            report_fatal_error(out.str());
804        }
805        assert (value);
806        // mMarker.insert({expr, value});
807    }
808    if (LLVM_UNLIKELY(value->getType()->isPointerTy() && ensureLoaded)) {
809        value = b->CreateAlignedLoad(value, getPointerElementAlignment(value));
810    }
811    return value;
812}
813
814Value * PabloCompiler::getPointerToVar(const std::unique_ptr<kernel::KernelBuilder> & b, const Var * var, Value * index1, Value * index2)  {
815    assert (var && index1);
816    if (LLVM_LIKELY(var->isKernelParameter())) {
817        if (LLVM_UNLIKELY(var->isScalar())) {
818            std::string tmp;
819            raw_string_ostream out(tmp);
820            out << mKernel->getName();
821            out << ": cannot index scalar value ";
822            var->print(out);
823            report_fatal_error(out.str());
824        } else if (var->isReadOnly()) {
825            if (index2) {
826                return b->getInputStreamPackPtr(var->getName(), index1, index2);
827            } else {
828                return b->getInputStreamBlockPtr(var->getName(), index1);
829            }
830        } else if (var->isReadNone()) {
831            if (index2) {
832                return b->getOutputStreamPackPtr(var->getName(), index1, index2);
833            } else {
834                return b->getOutputStreamBlockPtr(var->getName(), index1);
835            }
836        } else {
837            std::string tmp;
838            raw_string_ostream out(tmp);
839            out << mKernel->getName();
840            out << ": stream ";
841            var->print(out);
842            out << " cannot be read from or written to";
843            report_fatal_error(out.str());
844        }
845    } else {
846        Value * const ptr = compileExpression(b, var, false);
847        std::vector<Value *> offsets;
848        offsets.push_back(ConstantInt::getNullValue(index1->getType()));
849        offsets.push_back(index1);
850        if (index2) offsets.push_back(index2);
851        return b->CreateGEP(ptr, offsets);
852    }
853}
854
855PabloCompiler::PabloCompiler(PabloKernel * const kernel)
856: mKernel(kernel)
857, mCarryManager(make_unique<CarryManager>())
858, mBranchCount(0) {
859    assert ("PabloKernel cannot be null!" && kernel);
860}
861
862PabloCompiler::~PabloCompiler() {
863}
864
865}
Note: See TracBrowser for help on using the repository browser.