source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 5837

Last change on this file since 5837 was 5837, checked in by cameron, 18 months ago

Pablo packh/packl and transposition with -enable-pablo-s2p

File size: 38.0 KB
Line 
1/*
2 *  Copyright (c) 2014-16 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "pablo_compiler.h"
8#include <pablo/pablo_kernel.h>
9#include <pablo/pablo_toolchain.h>
10#include <pablo/codegenstate.h>
11#include <pablo/boolean.h>
12#include <pablo/arithmetic.h>
13#include <pablo/branch.h>
14#include <pablo/pe_advance.h>
15#include <pablo/pe_lookahead.h>
16#include <pablo/pe_matchstar.h>
17#include <pablo/pe_scanthru.h>
18#include <pablo/pe_infile.h>
19#include <pablo/pe_count.h>
20#include <pablo/pe_integer.h>
21#include <pablo/pe_string.h>
22#include <pablo/pe_zeroes.h>
23#include <pablo/pe_ones.h>
24#include <pablo/pe_repeat.h>
25#include <pablo/pe_pack.h>
26#include <pablo/pe_var.h>
27#include <pablo/ps_assign.h>
28#ifdef USE_CARRYPACK_MANAGER
29#include <pablo/carrypack_manager.h>
30#else
31#include <pablo/carry_manager.h>
32#endif
33#include <kernels/kernel_builder.h>
34#include <kernels/streamset.h>
35#include <llvm/IR/Module.h>
36#include <llvm/IR/Type.h>
37#include <llvm/Support/raw_os_ostream.h>
38
39using namespace llvm;
40
41namespace pablo {
42
43using TypeId = PabloAST::ClassTypeId;
44
45inline static unsigned getAlignment(const Type * const type) {
46    return type->getPrimitiveSizeInBits() / 8;
47}
48
49inline static unsigned getAlignment(const Value * const expr) {
50    return getAlignment(expr->getType());
51}
52
53inline static unsigned getPointerElementAlignment(const Value * const ptr) {
54    return getAlignment(ptr->getType()->getPointerElementType());
55}
56
57void PabloCompiler::initializeKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
58    mBranchCount = 0;
59    examineBlock(b, mKernel->getEntryScope());
60    mCarryManager->initializeCarryData(b, mKernel);
61    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
62        const auto count = (mBranchCount * 2) + 1;
63        mKernel->addScalar(ArrayType::get(mKernel->getSizeTy(), count), "profile");
64        mBasicBlock.reserve(count);
65    }
66}
67
68void PabloCompiler::releaseKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
69    mCarryManager->releaseCarryData(b);
70}
71
72void PabloCompiler::clearCarryData(const std::unique_ptr<kernel::KernelBuilder> & b) {
73    mCarryManager->clearCarryData(b);
74}
75
76void PabloCompiler::compile(const std::unique_ptr<kernel::KernelBuilder> & b) {
77    mCarryManager->initializeCodeGen(b);
78    PabloBlock * const entryBlock = mKernel->getEntryScope(); assert (entryBlock);
79    mMarker.emplace(entryBlock->createZeroes(), b->allZeroes());
80    mMarker.emplace(entryBlock->createOnes(), b->allOnes());
81    mBranchCount = 0;
82    addBranchCounter(b);
83    compileBlock(b, entryBlock);
84    mCarryManager->finalizeCodeGen(b);
85}
86
87void PabloCompiler::examineBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
88    for (const Statement * stmt : *block) {
89        if (LLVM_UNLIKELY(isa<Lookahead>(stmt))) {
90            const Lookahead * const la = cast<Lookahead>(stmt);
91            PabloAST * input = la->getExpression();
92            if (isa<Extract>(input)) {
93                input = cast<Extract>(input)->getArray();
94            }
95            bool notFound = true;
96            if (LLVM_LIKELY(isa<Var>(input))) {
97                for (unsigned i = 0; i < mKernel->getNumOfInputs(); ++i) {
98                    if (input == mKernel->getInput(i)) {
99                        const auto & binding = mKernel->getStreamInput(i);
100                        if (LLVM_UNLIKELY(!binding.hasLookahead() || binding.getLookahead() < la->getAmount())) {
101                            std::string tmp;
102                            raw_string_ostream out(tmp);
103                            input->print(out);
104                            out << " must have a lookahead attribute of at least " << la->getAmount();
105                            report_fatal_error(out.str());
106                        }
107                        notFound = false;
108                        break;
109                    }
110                }
111            }
112            if (LLVM_UNLIKELY(notFound)) {
113                report_fatal_error("Lookahead " + stmt->getName() + " can only be performed on an input streamset");
114            }
115        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
116            ++mBranchCount;
117            examineBlock(b, cast<Branch>(stmt)->getBody());
118        } else if (LLVM_UNLIKELY(isa<Count>(stmt))) {
119            mAccumulator.insert(std::make_pair(stmt, b->getInt32(mKernel->addUnnamedScalar(stmt->getType()))));
120        }
121    }   
122}
123
124void PabloCompiler::addBranchCounter(const std::unique_ptr<kernel::KernelBuilder> & b) {
125    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {       
126        Value * ptr = b->getScalarFieldPtr("profile");
127        assert (mBasicBlock.size() < ptr->getType()->getPointerElementType()->getArrayNumElements());
128        ptr = b->CreateGEP(ptr, {b->getInt32(0), b->getInt32(mBasicBlock.size())});
129        const auto alignment = getPointerElementAlignment(ptr);
130        Value * value = b->CreateAlignedLoad(ptr, alignment, false, "branchCounter");
131        value = b->CreateAdd(value, ConstantInt::get(cast<IntegerType>(value->getType()), 1));
132        b->CreateAlignedStore(value, ptr, alignment);
133        mBasicBlock.push_back(b->GetInsertBlock());
134    }
135}
136
137inline void PabloCompiler::compileBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
138    for (const Statement * statement : *block) {
139        compileStatement(b, statement);
140    }
141}
142
143void PabloCompiler::compileIf(const std::unique_ptr<kernel::KernelBuilder> & b, const If * const ifStatement) {
144    //
145    //  The If-ElseZero stmt:
146    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
147    //  If the value of the predicate is nonzero, then determine the values of variables
148    //  <var>* by executing the given statements.  Otherwise, the value of the
149    //  variables are all zero.  Requirements: (a) no variable that is defined within
150    //  the body of the if may be accessed outside unless it is explicitly
151    //  listed in the variable list, (b) every variable in the defined list receives
152    //  a value within the body, and (c) the logical consequence of executing
153    //  the statements in the event that the predicate is zero is that the
154    //  values of all defined variables indeed work out to be 0.
155    //
156    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
157    //  is inserted for each variable in the defined variable list.  It receives
158    //  a zero value from the ifentry block and the defined value from the if
159    //  body.
160    //
161
162    BasicBlock * const ifEntryBlock = b->GetInsertBlock();
163    ++mBranchCount;
164    BasicBlock * const ifBodyBlock = b->CreateBasicBlock("if.body_" + std::to_string(mBranchCount));
165    BasicBlock * const ifEndBlock = b->CreateBasicBlock("if.end_" + std::to_string(mBranchCount));
166   
167    std::vector<std::pair<const Var *, Value *>> incoming;
168
169    for (const Var * var : ifStatement->getEscaped()) {
170        if (LLVM_UNLIKELY(var->isKernelParameter())) {
171            Value * marker = nullptr;
172            if (var->isScalar()) {
173                marker = b->getScalarFieldPtr(var->getName());
174            } else if (var->isReadOnly()) {
175                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
176            } else if (var->isReadNone()) {
177                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
178            }
179            mMarker[var] = marker;
180        } else {
181            auto f = mMarker.find(var);
182            if (LLVM_UNLIKELY(f == mMarker.end())) {
183                std::string tmp;
184                raw_string_ostream out(tmp);
185                var->print(out);
186                out << " is uninitialized prior to entering ";
187                ifStatement->print(out);
188                report_fatal_error(out.str());
189            }
190            incoming.emplace_back(var, f->second);
191        }
192    }
193
194    const PabloBlock * ifBody = ifStatement->getBody();
195   
196    mCarryManager->enterIfScope(b, ifBody);
197
198    Value * condition = compileExpression(b, ifStatement->getCondition());
199    if (condition->getType() == b->getBitBlockType()) {
200        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
201    }
202   
203    b->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
204   
205    // Entry processing is complete, now handle the body of the if.
206    b->SetInsertPoint(ifBodyBlock);
207
208    mCarryManager->enterIfBody(b, ifEntryBlock);
209
210    addBranchCounter(b);
211
212    compileBlock(b, ifBody);
213
214    mCarryManager->leaveIfBody(b, b->GetInsertBlock());
215
216    BasicBlock * ifExitBlock = b->GetInsertBlock();
217
218    b->CreateBr(ifEndBlock);
219
220    ifEndBlock->moveAfter(ifExitBlock);
221
222    //End Block
223    b->SetInsertPoint(ifEndBlock);
224
225    mCarryManager->leaveIfScope(b, ifEntryBlock, ifExitBlock);
226
227    for (const auto i : incoming) {
228        const Var * var; Value * incoming;
229        std::tie(var, incoming) = i;
230
231        auto f = mMarker.find(var);
232        if (LLVM_UNLIKELY(f == mMarker.end())) {
233            std::string tmp;
234            raw_string_ostream out(tmp);
235            out << "PHINode creation error: ";
236            var->print(out);
237            out << " was not assigned an outgoing value.";
238            report_fatal_error(out.str());
239        }
240
241        Value * const outgoing = f->second;
242        if (LLVM_UNLIKELY(incoming == outgoing)) {
243            continue;
244        }
245
246        if (LLVM_UNLIKELY(incoming->getType() != outgoing->getType())) {
247            std::string tmp;
248            raw_string_ostream out(tmp);
249            out << "PHINode creation error: incoming type of ";
250            var->print(out);
251            out << " (";
252            incoming->getType()->print(out);
253            out << ") differs from the outgoing type (";
254            outgoing->getType()->print(out);
255            out << ") within ";
256            ifStatement->print(out);
257            report_fatal_error(out.str());
258        }
259
260        PHINode * phi = b->CreatePHI(incoming->getType(), 2, var->getName());
261        phi->addIncoming(incoming, ifEntryBlock);
262        phi->addIncoming(outgoing, ifExitBlock);
263        f->second = phi;
264    }
265
266    addBranchCounter(b);
267}
268
269void PabloCompiler::compileWhile(const std::unique_ptr<kernel::KernelBuilder> & b, const While * const whileStatement) {
270
271    const PabloBlock * const whileBody = whileStatement->getBody();
272
273    BasicBlock * whileEntryBlock = b->GetInsertBlock();
274
275    const auto escaped = whileStatement->getEscaped();
276
277#ifdef ENABLE_BOUNDED_WHILE
278    PHINode * bound_phi = nullptr;  // Needed for bounded while loops.
279#endif
280    // On entry to the while structure, proceed to execute the first iteration
281    // of the loop body unconditionally. The while condition is tested at the end of
282    // the loop.
283
284    for (const Var * var : escaped) {
285        if (LLVM_UNLIKELY(var->isKernelParameter())) {
286            Value * marker = nullptr;
287            if (var->isScalar()) {
288                marker = b->getScalarFieldPtr(var->getName());
289            } else if (var->isReadOnly()) {
290                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
291            } else if (var->isReadNone()) {
292                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
293            }
294            mMarker[var] = marker;
295        }
296    }
297
298    mCarryManager->enterLoopScope(b, whileBody);
299
300    BasicBlock * whileBodyBlock = b->CreateBasicBlock("while.body_" + std::to_string(mBranchCount));
301    BasicBlock * whileEndBlock = b->CreateBasicBlock("while.end_" + std::to_string(mBranchCount));
302    ++mBranchCount;
303
304    b->CreateBr(whileBodyBlock);
305
306    b->SetInsertPoint(whileBodyBlock);
307
308    //
309    // There are 3 sets of Phi nodes for the while loop.
310    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
311    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
312    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
313#ifdef ENABLE_BOUNDED_WHILE
314    // (4) The loop bound, if any.
315#endif
316
317    std::vector<std::pair<const Var *, PHINode *>> variants;
318
319    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
320    for (const auto var : escaped) {
321        auto f = mMarker.find(var);
322        if (LLVM_UNLIKELY(f == mMarker.end())) {
323            std::string tmp;
324            raw_string_ostream out(tmp);
325            out << "PHINode creation error: ";
326            var->print(out);
327            out << " is uninitialized prior to entering ";
328            whileStatement->print(out);
329            report_fatal_error(out.str());
330        }
331        Value * entryValue = f->second;
332        PHINode * phi = b->CreatePHI(entryValue->getType(), 2, var->getName());
333        phi->addIncoming(entryValue, whileEntryBlock);
334        f->second = phi;
335        assert(mMarker[var] == phi);
336        variants.emplace_back(var, phi);
337    }
338#ifdef ENABLE_BOUNDED_WHILE
339    if (whileStatement->getBound()) {
340        bound_phi = b->CreatePHI(b->getSizeTy(), 2, "while_bound");
341        bound_phi->addIncoming(b->getSize(whileStatement->getBound()), whileEntryBlock);
342    }
343#endif
344
345    mCarryManager->enterLoopBody(b, whileEntryBlock);
346
347    addBranchCounter(b);
348
349    compileBlock(b, whileBody);
350
351    // After the whileBody has been compiled, we may be in a different basic block.
352
353    mCarryManager->leaveLoopBody(b, b->GetInsertBlock());
354
355
356#ifdef ENABLE_BOUNDED_WHILE
357    if (whileStatement->getBound()) {
358        Value * new_bound = b->CreateSub(bound_phi, b->getSize(1));
359        bound_phi->addIncoming(new_bound, whileExitBlock);
360        condition = b->CreateAnd(condition, b->CreateICmpUGT(new_bound, ConstantInt::getNullValue(b->getSizeTy())));
361    }
362#endif
363
364    BasicBlock * const whileExitBlock = b->GetInsertBlock();
365
366    // and for any variant nodes in the loop body
367    for (const auto variant : variants) {
368        const Var * var; PHINode * incomingPhi;
369        std::tie(var, incomingPhi) = variant;
370        const auto f = mMarker.find(var);
371        if (LLVM_UNLIKELY(f == mMarker.end())) {
372            std::string tmp;
373            raw_string_ostream out(tmp);
374            out << "PHINode creation error: ";
375            var->print(out);
376            out << " is no longer assigned a value.";
377            report_fatal_error(out.str());
378        }
379
380        Value * const outgoingValue = f->second;
381
382        if (LLVM_UNLIKELY(incomingPhi->getType() != outgoingValue->getType())) {
383            std::string tmp;
384            raw_string_ostream out(tmp);
385            out << "PHINode creation error: incoming type of ";
386            var->print(out);
387            out << " (";
388            incomingPhi->getType()->print(out);
389            out << ") differs from the outgoing type (";
390            outgoingValue->getType()->print(out);
391            out << ") within ";
392            whileStatement->print(out);
393            report_fatal_error(out.str());
394        }
395
396        incomingPhi->addIncoming(outgoingValue, whileExitBlock);
397    }
398
399    // Terminate the while loop body with a conditional branch back.
400    Value * condition = compileExpression(b, whileStatement->getCondition());
401    if (condition->getType() == b->getBitBlockType()) {
402        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
403    }
404
405    b->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
406
407    whileEndBlock->moveAfter(whileExitBlock);
408
409    b->SetInsertPoint(whileEndBlock);
410
411    mCarryManager->leaveLoopScope(b, whileEntryBlock, whileExitBlock);
412
413    addBranchCounter(b);
414}
415
416void PabloCompiler::compileStatement(const std::unique_ptr<kernel::KernelBuilder> & b, const Statement * const stmt) {
417
418    if (LLVM_UNLIKELY(isa<If>(stmt))) {
419        compileIf(b, cast<If>(stmt));
420    } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
421        compileWhile(b, cast<While>(stmt));
422    } else {
423        const PabloAST * expr = stmt;
424        Value * value = nullptr;
425        if (isa<And>(stmt)) {
426            Value * const op0 = compileExpression(b, stmt->getOperand(0));
427            Value * const op1 = compileExpression(b, stmt->getOperand(1));
428            value = b->simd_and(op0, op1);
429        } else if (isa<Or>(stmt)) {
430            Value * const op0 = compileExpression(b, stmt->getOperand(0));
431            Value * const op1 = compileExpression(b, stmt->getOperand(1));
432            value = b->simd_or(op0, op1);
433        } else if (isa<Xor>(stmt)) {
434            Value * const op0 = compileExpression(b, stmt->getOperand(0));
435            Value * const op1 = compileExpression(b, stmt->getOperand(1));
436            value = b->simd_xor(op0, op1);
437        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
438            Value* ifMask = compileExpression(b, sel->getCondition());
439            Value* ifTrue = b->simd_and(ifMask, compileExpression(b, sel->getTrueExpr()));
440            Value* ifFalse = b->simd_and(b->simd_not(ifMask), compileExpression(b, sel->getFalseExpr()));
441            value = b->simd_or(ifTrue, ifFalse);
442        } else if (isa<Not>(stmt)) {
443            value = b->simd_not(compileExpression(b, stmt->getOperand(0)));
444        } else if (isa<Advance>(stmt)) {
445            const Advance * const adv = cast<Advance>(stmt);
446            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
447            // manager so that it properly selects the correct carry bit.
448            value = mCarryManager->advanceCarryInCarryOut(b, adv, compileExpression(b, adv->getExpression()));
449        } else if (isa<IndexedAdvance>(stmt)) {
450            const IndexedAdvance * const adv = cast<IndexedAdvance>(stmt);
451            Value * strm = compileExpression(b, adv->getExpression());
452            Value * index_strm = compileExpression(b, adv->getIndex());
453            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
454            // manager so that it properly selects the correct carry bit.
455            value = mCarryManager->indexedAdvanceCarryInCarryOut(b, adv, strm, index_strm);
456        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
457            Value * const marker = compileExpression(b, mstar->getMarker());
458            Value * const cc = compileExpression(b, mstar->getCharClass());
459            Value * const marker_and_cc = b->simd_and(marker, cc);
460            Value * const sum = mCarryManager->addCarryInCarryOut(b, mstar, marker_and_cc, cc);
461            value = b->simd_or(b->simd_xor(sum, cc), marker);
462        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
463            Value * const from = compileExpression(b, sthru->getScanFrom());
464            Value * const thru = compileExpression(b, sthru->getScanThru());
465            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, thru);
466            value = b->simd_and(sum, b->simd_not(thru));
467        } else if (const ScanTo * sthru = dyn_cast<ScanTo>(stmt)) {
468            Value * const marker_expr = compileExpression(b, sthru->getScanFrom());
469            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
470            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, marker_expr, b->simd_not(to));
471            value = b->simd_and(sum, to);
472        } else if (const AdvanceThenScanThru * sthru = dyn_cast<AdvanceThenScanThru>(stmt)) {
473            Value * const from = compileExpression(b, sthru->getScanFrom());
474            Value * const thru = compileExpression(b, sthru->getScanThru());
475            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, thru));
476            value = b->simd_and(sum, b->simd_not(thru));
477        } else if (const AdvanceThenScanTo * sthru = dyn_cast<AdvanceThenScanTo>(stmt)) {
478            Value * const from = compileExpression(b, sthru->getScanFrom());
479            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
480            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, b->simd_not(to)));
481            value = b->simd_and(sum, to);
482        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
483            expr = cast<Assign>(stmt)->getVariable();
484            value = compileExpression(b, cast<Assign>(stmt)->getValue());
485            if (isa<Extract>(expr) || (isa<Var>(expr) && cast<Var>(expr)->isKernelParameter())) {
486                Value * const ptr = compileExpression(b, expr, false);
487                b->CreateAlignedStore(value, ptr, getAlignment(value));
488                value = ptr;
489            }
490        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
491            Value * EOFmask = b->getScalarField("EOFmask");
492            value = b->simd_and(compileExpression(b, e->getExpr()), b->simd_not(EOFmask));
493        } else if (const AtEOF * e = dyn_cast<AtEOF>(stmt)) {
494            Value * EOFbit = b->getScalarField("EOFbit");
495            value = b->simd_and(compileExpression(b, e->getExpr()), EOFbit);
496        } else if (const Count * c = dyn_cast<Count>(stmt)) {
497            Value * EOFbit = b->getScalarField("EOFbit");
498            Value * EOFmask = b->getScalarField("EOFmask");
499            Value * const to_count = b->simd_and(b->simd_or(b->simd_not(EOFmask), EOFbit), compileExpression(b, c->getExpr()));           
500            const auto f = mAccumulator.find(c);
501            if (LLVM_UNLIKELY(f == mAccumulator.end())) {
502                report_fatal_error("Unknown accumulator: " + c->getName().str());
503            }
504            Value * const ptr = b->getScalarFieldPtr(f->second);
505            const auto alignment = getPointerElementAlignment(ptr);
506            Value * const countSoFar = b->CreateAlignedLoad(ptr, alignment, c->getName() + "_accumulator");
507            const auto fieldWidth = b->getSizeTy()->getBitWidth();
508            auto fields = (b->getBitBlockWidth() / fieldWidth);
509            Value * fieldCounts = b->simd_popcount(fieldWidth, to_count);
510            while (fields > 1) {
511                fields /= 2;
512                fieldCounts = b->CreateAdd(fieldCounts, b->mvmd_srli(fieldWidth, fieldCounts, fields));
513            }
514            value = b->CreateAdd(b->mvmd_extract(fieldWidth, fieldCounts, 0), countSoFar, "countSoFar");
515            b->CreateAlignedStore(value, ptr, alignment);
516        } else if (const Lookahead * l = dyn_cast<Lookahead>(stmt)) {
517            PabloAST * stream = l->getExpression();
518            Value * index = nullptr;
519            if (LLVM_UNLIKELY(isa<Extract>(stream))) {               
520                index = compileExpression(b, cast<Extract>(stream)->getIndex(), true);
521                stream = cast<Extract>(stream)->getArray();
522            } else {
523                index = b->getInt32(0);
524            }
525            const auto bit_shift = (l->getAmount() % b->getBitBlockWidth());
526            const auto block_shift = (l->getAmount() / b->getBitBlockWidth());
527            Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift));
528            Value * lookAhead = b->CreateBlockAlignedLoad(ptr);
529            if (bit_shift == 0) {  // Simple case with no intra-block shifting.
530                value = lookAhead;
531            } else { // Need to form shift result from two adjacent blocks.
532                Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift + 1));
533                Value * lookAhead1 = b->CreateBlockAlignedLoad(ptr);
534                if (LLVM_UNLIKELY((bit_shift % 8) == 0)) { // Use a single whole-byte shift, if possible.
535                    value = b->mvmd_dslli(8, lookAhead1, lookAhead, (bit_shift / 8));
536                } else {
537                    Type  * const streamType = b->getIntNTy(b->getBitBlockWidth());
538                    Value * b1 = b->CreateBitCast(lookAhead1, streamType);
539                    Value * b0 = b->CreateBitCast(lookAhead, streamType);
540                    Value * result = b->CreateOr(b->CreateShl(b1, b->getBitBlockWidth() - bit_shift), b->CreateLShr(b0, bit_shift));
541                    value = b->CreateBitCast(result, b->getBitBlockType());
542                }
543            }
544        } else if (const Repeat * const s = dyn_cast<Repeat>(stmt)) {
545            value = compileExpression(b, s->getValue());
546            Type * const ty = s->getType();
547            if (LLVM_LIKELY(ty->isVectorTy())) {
548                const auto fw = s->getFieldWidth()->value();
549                value = b->CreateZExtOrTrunc(value, b->getIntNTy(fw));
550                value = b->simd_fill(fw, value);
551            } else {
552                value = b->CreateZExtOrTrunc(value, ty);
553            }
554        } else if (const PackH * const p = dyn_cast<PackH>(stmt)) {
555            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
556            const auto packWidth = p->getFieldWidth()->value();
557            assert (sourceWidth == packWidth);
558            Value * const base = compileExpression(b, p->getValue(), false);
559            const auto result_packs = sourceWidth/2;
560            if (LLVM_LIKELY(result_packs > 1)) {
561                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
562            }
563            Constant * const ZERO = b->getInt32(0);
564            for (unsigned i = 0; i < result_packs; ++i) {
565                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
566                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
567                Value * P = b->bitCast(b->hsimd_packh(packWidth, A, B));
568                if (LLVM_UNLIKELY(result_packs == 1)) {
569                    value = P;
570                    break;
571                }
572                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
573            }
574        } else if (const PackL * const p = dyn_cast<PackL>(stmt)) {
575            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
576            const auto packWidth = p->getFieldWidth()->value();
577            assert (sourceWidth == packWidth);
578            Value * const base = compileExpression(b, p->getValue(), false);
579            const auto result_packs = sourceWidth/2;
580            if (LLVM_LIKELY(result_packs > 1)) {
581                value = b->CreateAlloca(ArrayType::get(b->getBitBlockType(), result_packs));
582            }
583            Constant * const ZERO = b->getInt32(0);
584            for (unsigned i = 0; i < result_packs; ++i) {
585                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
586                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
587                Value * P = b->bitCast(b->hsimd_packl(packWidth, A, B));
588                if (LLVM_UNLIKELY(result_packs == 1)) {
589                    value = P;
590                    break;
591                }
592                b->CreateStore(P, b->CreateGEP(value, {ZERO, b->getInt32(i)}));
593            }
594        } else {
595            std::string tmp;
596            raw_string_ostream out(tmp);
597            out << "PabloCompiler: ";
598            stmt->print(out);
599            out << " was not recognized by the compiler";
600            report_fatal_error(out.str());
601        }
602        assert (expr);
603        assert (value);
604        mMarker[expr] = value;
605        if (DebugOptionIsSet(DumpTrace)) {
606            std::string tmp;
607            raw_string_ostream name(tmp);
608            expr->print(name);
609            if (value->getType()->isVectorTy()) {
610                b->CallPrintRegister(name.str(), value);
611            } else if (value->getType()->isIntegerTy()) {
612                b->CallPrintInt(name.str(), value);
613            }
614        }
615    }
616}
617
618unsigned getIntegerBitWidth(const Type * ty) {
619    if (ty->isArrayTy()) {
620        assert (ty->getArrayNumElements() == 1);
621        ty = ty->getArrayElementType();
622    }
623    if (ty->isVectorTy()) {
624        assert (ty->getVectorNumElements() == 0);
625        ty = ty->getVectorElementType();
626    }
627    return ty->getIntegerBitWidth();
628}
629
630Value * PabloCompiler::compileExpression(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloAST * const expr, const bool ensureLoaded) {
631    const auto f = mMarker.find(expr);   
632    Value * value = nullptr;
633    if (LLVM_LIKELY(f != mMarker.end())) {
634        value = f->second;
635    } else {
636        if (isa<Integer>(expr)) {
637            value = ConstantInt::get(cast<Integer>(expr)->getType(), cast<Integer>(expr)->value());
638        } else if (isa<Zeroes>(expr)) {
639            value = b->allZeroes();
640        } else if (LLVM_UNLIKELY(isa<Ones>(expr))) {
641            value = b->allOnes();
642        } else if (isa<Extract>(expr)) {
643            const Extract * const extract = cast<Extract>(expr);
644            const Var * const var = cast<Var>(extract->getArray());
645            Value * const index = compileExpression(b, extract->getIndex());
646            value = getPointerToVar(b, var, index);
647        } else if (LLVM_UNLIKELY(isa<Var>(expr))) {
648            const Var * const var = cast<Var>(expr);
649            if (LLVM_LIKELY(var->isKernelParameter() && var->isScalar())) {
650                value = b->getScalarFieldPtr(var->getName());
651            } else { // use before def error
652                std::string tmp;
653                raw_string_ostream out(tmp);
654                out << "PabloCompiler: ";
655                expr->print(out);
656                out << " is not a scalar value or was used before definition";
657                report_fatal_error(out.str());
658            }
659        } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
660            const Operator * const op = cast<Operator>(expr);
661            const PabloAST * lh = op->getLH();
662            const PabloAST * rh = op->getRH();
663            if ((isa<Var>(lh) || isa<Extract>(lh)) || (isa<Var>(rh) || isa<Extract>(rh))) {
664                const unsigned n = std::min(getIntegerBitWidth(lh->getType()), getIntegerBitWidth(rh->getType()));
665                const unsigned m = b->getBitBlockWidth() / n;
666                IntegerType * const fw = b->getIntNTy(m);
667                VectorType * const vTy = VectorType::get(b->getIntNTy(n), m);
668
669                Value * baseLhv = nullptr;
670                Value * lhvStreamIndex = nullptr;
671                if (isa<Var>(lh)) {
672                    lhvStreamIndex = b->getInt32(0);
673                } else if (isa<Extract>(lh)) {
674                    lhvStreamIndex = compileExpression(b, cast<Extract>(lh)->getIndex());
675                } else {
676                    baseLhv = compileExpression(b, lh);
677                }
678
679                Value * baseRhv = nullptr;
680                Value * rhvStreamIndex = nullptr;
681                if (isa<Var>(rh)) {
682                    rhvStreamIndex = b->getInt32(0);
683                } else if (isa<Extract>(lh)) {
684                    rhvStreamIndex = compileExpression(b, cast<Extract>(rh)->getIndex());
685                } else {
686                    baseRhv = compileExpression(b, rh);
687                }
688
689                const TypeId typeId = op->getClassTypeId();
690
691                if (LLVM_UNLIKELY(typeId == TypeId::Add || typeId == TypeId::Subtract)) {
692
693                    value = b->CreateAlloca(vTy, b->getInt32(n));
694
695                    for (unsigned i = 0; i < n; ++i) {
696                        llvm::Constant * const index = b->getInt32(i);
697                        Value * lhv = nullptr;
698                        if (baseLhv) {
699                            lhv = baseLhv;
700                        } else {
701                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
702                            lhv = b->CreateBlockAlignedLoad(lhv);
703                        }
704                        lhv = b->CreateBitCast(lhv, vTy);
705
706                        Value * rhv = nullptr;
707                        if (baseRhv) {
708                            rhv = baseRhv;
709                        } else {
710                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
711                            rhv = b->CreateBlockAlignedLoad(rhv);
712                        }
713                        rhv = b->CreateBitCast(rhv, vTy);
714
715                        Value * result = nullptr;
716                        if (typeId == TypeId::Add) {
717                            result = b->CreateAdd(lhv, rhv);
718                        } else { // if (typeId == TypeId::Subtract) {
719                            result = b->CreateSub(lhv, rhv);
720                        }
721                        b->CreateAlignedStore(result, b->CreateGEP(value, {b->getInt32(0), b->getInt32(i)}), getAlignment(result));
722                    }
723
724                } else {
725
726                    value = UndefValue::get(VectorType::get(fw, n));
727
728                    for (unsigned i = 0; i < n; ++i) {
729                        llvm::Constant * const index = b->getInt32(i);
730                        Value * lhv = nullptr;
731                        if (baseLhv) {
732                            lhv = baseLhv;
733                        } else {
734                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
735                            lhv = b->CreateBlockAlignedLoad(lhv);
736                        }
737                        lhv = b->CreateBitCast(lhv, vTy);
738
739                        Value * rhv = nullptr;
740                        if (baseRhv) {
741                            rhv = baseRhv;
742                        } else {
743                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
744                            rhv = b->CreateBlockAlignedLoad(rhv);
745                        }
746                        rhv = b->CreateBitCast(rhv, vTy);
747
748                        Value * comp = nullptr;
749                        switch (typeId) {
750                            case TypeId::GreaterThanEquals:
751                            case TypeId::LessThan:
752                                comp = b->simd_ult(n, lhv, rhv);
753                                break;
754                            case TypeId::Equals:
755                            case TypeId::NotEquals:
756                                comp = b->simd_eq(n, lhv, rhv);
757                                break;
758                            case TypeId::LessThanEquals:
759                            case TypeId::GreaterThan:
760                                comp = b->simd_ugt(n, lhv, rhv);
761                                break;
762                            default: llvm_unreachable("invalid vector operator id");
763                        }
764                        Value * const mask = b->CreateZExtOrTrunc(b->hsimd_signmask(n, comp), fw);
765                        value = b->mvmd_insert(m, value, mask, i);
766                    }
767
768                    value = b->CreateBitCast(value, b->getBitBlockType());
769                    switch (typeId) {
770                        case TypeId::GreaterThanEquals:
771                        case TypeId::LessThanEquals:
772                        case TypeId::NotEquals:
773                            value = b->simd_not(value);
774                        default: break;
775                    }
776                }
777
778            } else {
779                Value * const lhv = compileExpression(b, lh);
780                Value * const rhv = compileExpression(b, rh);
781                switch (op->getClassTypeId()) {
782                    case TypeId::Add:
783                        value = b->CreateAdd(lhv, rhv); break;
784                    case TypeId::Subtract:
785                        value = b->CreateSub(lhv, rhv); break;
786                    case TypeId::LessThan:
787                        value = b->CreateICmpSLT(lhv, rhv); break;
788                    case TypeId::LessThanEquals:
789                        value = b->CreateICmpSLE(lhv, rhv); break;
790                    case TypeId::Equals:
791                        value = b->CreateICmpEQ(lhv, rhv); break;
792                    case TypeId::GreaterThanEquals:
793                        value = b->CreateICmpSGE(lhv, rhv); break;
794                    case TypeId::GreaterThan:
795                        value = b->CreateICmpSGT(lhv, rhv); break;
796                    case TypeId::NotEquals:
797                        value = b->CreateICmpNE(lhv, rhv); break;
798                    default: llvm_unreachable("invalid scalar operator id");
799                }
800            }
801        } else { // use before def error
802            std::string tmp;
803            raw_string_ostream out(tmp);
804            out << "PabloCompiler: ";
805            expr->print(out);
806            out << " was used before definition";
807            report_fatal_error(out.str());
808        }
809        assert (value);
810        // mMarker.insert({expr, value});
811    }
812    if (LLVM_UNLIKELY(value->getType()->isPointerTy() && ensureLoaded)) {
813        value = b->CreateAlignedLoad(value, getPointerElementAlignment(value));
814    }
815    return value;
816}
817
818Value * PabloCompiler::getPointerToVar(const std::unique_ptr<kernel::KernelBuilder> & b, const Var * var, Value * index1, Value * index2)  {
819    assert (var && index1 && (index2 == nullptr || index1->getType() == index2->getType()));
820    if (LLVM_LIKELY(var->isKernelParameter())) {
821        if (LLVM_UNLIKELY(var->isScalar())) {
822            std::string tmp;
823            raw_string_ostream out(tmp);
824            out << mKernel->getName();
825            out << ": cannot index scalar value ";
826            var->print(out);
827            report_fatal_error(out.str());
828        } else if (var->isReadOnly()) {
829            if (index2) {
830                return b->getInputStreamPackPtr(var->getName(), index1, index2);
831            } else {
832                return b->getInputStreamBlockPtr(var->getName(), index1);
833            }
834        } else if (var->isReadNone()) {
835            if (index2) {
836                return b->getOutputStreamPackPtr(var->getName(), index1, index2);
837            } else {
838                return b->getOutputStreamBlockPtr(var->getName(), index1);
839            }
840        } else {
841            std::string tmp;
842            raw_string_ostream out(tmp);
843            out << mKernel->getName();
844            out << ": stream ";
845            var->print(out);
846            out << " cannot be read from or written to";
847            report_fatal_error(out.str());
848        }
849    } else {
850        Value * const ptr = compileExpression(b, var, false);
851        std::vector<Value *> offsets;
852        offsets.push_back(ConstantInt::getNullValue(index1->getType()));
853        offsets.push_back(index1);
854        if (index2) offsets.push_back(index2);
855        return b->CreateGEP(ptr, offsets);
856    }
857}
858
859PabloCompiler::PabloCompiler(PabloKernel * const kernel)
860: mKernel(kernel)
861, mCarryManager(new CarryManager)
862, mBranchCount(0) {
863    assert ("PabloKernel cannot be null!" && kernel);
864}
865
866PabloCompiler::~PabloCompiler() {
867    delete mCarryManager;
868}
869
870}
Note: See TracBrowser for help on using the repository browser.