source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 5994

Last change on this file since 5994 was 5992, checked in by cameron, 17 months ago

Setting BinaryFilesMode? to Text (temporary); conversion to unique_ptr progress

File size: 37.8 KB
Line 
1/*
2 *  Copyright (c) 2014-16 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "pablo_compiler.h"
8#include <pablo/pablo_kernel.h>
9#include <pablo/pablo_toolchain.h>
10#include <pablo/codegenstate.h>
11#include <pablo/boolean.h>
12#include <pablo/arithmetic.h>
13#include <pablo/branch.h>
14#include <pablo/pe_advance.h>
15#include <pablo/pe_lookahead.h>
16#include <pablo/pe_matchstar.h>
17#include <pablo/pe_scanthru.h>
18#include <pablo/pe_infile.h>
19#include <pablo/pe_count.h>
20#include <pablo/pe_integer.h>
21#include <pablo/pe_string.h>
22#include <pablo/pe_zeroes.h>
23#include <pablo/pe_ones.h>
24#include <pablo/pe_repeat.h>
25#include <pablo/pe_pack.h>
26#include <pablo/pe_var.h>
27#include <pablo/ps_assign.h>
28#ifdef USE_CARRYPACK_MANAGER
29#include <pablo/carrypack_manager.h>
30#else
31#include <pablo/carry_manager.h>
32#endif
33#include <kernels/kernel_builder.h>
34#include <kernels/streamset.h>
35#include <llvm/IR/Module.h>
36#include <llvm/IR/Type.h>
37#include <llvm/Support/raw_os_ostream.h>
38
39using namespace llvm;
40
41namespace pablo {
42
43using TypeId = PabloAST::ClassTypeId;
44
45inline static unsigned getAlignment(const Type * const type) {
46    return type->getPrimitiveSizeInBits() / 8;
47}
48
49inline static unsigned getAlignment(const Value * const expr) {
50    return getAlignment(expr->getType());
51}
52
53inline static unsigned getPointerElementAlignment(const Value * const ptr) {
54    return getAlignment(ptr->getType()->getPointerElementType());
55}
56
57void PabloCompiler::initializeKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
58    mBranchCount = 0;
59    examineBlock(b, mKernel->getEntryScope());
60    mCarryManager->initializeCarryData(b, mKernel);
61    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
62        const auto count = (mBranchCount * 2) + 1;
63        mKernel->addScalar(ArrayType::get(mKernel->getSizeTy(), count), "profile");
64        mBasicBlock.reserve(count);
65    }
66}
67
68void PabloCompiler::releaseKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
69    mCarryManager->releaseCarryData(b);
70}
71
72void PabloCompiler::clearCarryData(const std::unique_ptr<kernel::KernelBuilder> & b) {
73    mCarryManager->clearCarryData(b);
74}
75
76void PabloCompiler::compile(const std::unique_ptr<kernel::KernelBuilder> & b) {
77    mCarryManager->initializeCodeGen(b);
78    PabloBlock * const entryBlock = mKernel->getEntryScope(); assert (entryBlock);
79    mMarker.emplace(entryBlock->createZeroes(), b->allZeroes());
80    mMarker.emplace(entryBlock->createOnes(), b->allOnes());
81    mBranchCount = 0;
82    addBranchCounter(b);
83    compileBlock(b, entryBlock);
84    mCarryManager->finalizeCodeGen(b);
85}
86
87void PabloCompiler::examineBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
88    for (const Statement * stmt : *block) {
89        if (LLVM_UNLIKELY(isa<Lookahead>(stmt))) {
90            const Lookahead * const la = cast<Lookahead>(stmt);
91            PabloAST * input = la->getExpression();
92            if (isa<Extract>(input)) {
93                input = cast<Extract>(input)->getArray();
94            }
95            bool notFound = true;
96            if (LLVM_LIKELY(isa<Var>(input))) {
97                for (unsigned i = 0; i < mKernel->getNumOfInputs(); ++i) {
98                    if (input == mKernel->getInput(i)) {
99                        const auto & binding = mKernel->getStreamInput(i);
100                        if (LLVM_UNLIKELY(!binding.hasLookahead() || binding.getLookahead() < la->getAmount())) {
101                            std::string tmp;
102                            raw_string_ostream out(tmp);
103                            input->print(out);
104                            out << " must have a lookahead attribute of at least " << la->getAmount();
105                            report_fatal_error(out.str());
106                        }
107                        notFound = false;
108                        break;
109                    }
110                }
111            }
112            if (LLVM_UNLIKELY(notFound)) {
113                report_fatal_error("Lookahead " + stmt->getName() + " can only be performed on an input streamset");
114            }
115        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
116            ++mBranchCount;
117            examineBlock(b, cast<Branch>(stmt)->getBody());
118        } else if (LLVM_UNLIKELY(isa<Count>(stmt))) {
119            mAccumulator.insert(std::make_pair(stmt, b->getInt32(mKernel->addUnnamedScalar(stmt->getType()))));
120        }
121    }   
122}
123
124void PabloCompiler::addBranchCounter(const std::unique_ptr<kernel::KernelBuilder> & b) {
125    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {       
126        Value * ptr = b->getScalarFieldPtr("profile");
127        assert (mBasicBlock.size() < ptr->getType()->getPointerElementType()->getArrayNumElements());
128        ptr = b->CreateGEP(ptr, {b->getInt32(0), b->getInt32(mBasicBlock.size())});
129        const auto alignment = getPointerElementAlignment(ptr);
130        Value * value = b->CreateAlignedLoad(ptr, alignment, false, "branchCounter");
131        value = b->CreateAdd(value, ConstantInt::get(cast<IntegerType>(value->getType()), 1));
132        b->CreateAlignedStore(value, ptr, alignment);
133        mBasicBlock.push_back(b->GetInsertBlock());
134    }
135}
136
137inline void PabloCompiler::compileBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
138    for (const Statement * statement : *block) {
139        compileStatement(b, statement);
140    }
141}
142
143void PabloCompiler::compileIf(const std::unique_ptr<kernel::KernelBuilder> & b, const If * const ifStatement) {
144    //
145    //  The If-ElseZero stmt:
146    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
147    //  If the value of the predicate is nonzero, then determine the values of variables
148    //  <var>* by executing the given statements.  Otherwise, the value of the
149    //  variables are all zero.  Requirements: (a) no variable that is defined within
150    //  the body of the if may be accessed outside unless it is explicitly
151    //  listed in the variable list, (b) every variable in the defined list receives
152    //  a value within the body, and (c) the logical consequence of executing
153    //  the statements in the event that the predicate is zero is that the
154    //  values of all defined variables indeed work out to be 0.
155    //
156    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
157    //  is inserted for each variable in the defined variable list.  It receives
158    //  a zero value from the ifentry block and the defined value from the if
159    //  body.
160    //
161
162    BasicBlock * const ifEntryBlock = b->GetInsertBlock();
163    ++mBranchCount;
164    BasicBlock * const ifBodyBlock = b->CreateBasicBlock("if.body_" + std::to_string(mBranchCount));
165    BasicBlock * const ifEndBlock = b->CreateBasicBlock("if.end_" + std::to_string(mBranchCount));
166   
167    std::vector<std::pair<const Var *, Value *>> incoming;
168
169    for (const Var * var : ifStatement->getEscaped()) {
170        if (LLVM_UNLIKELY(var->isKernelParameter())) {
171            Value * marker = nullptr;
172            if (var->isScalar()) {
173                marker = b->getScalarFieldPtr(var->getName());
174            } else if (var->isReadOnly()) {
175                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
176            } else if (var->isReadNone()) {
177                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
178            }
179            mMarker[var] = marker;
180        } else {
181            auto f = mMarker.find(var);
182            if (LLVM_UNLIKELY(f == mMarker.end())) {
183                std::string tmp;
184                raw_string_ostream out(tmp);
185                var->print(out);
186                out << " is uninitialized prior to entering ";
187                ifStatement->print(out);
188                report_fatal_error(out.str());
189            }
190            incoming.emplace_back(var, f->second);
191        }
192    }
193
194    const PabloBlock * ifBody = ifStatement->getBody();
195   
196    mCarryManager->enterIfScope(b, ifBody);
197
198    Value * condition = compileExpression(b, ifStatement->getCondition());
199    if (condition->getType() == b->getBitBlockType()) {
200        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
201    }
202   
203    b->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
204   
205    // Entry processing is complete, now handle the body of the if.
206    b->SetInsertPoint(ifBodyBlock);
207
208    mCarryManager->enterIfBody(b, ifEntryBlock);
209
210    addBranchCounter(b);
211
212    compileBlock(b, ifBody);
213
214    mCarryManager->leaveIfBody(b, b->GetInsertBlock());
215
216    BasicBlock * ifExitBlock = b->GetInsertBlock();
217
218    b->CreateBr(ifEndBlock);
219
220    ifEndBlock->moveAfter(ifExitBlock);
221
222    //End Block
223    b->SetInsertPoint(ifEndBlock);
224
225    mCarryManager->leaveIfScope(b, ifEntryBlock, ifExitBlock);
226
227    for (const auto i : incoming) {
228        const Var * var; Value * incoming;
229        std::tie(var, incoming) = i;
230
231        auto f = mMarker.find(var);
232        if (LLVM_UNLIKELY(f == mMarker.end())) {
233            std::string tmp;
234            raw_string_ostream out(tmp);
235            out << "PHINode creation error: ";
236            var->print(out);
237            out << " was not assigned an outgoing value.";
238            report_fatal_error(out.str());
239        }
240
241        Value * const outgoing = f->second;
242        if (LLVM_UNLIKELY(incoming == outgoing)) {
243            continue;
244        }
245
246        if (LLVM_UNLIKELY(incoming->getType() != outgoing->getType())) {
247            std::string tmp;
248            raw_string_ostream out(tmp);
249            out << "PHINode creation error: incoming type of ";
250            var->print(out);
251            out << " (";
252            incoming->getType()->print(out);
253            out << ") differs from the outgoing type (";
254            outgoing->getType()->print(out);
255            out << ") within ";
256            ifStatement->print(out);
257            report_fatal_error(out.str());
258        }
259
260        PHINode * phi = b->CreatePHI(incoming->getType(), 2, var->getName());
261        phi->addIncoming(incoming, ifEntryBlock);
262        phi->addIncoming(outgoing, ifExitBlock);
263        f->second = phi;
264    }
265
266    addBranchCounter(b);
267}
268
269void PabloCompiler::compileWhile(const std::unique_ptr<kernel::KernelBuilder> & b, const While * const whileStatement) {
270
271    const PabloBlock * const whileBody = whileStatement->getBody();
272
273    BasicBlock * whileEntryBlock = b->GetInsertBlock();
274
275    const auto escaped = whileStatement->getEscaped();
276
277#ifdef ENABLE_BOUNDED_WHILE
278    PHINode * bound_phi = nullptr;  // Needed for bounded while loops.
279#endif
280    // On entry to the while structure, proceed to execute the first iteration
281    // of the loop body unconditionally. The while condition is tested at the end of
282    // the loop.
283
284    for (const Var * var : escaped) {
285        if (LLVM_UNLIKELY(var->isKernelParameter())) {
286            Value * marker = nullptr;
287            if (var->isScalar()) {
288                marker = b->getScalarFieldPtr(var->getName());
289            } else if (var->isReadOnly()) {
290                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
291            } else if (var->isReadNone()) {
292                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
293            }
294            mMarker[var] = marker;
295        }
296    }
297
298    mCarryManager->enterLoopScope(b, whileBody);
299
300    BasicBlock * whileBodyBlock = b->CreateBasicBlock("while.body_" + std::to_string(mBranchCount));
301    BasicBlock * whileEndBlock = b->CreateBasicBlock("while.end_" + std::to_string(mBranchCount));
302    ++mBranchCount;
303
304    b->CreateBr(whileBodyBlock);
305
306    b->SetInsertPoint(whileBodyBlock);
307
308    //
309    // There are 3 sets of Phi nodes for the while loop.
310    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
311    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
312    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
313#ifdef ENABLE_BOUNDED_WHILE
314    // (4) The loop bound, if any.
315#endif
316
317    std::vector<std::pair<const Var *, PHINode *>> variants;
318
319    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
320    for (const auto var : escaped) {
321        auto f = mMarker.find(var);
322        if (LLVM_UNLIKELY(f == mMarker.end())) {
323            std::string tmp;
324            raw_string_ostream out(tmp);
325            out << "PHINode creation error: ";
326            var->print(out);
327            out << " is uninitialized prior to entering ";
328            whileStatement->print(out);
329            report_fatal_error(out.str());
330        }
331        Value * entryValue = f->second;
332        PHINode * phi = b->CreatePHI(entryValue->getType(), 2, var->getName());
333        phi->addIncoming(entryValue, whileEntryBlock);
334        f->second = phi;
335        assert(mMarker[var] == phi);
336        variants.emplace_back(var, phi);
337    }
338#ifdef ENABLE_BOUNDED_WHILE
339    if (whileStatement->getBound()) {
340        bound_phi = b->CreatePHI(b->getSizeTy(), 2, "while_bound");
341        bound_phi->addIncoming(b->getSize(whileStatement->getBound()), whileEntryBlock);
342    }
343#endif
344
345    mCarryManager->enterLoopBody(b, whileEntryBlock);
346
347    addBranchCounter(b);
348
349    compileBlock(b, whileBody);
350
351    // After the whileBody has been compiled, we may be in a different basic block.
352
353    mCarryManager->leaveLoopBody(b, b->GetInsertBlock());
354
355
356#ifdef ENABLE_BOUNDED_WHILE
357    if (whileStatement->getBound()) {
358        Value * new_bound = b->CreateSub(bound_phi, b->getSize(1));
359        bound_phi->addIncoming(new_bound, whileExitBlock);
360        condition = b->CreateAnd(condition, b->CreateICmpUGT(new_bound, ConstantInt::getNullValue(b->getSizeTy())));
361    }
362#endif
363
364    BasicBlock * const whileExitBlock = b->GetInsertBlock();
365
366    // and for any variant nodes in the loop body
367    for (const auto variant : variants) {
368        const Var * var; PHINode * incomingPhi;
369        std::tie(var, incomingPhi) = variant;
370        const auto f = mMarker.find(var);
371        if (LLVM_UNLIKELY(f == mMarker.end())) {
372            std::string tmp;
373            raw_string_ostream out(tmp);
374            out << "PHINode creation error: ";
375            var->print(out);
376            out << " is no longer assigned a value.";
377            report_fatal_error(out.str());
378        }
379
380        Value * const outgoingValue = f->second;
381
382        if (LLVM_UNLIKELY(incomingPhi->getType() != outgoingValue->getType())) {
383            std::string tmp;
384            raw_string_ostream out(tmp);
385            out << "PHINode creation error: incoming type of ";
386            var->print(out);
387            out << " (";
388            incomingPhi->getType()->print(out);
389            out << ") differs from the outgoing type (";
390            outgoingValue->getType()->print(out);
391            out << ") within ";
392            whileStatement->print(out);
393            report_fatal_error(out.str());
394        }
395
396        incomingPhi->addIncoming(outgoingValue, whileExitBlock);
397    }
398
399    // Terminate the while loop body with a conditional branch back.
400    Value * condition = compileExpression(b, whileStatement->getCondition());
401    if (condition->getType() == b->getBitBlockType()) {
402        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
403    }
404
405    b->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
406
407    whileEndBlock->moveAfter(whileExitBlock);
408
409    b->SetInsertPoint(whileEndBlock);
410
411    mCarryManager->leaveLoopScope(b, whileEntryBlock, whileExitBlock);
412
413    addBranchCounter(b);
414}
415
416void PabloCompiler::compileStatement(const std::unique_ptr<kernel::KernelBuilder> & b, const Statement * const stmt) {
417
418    if (LLVM_UNLIKELY(isa<If>(stmt))) {
419        compileIf(b, cast<If>(stmt));
420    } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
421        compileWhile(b, cast<While>(stmt));
422    } else {
423        const PabloAST * expr = stmt;
424        Value * value = nullptr;
425        if (isa<And>(stmt)) {
426            Value * const op0 = compileExpression(b, stmt->getOperand(0));
427            Value * const op1 = compileExpression(b, stmt->getOperand(1));
428            value = b->simd_and(op0, op1);
429        } else if (isa<Or>(stmt)) {
430            Value * const op0 = compileExpression(b, stmt->getOperand(0));
431            Value * const op1 = compileExpression(b, stmt->getOperand(1));
432            value = b->simd_or(op0, op1);
433        } else if (isa<Xor>(stmt)) {
434            Value * const op0 = compileExpression(b, stmt->getOperand(0));
435            Value * const op1 = compileExpression(b, stmt->getOperand(1));
436            value = b->simd_xor(op0, op1);
437        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
438            Value* ifMask = compileExpression(b, sel->getCondition());
439            Value* ifTrue = b->simd_and(ifMask, compileExpression(b, sel->getTrueExpr()));
440            Value* ifFalse = b->simd_and(b->simd_not(ifMask), compileExpression(b, sel->getFalseExpr()));
441            value = b->simd_or(ifTrue, ifFalse);
442        } else if (isa<Not>(stmt)) {
443            value = b->simd_not(compileExpression(b, stmt->getOperand(0)));
444        } else if (isa<Advance>(stmt)) {
445            const Advance * const adv = cast<Advance>(stmt);
446            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
447            // manager so that it properly selects the correct carry bit.
448            value = mCarryManager->advanceCarryInCarryOut(b, adv, compileExpression(b, adv->getExpression()));
449        } else if (isa<IndexedAdvance>(stmt)) {
450            const IndexedAdvance * const adv = cast<IndexedAdvance>(stmt);
451            Value * strm = compileExpression(b, adv->getExpression());
452            Value * index_strm = compileExpression(b, adv->getIndex());
453            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
454            // manager so that it properly selects the correct carry bit.
455            value = mCarryManager->indexedAdvanceCarryInCarryOut(b, adv, strm, index_strm);
456        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
457            Value * const marker = compileExpression(b, mstar->getMarker());
458            Value * const cc = compileExpression(b, mstar->getCharClass());
459            Value * const marker_and_cc = b->simd_and(marker, cc);
460            Value * const sum = mCarryManager->addCarryInCarryOut(b, mstar, marker_and_cc, cc);
461            value = b->simd_or(b->simd_xor(sum, cc), marker);
462        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
463            Value * const from = compileExpression(b, sthru->getScanFrom());
464            Value * const thru = compileExpression(b, sthru->getScanThru());
465            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, thru);
466            value = b->simd_and(sum, b->simd_not(thru));
467        } else if (const ScanTo * sthru = dyn_cast<ScanTo>(stmt)) {
468            Value * const marker_expr = compileExpression(b, sthru->getScanFrom());
469            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
470            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, marker_expr, b->simd_not(to));
471            value = b->simd_and(sum, to);
472        } else if (const AdvanceThenScanThru * sthru = dyn_cast<AdvanceThenScanThru>(stmt)) {
473            Value * const from = compileExpression(b, sthru->getScanFrom());
474            Value * const thru = compileExpression(b, sthru->getScanThru());
475            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, thru));
476            value = b->simd_and(sum, b->simd_not(thru));
477        } else if (const AdvanceThenScanTo * sthru = dyn_cast<AdvanceThenScanTo>(stmt)) {
478            Value * const from = compileExpression(b, sthru->getScanFrom());
479            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
480            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, b->simd_not(to)));
481            value = b->simd_and(sum, to);
482        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
483            expr = cast<Assign>(stmt)->getVariable();
484            value = compileExpression(b, cast<Assign>(stmt)->getValue());
485            if (isa<Extract>(expr) || (isa<Var>(expr) && cast<Var>(expr)->isKernelParameter())) {
486                Value * const ptr = compileExpression(b, expr, false);
487                b->CreateAlignedStore(value, ptr, getAlignment(value));
488                value = ptr;
489            }
490        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
491            Value * EOFmask = b->getScalarField("EOFmask");
492            value = b->simd_and(compileExpression(b, e->getExpr()), b->simd_not(EOFmask));
493        } else if (const AtEOF * e = dyn_cast<AtEOF>(stmt)) {
494            Value * EOFbit = b->getScalarField("EOFbit");
495            value = b->simd_and(compileExpression(b, e->getExpr()), EOFbit);
496        } else if (const Count * c = dyn_cast<Count>(stmt)) {
497            Value * EOFbit = b->getScalarField("EOFbit");
498            Value * EOFmask = b->getScalarField("EOFmask");
499            Value * const to_count = b->simd_and(b->simd_or(b->simd_not(EOFmask), EOFbit), compileExpression(b, c->getExpr()));           
500            const auto f = mAccumulator.find(c);
501            if (LLVM_UNLIKELY(f == mAccumulator.end())) {
502                report_fatal_error("Unknown accumulator: " + c->getName().str());
503            }
504            Value * const ptr = b->getScalarFieldPtr(f->second);
505            const auto alignment = getPointerElementAlignment(ptr);
506            Value * const countSoFar = b->CreateAlignedLoad(ptr, alignment, c->getName() + "_accumulator");
507            const auto fieldWidth = b->getSizeTy()->getBitWidth();
508            Value * bitBlockCount = b->simd_popcount(b->getBitBlockWidth(), to_count);
509            value = b->CreateAdd(b->mvmd_extract(fieldWidth, bitBlockCount, 0), countSoFar, "countSoFar");
510            b->CreateAlignedStore(value, ptr, alignment);
511        } else if (const Lookahead * l = dyn_cast<Lookahead>(stmt)) {
512            PabloAST * stream = l->getExpression();
513            Value * index = nullptr;
514            if (LLVM_UNLIKELY(isa<Extract>(stream))) {               
515                index = compileExpression(b, cast<Extract>(stream)->getIndex(), true);
516                stream = cast<Extract>(stream)->getArray();
517            } else {
518                index = b->getInt32(0);
519            }
520            const auto bit_shift = (l->getAmount() % b->getBitBlockWidth());
521            const auto block_shift = (l->getAmount() / b->getBitBlockWidth());
522            Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift));
523            Value * lookAhead = b->CreateBlockAlignedLoad(ptr);
524            if (bit_shift == 0) {  // Simple case with no intra-block shifting.
525                value = lookAhead;
526            } else { // Need to form shift result from two adjacent blocks.
527                Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift + 1));
528                Value * lookAhead1 = b->CreateBlockAlignedLoad(ptr);
529                if (LLVM_UNLIKELY((bit_shift % 8) == 0)) { // Use a single whole-byte shift, if possible.
530                    value = b->mvmd_dslli(8, lookAhead1, lookAhead, (bit_shift / 8));
531                } else {
532                    Type  * const streamType = b->getIntNTy(b->getBitBlockWidth());
533                    Value * b1 = b->CreateBitCast(lookAhead1, streamType);
534                    Value * b0 = b->CreateBitCast(lookAhead, streamType);
535                    Value * result = b->CreateOr(b->CreateShl(b1, b->getBitBlockWidth() - bit_shift), b->CreateLShr(b0, bit_shift));
536                    value = b->CreateBitCast(result, b->getBitBlockType());
537                }
538            }
539        } else if (const Repeat * const s = dyn_cast<Repeat>(stmt)) {
540            value = compileExpression(b, s->getValue());
541            Type * const ty = s->getType();
542            if (LLVM_LIKELY(ty->isVectorTy())) {
543                const auto fw = s->getFieldWidth()->value();
544                value = b->CreateZExtOrTrunc(value, b->getIntNTy(fw));
545                value = b->simd_fill(fw, value);
546            } else {
547                value = b->CreateZExtOrTrunc(value, ty);
548            }
549        } else if (const PackH * const p = dyn_cast<PackH>(stmt)) {
550            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
551            const auto packWidth = p->getFieldWidth()->value();
552            assert (sourceWidth == packWidth);
553            Value * const base = compileExpression(b, p->getValue(), false);
554            const auto result_packs = sourceWidth/2;
555            if (LLVM_LIKELY(result_packs > 1)) {
556                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
557            }
558            Constant * const ZERO = b->getInt32(0);
559            for (unsigned i = 0; i < result_packs; ++i) {
560                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
561                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
562                Value * P = b->bitCast(b->hsimd_packh(packWidth, A, B));
563                if (LLVM_UNLIKELY(result_packs == 1)) {
564                    value = P;
565                    break;
566                }
567                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
568            }
569        } else if (const PackL * const p = dyn_cast<PackL>(stmt)) {
570            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
571            const auto packWidth = p->getFieldWidth()->value();
572            assert (sourceWidth == packWidth);
573            Value * const base = compileExpression(b, p->getValue(), false);
574            const auto result_packs = sourceWidth/2;
575            if (LLVM_LIKELY(result_packs > 1)) {
576                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
577            }
578            Constant * const ZERO = b->getInt32(0);
579            for (unsigned i = 0; i < result_packs; ++i) {
580                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
581                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
582                Value * P = b->bitCast(b->hsimd_packl(packWidth, A, B));
583                if (LLVM_UNLIKELY(result_packs == 1)) {
584                    value = P;
585                    break;
586                }
587                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
588            }
589        } else {
590            std::string tmp;
591            raw_string_ostream out(tmp);
592            out << "PabloCompiler: ";
593            stmt->print(out);
594            out << " was not recognized by the compiler";
595            report_fatal_error(out.str());
596        }
597        assert (expr);
598        assert (value);
599        mMarker[expr] = value;
600        if (DebugOptionIsSet(DumpTrace)) {
601            std::string tmp;
602            raw_string_ostream name(tmp);
603            expr->print(name);
604            if (value->getType()->isVectorTy()) {
605                b->CallPrintRegister(name.str(), value);
606            } else if (value->getType()->isIntegerTy()) {
607                b->CallPrintInt(name.str(), value);
608            }
609        }
610    }
611}
612
613unsigned getIntegerBitWidth(const Type * ty) {
614    if (ty->isArrayTy()) {
615        assert (ty->getArrayNumElements() == 1);
616        ty = ty->getArrayElementType();
617    }
618    if (ty->isVectorTy()) {
619        assert (ty->getVectorNumElements() == 0);
620        ty = ty->getVectorElementType();
621    }
622    return ty->getIntegerBitWidth();
623}
624
625Value * PabloCompiler::compileExpression(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloAST * const expr, const bool ensureLoaded) {
626    const auto f = mMarker.find(expr);   
627    Value * value = nullptr;
628    if (LLVM_LIKELY(f != mMarker.end())) {
629        value = f->second;
630    } else {
631        if (isa<Integer>(expr)) {
632            value = ConstantInt::get(cast<Integer>(expr)->getType(), cast<Integer>(expr)->value());
633        } else if (isa<Zeroes>(expr)) {
634            value = b->allZeroes();
635        } else if (LLVM_UNLIKELY(isa<Ones>(expr))) {
636            value = b->allOnes();
637        } else if (isa<Extract>(expr)) {
638            const Extract * const extract = cast<Extract>(expr);
639            const Var * const var = cast<Var>(extract->getArray());
640            Value * const index = compileExpression(b, extract->getIndex());
641            value = getPointerToVar(b, var, index);
642        } else if (LLVM_UNLIKELY(isa<Var>(expr))) {
643            const Var * const var = cast<Var>(expr);
644            if (LLVM_LIKELY(var->isKernelParameter() && var->isScalar())) {
645                value = b->getScalarFieldPtr(var->getName());
646            } else { // use before def error
647                std::string tmp;
648                raw_string_ostream out(tmp);
649                out << "PabloCompiler: ";
650                expr->print(out);
651                out << " is not a scalar value or was used before definition";
652                report_fatal_error(out.str());
653            }
654        } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
655            const Operator * const op = cast<Operator>(expr);
656            const PabloAST * lh = op->getLH();
657            const PabloAST * rh = op->getRH();
658            if ((isa<Var>(lh) || isa<Extract>(lh)) || (isa<Var>(rh) || isa<Extract>(rh))) {
659                const unsigned n = std::min(getIntegerBitWidth(lh->getType()), getIntegerBitWidth(rh->getType()));
660                const unsigned m = b->getBitBlockWidth() / n;
661                IntegerType * const fw = b->getIntNTy(m);
662                VectorType * const vTy = VectorType::get(b->getIntNTy(n), m);
663
664                Value * baseLhv = nullptr;
665                Value * lhvStreamIndex = nullptr;
666                if (isa<Var>(lh)) {
667                    lhvStreamIndex = b->getInt32(0);
668                } else if (isa<Extract>(lh)) {
669                    lhvStreamIndex = compileExpression(b, cast<Extract>(lh)->getIndex());
670                    lh = cast<Extract>(lh)->getArray();
671                } else {
672                    baseLhv = compileExpression(b, lh);
673                }
674
675                Value * baseRhv = nullptr;
676                Value * rhvStreamIndex = nullptr;
677                if (isa<Var>(rh)) {
678                    rhvStreamIndex = b->getInt32(0);
679                } else if (isa<Extract>(rh)) {
680                    rhvStreamIndex = compileExpression(b, cast<Extract>(rh)->getIndex());
681                    rh = cast<Extract>(rh)->getArray();
682                } else {
683                    baseRhv = compileExpression(b, rh);
684                }
685
686                const TypeId typeId = op->getClassTypeId();
687
688                if (LLVM_UNLIKELY(typeId == TypeId::Add || typeId == TypeId::Subtract)) {
689
690                    value = b->CreateAlloca(vTy, b->getInt32(n));
691
692                    for (unsigned i = 0; i < n; ++i) {
693                        llvm::Constant * const index = b->getInt32(i);
694                        Value * lhv = nullptr;
695                        if (baseLhv) {
696                            lhv = baseLhv;
697                        } else {
698                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
699                            lhv = b->CreateBlockAlignedLoad(lhv);
700                        }
701                        lhv = b->CreateBitCast(lhv, vTy);
702
703                        Value * rhv = nullptr;
704                        if (baseRhv) {
705                            rhv = baseRhv;
706                        } else {
707                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
708                            rhv = b->CreateBlockAlignedLoad(rhv);
709                        }
710                        rhv = b->CreateBitCast(rhv, vTy);
711
712                        Value * result = nullptr;
713                        if (typeId == TypeId::Add) {
714                            result = b->CreateAdd(lhv, rhv);
715                        } else { // if (typeId == TypeId::Subtract) {
716                            result = b->CreateSub(lhv, rhv);
717                        }
718                        b->CreateAlignedStore(result, b->CreateGEP(value, {b->getInt32(0), b->getInt32(i)}), getAlignment(result));
719                    }
720
721                } else {
722
723                    value = UndefValue::get(VectorType::get(fw, n));
724
725                    for (unsigned i = 0; i < n; ++i) {
726                        llvm::Constant * const index = b->getInt32(i);
727                        Value * lhv = nullptr;
728                        if (baseLhv) {
729                            lhv = baseLhv;
730                        } else {
731                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
732                            lhv = b->CreateBlockAlignedLoad(lhv);
733                        }
734                        lhv = b->CreateBitCast(lhv, vTy);
735
736                        Value * rhv = nullptr;
737                        if (baseRhv) {
738                            rhv = baseRhv;
739                        } else {
740                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
741                            rhv = b->CreateBlockAlignedLoad(rhv);
742                        }
743                        rhv = b->CreateBitCast(rhv, vTy);
744
745                        Value * comp = nullptr;
746                        switch (typeId) {
747                            case TypeId::GreaterThanEquals:
748                            case TypeId::LessThan:
749                                comp = b->simd_ult(n, lhv, rhv);
750                                break;
751                            case TypeId::Equals:
752                            case TypeId::NotEquals:
753                                comp = b->simd_eq(n, lhv, rhv);
754                                break;
755                            case TypeId::LessThanEquals:
756                            case TypeId::GreaterThan:
757                                comp = b->simd_ugt(n, lhv, rhv);
758                                break;
759                            default: llvm_unreachable("invalid vector operator id");
760                        }
761                        Value * const mask = b->CreateZExtOrTrunc(b->hsimd_signmask(n, comp), fw);
762                        value = b->mvmd_insert(m, value, mask, i);
763                    }
764
765                    value = b->CreateBitCast(value, b->getBitBlockType());
766                    switch (typeId) {
767                        case TypeId::GreaterThanEquals:
768                        case TypeId::LessThanEquals:
769                        case TypeId::NotEquals:
770                            value = b->simd_not(value);
771                        default: break;
772                    }
773                }
774
775            } else {
776                Value * const lhv = compileExpression(b, lh);
777                Value * const rhv = compileExpression(b, rh);
778                switch (op->getClassTypeId()) {
779                    case TypeId::Add:
780                        value = b->CreateAdd(lhv, rhv); break;
781                    case TypeId::Subtract:
782                        value = b->CreateSub(lhv, rhv); break;
783                    case TypeId::LessThan:
784                        value = b->CreateICmpSLT(lhv, rhv); break;
785                    case TypeId::LessThanEquals:
786                        value = b->CreateICmpSLE(lhv, rhv); break;
787                    case TypeId::Equals:
788                        value = b->CreateICmpEQ(lhv, rhv); break;
789                    case TypeId::GreaterThanEquals:
790                        value = b->CreateICmpSGE(lhv, rhv); break;
791                    case TypeId::GreaterThan:
792                        value = b->CreateICmpSGT(lhv, rhv); break;
793                    case TypeId::NotEquals:
794                        value = b->CreateICmpNE(lhv, rhv); break;
795                    default: llvm_unreachable("invalid scalar operator id");
796                }
797            }
798        } else { // use before def error
799            std::string tmp;
800            raw_string_ostream out(tmp);
801            out << "PabloCompiler: ";
802            expr->print(out);
803            out << " was used before definition";
804            report_fatal_error(out.str());
805        }
806        assert (value);
807        // mMarker.insert({expr, value});
808    }
809    if (LLVM_UNLIKELY(value->getType()->isPointerTy() && ensureLoaded)) {
810        value = b->CreateAlignedLoad(value, getPointerElementAlignment(value));
811    }
812    return value;
813}
814
815Value * PabloCompiler::getPointerToVar(const std::unique_ptr<kernel::KernelBuilder> & b, const Var * var, Value * index1, Value * index2)  {
816    assert (var && index1);
817    if (LLVM_LIKELY(var->isKernelParameter())) {
818        if (LLVM_UNLIKELY(var->isScalar())) {
819            std::string tmp;
820            raw_string_ostream out(tmp);
821            out << mKernel->getName();
822            out << ": cannot index scalar value ";
823            var->print(out);
824            report_fatal_error(out.str());
825        } else if (var->isReadOnly()) {
826            if (index2) {
827                return b->getInputStreamPackPtr(var->getName(), index1, index2);
828            } else {
829                return b->getInputStreamBlockPtr(var->getName(), index1);
830            }
831        } else if (var->isReadNone()) {
832            if (index2) {
833                return b->getOutputStreamPackPtr(var->getName(), index1, index2);
834            } else {
835                return b->getOutputStreamBlockPtr(var->getName(), index1);
836            }
837        } else {
838            std::string tmp;
839            raw_string_ostream out(tmp);
840            out << mKernel->getName();
841            out << ": stream ";
842            var->print(out);
843            out << " cannot be read from or written to";
844            report_fatal_error(out.str());
845        }
846    } else {
847        Value * const ptr = compileExpression(b, var, false);
848        std::vector<Value *> offsets;
849        offsets.push_back(ConstantInt::getNullValue(index1->getType()));
850        offsets.push_back(index1);
851        if (index2) offsets.push_back(index2);
852        return b->CreateGEP(ptr, offsets);
853    }
854}
855
856PabloCompiler::PabloCompiler(PabloKernel * const kernel)
857: mKernel(kernel)
858, mCarryManager(make_unique<CarryManager>())
859, mBranchCount(0) {
860    assert ("PabloKernel cannot be null!" && kernel);
861}
862
863PabloCompiler::~PabloCompiler() {
864}
865
866}
Note: See TracBrowser for help on using the repository browser.