source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 5831

Last change on this file since 5831 was 5831, checked in by nmedfort, 16 months ago

Potential bug fix for 32-bit

File size: 37.9 KB
Line 
1/*
2 *  Copyright (c) 2014-16 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "pablo_compiler.h"
8#include <pablo/pablo_kernel.h>
9#include <pablo/pablo_toolchain.h>
10#include <pablo/codegenstate.h>
11#include <pablo/boolean.h>
12#include <pablo/arithmetic.h>
13#include <pablo/branch.h>
14#include <pablo/pe_advance.h>
15#include <pablo/pe_lookahead.h>
16#include <pablo/pe_matchstar.h>
17#include <pablo/pe_scanthru.h>
18#include <pablo/pe_infile.h>
19#include <pablo/pe_count.h>
20#include <pablo/pe_integer.h>
21#include <pablo/pe_string.h>
22#include <pablo/pe_zeroes.h>
23#include <pablo/pe_ones.h>
24#include <pablo/pe_repeat.h>
25#include <pablo/pe_pack.h>
26#include <pablo/pe_var.h>
27#include <pablo/ps_assign.h>
28#ifdef USE_CARRYPACK_MANAGER
29#include <pablo/carrypack_manager.h>
30#else
31#include <pablo/carry_manager.h>
32#endif
33#include <kernels/kernel_builder.h>
34#include <kernels/streamset.h>
35#include <llvm/IR/Module.h>
36#include <llvm/Support/raw_os_ostream.h>
37
38using namespace llvm;
39
40namespace pablo {
41
42using TypeId = PabloAST::ClassTypeId;
43
44inline static unsigned getAlignment(const Type * const type) {
45    return type->getPrimitiveSizeInBits() / 8;
46}
47
48inline static unsigned getAlignment(const Value * const expr) {
49    return getAlignment(expr->getType());
50}
51
52inline static unsigned getPointerElementAlignment(const Value * const ptr) {
53    return getAlignment(ptr->getType()->getPointerElementType());
54}
55
56void PabloCompiler::initializeKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
57    mBranchCount = 0;
58    examineBlock(b, mKernel->getEntryBlock());
59    mCarryManager->initializeCarryData(b, mKernel);
60    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
61        const auto count = (mBranchCount * 2) + 1;
62        mKernel->addScalar(ArrayType::get(mKernel->getSizeTy(), count), "profile");
63        mBasicBlock.reserve(count);
64    }
65}
66
67void PabloCompiler::releaseKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
68    mCarryManager->releaseCarryData(b);
69}
70
71void PabloCompiler::clearCarryData(const std::unique_ptr<kernel::KernelBuilder> & b) {
72    mCarryManager->clearCarryData(b);
73}
74
75void PabloCompiler::compile(const std::unique_ptr<kernel::KernelBuilder> & b) {
76    mCarryManager->initializeCodeGen(b);
77    PabloBlock * const entryBlock = mKernel->getEntryBlock(); assert (entryBlock);
78    mMarker.emplace(entryBlock->createZeroes(), b->allZeroes());
79    mMarker.emplace(entryBlock->createOnes(), b->allOnes());
80    mBranchCount = 0;
81    addBranchCounter(b);
82    compileBlock(b, entryBlock);
83    mCarryManager->finalizeCodeGen(b);
84}
85
86void PabloCompiler::examineBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
87    for (const Statement * stmt : *block) {
88        if (LLVM_UNLIKELY(isa<Lookahead>(stmt))) {
89            const Lookahead * const la = cast<Lookahead>(stmt);
90            PabloAST * input = la->getExpression();
91            if (isa<Extract>(input)) {
92                input = cast<Extract>(input)->getArray();
93            }
94            bool notFound = true;
95            if (LLVM_LIKELY(isa<Var>(input))) {
96                for (unsigned i = 0; i < mKernel->getNumOfInputs(); ++i) {
97                    if (input == mKernel->getInput(i)) {
98                        const auto & binding = mKernel->getStreamInput(i);
99                        if (LLVM_UNLIKELY(!binding.hasLookahead() || binding.getLookahead() < la->getAmount())) {
100                            std::string tmp;
101                            raw_string_ostream out(tmp);
102                            input->print(out);
103                            out << " must have a lookahead attribute of at least " << la->getAmount();
104                            report_fatal_error(out.str());
105                        }
106                        notFound = false;
107                        break;
108                    }
109                }
110            }
111            if (LLVM_UNLIKELY(notFound)) {
112                report_fatal_error("Lookahead " + stmt->getName() + " can only be performed on an input streamset");
113            }
114        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
115            ++mBranchCount;
116            examineBlock(b, cast<Branch>(stmt)->getBody());
117        } else if (LLVM_UNLIKELY(isa<Count>(stmt))) {
118            mAccumulator.insert(std::make_pair(stmt, b->getInt32(mKernel->addUnnamedScalar(stmt->getType()))));
119        }
120    }   
121}
122
123void PabloCompiler::addBranchCounter(const std::unique_ptr<kernel::KernelBuilder> & b) {
124    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {       
125        Value * ptr = b->getScalarFieldPtr("profile");
126        assert (mBasicBlock.size() < ptr->getType()->getPointerElementType()->getArrayNumElements());
127        ptr = b->CreateGEP(ptr, {b->getInt32(0), b->getInt32(mBasicBlock.size())});
128        const auto alignment = getPointerElementAlignment(ptr);
129        Value * value = b->CreateAlignedLoad(ptr, alignment, false, "branchCounter");
130        value = b->CreateAdd(value, ConstantInt::get(cast<IntegerType>(value->getType()), 1));
131        b->CreateAlignedStore(value, ptr, alignment);
132        mBasicBlock.push_back(b->GetInsertBlock());
133    }
134}
135
136inline void PabloCompiler::compileBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
137    for (const Statement * statement : *block) {
138        compileStatement(b, statement);
139    }
140}
141
142void PabloCompiler::compileIf(const std::unique_ptr<kernel::KernelBuilder> & b, const If * const ifStatement) {
143    //
144    //  The If-ElseZero stmt:
145    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
146    //  If the value of the predicate is nonzero, then determine the values of variables
147    //  <var>* by executing the given statements.  Otherwise, the value of the
148    //  variables are all zero.  Requirements: (a) no variable that is defined within
149    //  the body of the if may be accessed outside unless it is explicitly
150    //  listed in the variable list, (b) every variable in the defined list receives
151    //  a value within the body, and (c) the logical consequence of executing
152    //  the statements in the event that the predicate is zero is that the
153    //  values of all defined variables indeed work out to be 0.
154    //
155    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
156    //  is inserted for each variable in the defined variable list.  It receives
157    //  a zero value from the ifentry block and the defined value from the if
158    //  body.
159    //
160
161    BasicBlock * const ifEntryBlock = b->GetInsertBlock();
162    ++mBranchCount;
163    BasicBlock * const ifBodyBlock = b->CreateBasicBlock("if.body_" + std::to_string(mBranchCount));
164    BasicBlock * const ifEndBlock = b->CreateBasicBlock("if.end_" + std::to_string(mBranchCount));
165   
166    std::vector<std::pair<const Var *, Value *>> incoming;
167
168    for (const Var * var : ifStatement->getEscaped()) {
169        if (LLVM_UNLIKELY(var->isKernelParameter())) {
170            Value * marker = nullptr;
171            if (var->isScalar()) {
172                marker = b->getScalarFieldPtr(var->getName());
173            } else if (var->isReadOnly()) {
174                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
175            } else if (var->isReadNone()) {
176                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
177            }
178            mMarker[var] = marker;
179        } else {
180            auto f = mMarker.find(var);
181            if (LLVM_UNLIKELY(f == mMarker.end())) {
182                std::string tmp;
183                raw_string_ostream out(tmp);
184                var->print(out);
185                out << " is uninitialized prior to entering ";
186                ifStatement->print(out);
187                report_fatal_error(out.str());
188            }
189            incoming.emplace_back(var, f->second);
190        }
191    }
192
193    const PabloBlock * ifBody = ifStatement->getBody();
194   
195    mCarryManager->enterIfScope(b, ifBody);
196
197    Value * condition = compileExpression(b, ifStatement->getCondition());
198    if (condition->getType() == b->getBitBlockType()) {
199        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
200    }
201   
202    b->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
203   
204    // Entry processing is complete, now handle the body of the if.
205    b->SetInsertPoint(ifBodyBlock);
206
207    mCarryManager->enterIfBody(b, ifEntryBlock);
208
209    addBranchCounter(b);
210
211    compileBlock(b, ifBody);
212
213    mCarryManager->leaveIfBody(b, b->GetInsertBlock());
214
215    BasicBlock * ifExitBlock = b->GetInsertBlock();
216
217    b->CreateBr(ifEndBlock);
218
219    ifEndBlock->moveAfter(ifExitBlock);
220
221    //End Block
222    b->SetInsertPoint(ifEndBlock);
223
224    mCarryManager->leaveIfScope(b, ifEntryBlock, ifExitBlock);
225
226    for (const auto i : incoming) {
227        const Var * var; Value * incoming;
228        std::tie(var, incoming) = i;
229
230        auto f = mMarker.find(var);
231        if (LLVM_UNLIKELY(f == mMarker.end())) {
232            std::string tmp;
233            raw_string_ostream out(tmp);
234            out << "PHINode creation error: ";
235            var->print(out);
236            out << " was not assigned an outgoing value.";
237            report_fatal_error(out.str());
238        }
239
240        Value * const outgoing = f->second;
241        if (LLVM_UNLIKELY(incoming == outgoing)) {
242            continue;
243        }
244
245        if (LLVM_UNLIKELY(incoming->getType() != outgoing->getType())) {
246            std::string tmp;
247            raw_string_ostream out(tmp);
248            out << "PHINode creation error: incoming type of ";
249            var->print(out);
250            out << " (";
251            incoming->getType()->print(out);
252            out << ") differs from the outgoing type (";
253            outgoing->getType()->print(out);
254            out << ") within ";
255            ifStatement->print(out);
256            report_fatal_error(out.str());
257        }
258
259        PHINode * phi = b->CreatePHI(incoming->getType(), 2, var->getName());
260        phi->addIncoming(incoming, ifEntryBlock);
261        phi->addIncoming(outgoing, ifExitBlock);
262        f->second = phi;
263    }
264
265    addBranchCounter(b);
266}
267
268void PabloCompiler::compileWhile(const std::unique_ptr<kernel::KernelBuilder> & b, const While * const whileStatement) {
269
270    const PabloBlock * const whileBody = whileStatement->getBody();
271
272    BasicBlock * whileEntryBlock = b->GetInsertBlock();
273
274    const auto escaped = whileStatement->getEscaped();
275
276#ifdef ENABLE_BOUNDED_WHILE
277    PHINode * bound_phi = nullptr;  // Needed for bounded while loops.
278#endif
279    // On entry to the while structure, proceed to execute the first iteration
280    // of the loop body unconditionally. The while condition is tested at the end of
281    // the loop.
282
283    for (const Var * var : escaped) {
284        if (LLVM_UNLIKELY(var->isKernelParameter())) {
285            Value * marker = nullptr;
286            if (var->isScalar()) {
287                marker = b->getScalarFieldPtr(var->getName());
288            } else if (var->isReadOnly()) {
289                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
290            } else if (var->isReadNone()) {
291                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
292            }
293            mMarker[var] = marker;
294        }
295    }
296
297    mCarryManager->enterLoopScope(b, whileBody);
298
299    BasicBlock * whileBodyBlock = b->CreateBasicBlock("while.body_" + std::to_string(mBranchCount));
300    BasicBlock * whileEndBlock = b->CreateBasicBlock("while.end_" + std::to_string(mBranchCount));
301    ++mBranchCount;
302
303    b->CreateBr(whileBodyBlock);
304
305    b->SetInsertPoint(whileBodyBlock);
306
307    //
308    // There are 3 sets of Phi nodes for the while loop.
309    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
310    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
311    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
312#ifdef ENABLE_BOUNDED_WHILE
313    // (4) The loop bound, if any.
314#endif
315
316    std::vector<std::pair<const Var *, PHINode *>> variants;
317
318    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
319    for (const auto var : escaped) {
320        auto f = mMarker.find(var);
321        if (LLVM_UNLIKELY(f == mMarker.end())) {
322            std::string tmp;
323            raw_string_ostream out(tmp);
324            out << "PHINode creation error: ";
325            var->print(out);
326            out << " is uninitialized prior to entering ";
327            whileStatement->print(out);
328            report_fatal_error(out.str());
329        }
330        Value * entryValue = f->second;
331        PHINode * phi = b->CreatePHI(entryValue->getType(), 2, var->getName());
332        phi->addIncoming(entryValue, whileEntryBlock);
333        f->second = phi;
334        assert(mMarker[var] == phi);
335        variants.emplace_back(var, phi);
336    }
337#ifdef ENABLE_BOUNDED_WHILE
338    if (whileStatement->getBound()) {
339        bound_phi = b->CreatePHI(b->getSizeTy(), 2, "while_bound");
340        bound_phi->addIncoming(b->getSize(whileStatement->getBound()), whileEntryBlock);
341    }
342#endif
343
344    mCarryManager->enterLoopBody(b, whileEntryBlock);
345
346    addBranchCounter(b);
347
348    compileBlock(b, whileBody);
349
350    // After the whileBody has been compiled, we may be in a different basic block.
351
352    mCarryManager->leaveLoopBody(b, b->GetInsertBlock());
353
354
355#ifdef ENABLE_BOUNDED_WHILE
356    if (whileStatement->getBound()) {
357        Value * new_bound = b->CreateSub(bound_phi, b->getSize(1));
358        bound_phi->addIncoming(new_bound, whileExitBlock);
359        condition = b->CreateAnd(condition, b->CreateICmpUGT(new_bound, ConstantInt::getNullValue(b->getSizeTy())));
360    }
361#endif
362
363    BasicBlock * const whileExitBlock = b->GetInsertBlock();
364
365    // and for any variant nodes in the loop body
366    for (const auto variant : variants) {
367        const Var * var; PHINode * incomingPhi;
368        std::tie(var, incomingPhi) = variant;
369        const auto f = mMarker.find(var);
370        if (LLVM_UNLIKELY(f == mMarker.end())) {
371            std::string tmp;
372            raw_string_ostream out(tmp);
373            out << "PHINode creation error: ";
374            var->print(out);
375            out << " is no longer assigned a value.";
376            report_fatal_error(out.str());
377        }
378
379        Value * const outgoingValue = f->second;
380
381        if (LLVM_UNLIKELY(incomingPhi->getType() != outgoingValue->getType())) {
382            std::string tmp;
383            raw_string_ostream out(tmp);
384            out << "PHINode creation error: incoming type of ";
385            var->print(out);
386            out << " (";
387            incomingPhi->getType()->print(out);
388            out << ") differs from the outgoing type (";
389            outgoingValue->getType()->print(out);
390            out << ") within ";
391            whileStatement->print(out);
392            report_fatal_error(out.str());
393        }
394
395        incomingPhi->addIncoming(outgoingValue, whileExitBlock);
396    }
397
398    // Terminate the while loop body with a conditional branch back.
399    Value * condition = compileExpression(b, whileStatement->getCondition());
400    if (condition->getType() == b->getBitBlockType()) {
401        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
402    }
403
404    b->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
405
406    whileEndBlock->moveAfter(whileExitBlock);
407
408    b->SetInsertPoint(whileEndBlock);
409
410    mCarryManager->leaveLoopScope(b, whileEntryBlock, whileExitBlock);
411
412    addBranchCounter(b);
413}
414
415void PabloCompiler::compileStatement(const std::unique_ptr<kernel::KernelBuilder> & b, const Statement * const stmt) {
416
417    if (LLVM_UNLIKELY(isa<If>(stmt))) {
418        compileIf(b, cast<If>(stmt));
419    } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
420        compileWhile(b, cast<While>(stmt));
421    } else {
422        const PabloAST * expr = stmt;
423        Value * value = nullptr;
424        if (isa<And>(stmt)) {
425            value = compileExpression(b, stmt->getOperand(0));
426            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
427                value = b->simd_and(value, compileExpression(b, stmt->getOperand(1)));
428            }
429        } else if (isa<Or>(stmt)) {
430            value = compileExpression(b, stmt->getOperand(0));
431            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
432                value = b->simd_or(value, compileExpression(b, stmt->getOperand(1)));
433            }
434        } else if (isa<Xor>(stmt)) {
435            value = compileExpression(b, stmt->getOperand(0));
436            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
437                value = b->simd_xor(value, compileExpression(b, stmt->getOperand(1)));
438            }
439        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
440            Value* ifMask = compileExpression(b, sel->getCondition());
441            Value* ifTrue = b->simd_and(ifMask, compileExpression(b, sel->getTrueExpr()));
442            Value* ifFalse = b->simd_and(b->simd_not(ifMask), compileExpression(b, sel->getFalseExpr()));
443            value = b->simd_or(ifTrue, ifFalse);
444        } else if (isa<Not>(stmt)) {
445            value = b->simd_not(compileExpression(b, stmt->getOperand(0)));
446        } else if (isa<Advance>(stmt)) {
447            const Advance * const adv = cast<Advance>(stmt);
448            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
449            // manager so that it properly selects the correct carry bit.
450            value = mCarryManager->advanceCarryInCarryOut(b, adv, compileExpression(b, adv->getExpression()));
451        } else if (isa<IndexedAdvance>(stmt)) {
452            const IndexedAdvance * const adv = cast<IndexedAdvance>(stmt);
453            Value * strm = compileExpression(b, adv->getExpression());
454            Value * index_strm = compileExpression(b, adv->getIndex());
455            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
456            // manager so that it properly selects the correct carry bit.
457            value = mCarryManager->indexedAdvanceCarryInCarryOut(b, adv, strm, index_strm);
458        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
459            Value * const marker = compileExpression(b, mstar->getMarker());
460            Value * const cc = compileExpression(b, mstar->getCharClass());
461            Value * const marker_and_cc = b->simd_and(marker, cc);
462            Value * const sum = mCarryManager->addCarryInCarryOut(b, mstar, marker_and_cc, cc);
463            value = b->simd_or(b->simd_xor(sum, cc), marker);
464        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
465            Value * const from = compileExpression(b, sthru->getScanFrom());
466            Value * const thru = compileExpression(b, sthru->getScanThru());
467            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, thru);
468            value = b->simd_and(sum, b->simd_not(thru));
469        } else if (const ScanTo * sthru = dyn_cast<ScanTo>(stmt)) {
470            Value * const marker_expr = compileExpression(b, sthru->getScanFrom());
471            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
472            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, marker_expr, b->simd_not(to));
473            value = b->simd_and(sum, to);
474        } else if (const AdvanceThenScanThru * sthru = dyn_cast<AdvanceThenScanThru>(stmt)) {
475            Value * const from = compileExpression(b, sthru->getScanFrom());
476            Value * const thru = compileExpression(b, sthru->getScanThru());
477            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, thru));
478            value = b->simd_and(sum, b->simd_not(thru));
479        } else if (const AdvanceThenScanTo * sthru = dyn_cast<AdvanceThenScanTo>(stmt)) {
480            Value * const from = compileExpression(b, sthru->getScanFrom());
481            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
482            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, b->simd_not(to)));
483            value = b->simd_and(sum, to);
484        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
485            expr = cast<Assign>(stmt)->getVariable();
486            value = compileExpression(b, cast<Assign>(stmt)->getValue());
487            if (isa<Extract>(expr) || (isa<Var>(expr) && cast<Var>(expr)->isKernelParameter())) {
488                Value * const ptr = compileExpression(b, expr, false);
489                b->CreateAlignedStore(value, ptr, getAlignment(value));
490                value = ptr;
491            }
492        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
493            Value * EOFmask = b->getScalarField("EOFmask");
494            value = b->simd_and(compileExpression(b, e->getExpr()), b->simd_not(EOFmask));
495        } else if (const AtEOF * e = dyn_cast<AtEOF>(stmt)) {
496            Value * EOFbit = b->getScalarField("EOFbit");
497            value = b->simd_and(compileExpression(b, e->getExpr()), EOFbit);
498        } else if (const Count * c = dyn_cast<Count>(stmt)) {
499            Value * EOFbit = b->getScalarField("EOFbit");
500            Value * EOFmask = b->getScalarField("EOFmask");
501            Value * const to_count = b->simd_and(b->simd_or(b->simd_not(EOFmask), EOFbit), compileExpression(b, c->getExpr()));
502            const unsigned counterSize = b->getSizeTy()->getBitWidth();
503            const auto f = mAccumulator.find(c);
504            if (LLVM_UNLIKELY(f == mAccumulator.end())) {
505                report_fatal_error("Unknown accumulator: " + c->getName().str());
506            }
507            Value * ptr = b->getScalarFieldPtr(f->second);
508            const auto alignment = getPointerElementAlignment(ptr);
509            Value * countSoFar = b->CreateAlignedLoad(ptr, alignment, c->getName() + "_accumulator");
510            auto fields = (b->getBitBlockWidth() / counterSize);
511            Value * fieldCounts = b->simd_popcount(counterSize, to_count);
512            while (fields > 1) {
513                fields = fields/2;
514                fieldCounts = b->CreateAdd(fieldCounts, b->mvmd_srli(counterSize, fieldCounts, fields));
515            }
516            value = b->CreateAdd(b->mvmd_extract(counterSize, fieldCounts, 0), countSoFar, "countSoFar");
517            b->CreateAlignedStore(value, ptr, alignment);
518        } else if (const Lookahead * l = dyn_cast<Lookahead>(stmt)) {
519            PabloAST * stream = l->getExpression();
520            Value * index = nullptr;
521            if (LLVM_UNLIKELY(isa<Extract>(stream))) {               
522                index = compileExpression(b, cast<Extract>(stream)->getIndex(), true);
523                stream = cast<Extract>(stream)->getArray();
524            } else {
525                index = b->getInt32(0);
526            }
527            const auto bit_shift = (l->getAmount() % b->getBitBlockWidth());
528            const auto block_shift = (l->getAmount() / b->getBitBlockWidth());
529            Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift));
530            Value * lookAhead = b->CreateBlockAlignedLoad(ptr);
531            if (bit_shift == 0) {  // Simple case with no intra-block shifting.
532                value = lookAhead;
533            } else { // Need to form shift result from two adjacent blocks.
534                Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift + 1));
535                Value * lookAhead1 = b->CreateBlockAlignedLoad(ptr);
536                if (LLVM_UNLIKELY((bit_shift % 8) == 0)) { // Use a single whole-byte shift, if possible.
537                    value = b->mvmd_dslli(8, lookAhead1, lookAhead, (bit_shift / 8));
538                } else {
539                    Type  * const streamType = b->getIntNTy(b->getBitBlockWidth());
540                    Value * b1 = b->CreateBitCast(lookAhead1, streamType);
541                    Value * b0 = b->CreateBitCast(lookAhead, streamType);
542                    Value * result = b->CreateOr(b->CreateShl(b1, b->getBitBlockWidth() - bit_shift), b->CreateLShr(b0, bit_shift));
543                    value = b->CreateBitCast(result, b->getBitBlockType());
544                }
545            }
546        } else if (const Repeat * const s = dyn_cast<Repeat>(stmt)) {
547            value = compileExpression(b, s->getValue());
548            Type * const ty = s->getType();
549            if (LLVM_LIKELY(ty->isVectorTy())) {
550                const auto fw = s->getFieldWidth()->value();
551                value = b->CreateZExtOrTrunc(value, b->getIntNTy(fw));
552                value = b->simd_fill(fw, value);
553            } else {
554                value = b->CreateZExtOrTrunc(value, ty);
555            }
556        #if 0
557        } else if (const PackH * const p = dyn_cast<PackH>(stmt)) {
558            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
559            const auto packedWidth = p->getFieldWidth()->value();
560            Value * const base = compileExpression(b, p->getValue(), false);
561            const auto packs = sourceWidth / 2;
562            if (LLVM_LIKELY(packs > 1)) {
563                value = b->CreateAlloca(b->getBitBlockType(), b->getInt32(packs));
564            }
565            Constant * const ZERO = b->getInt32(0);
566            for (unsigned i = 0; i < packs; ++i) {
567                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
568                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
569                Value * P = b->hsimd_packh(packedWidth, A, B);
570                if (LLVM_UNLIKELY(packs == 1)) {
571                    value = P;
572                    break;
573                }
574                b->CreateStore(P, b->CreateGEP(value, b->getInt32(i)));
575            }
576        } else if (const PackL * const p = dyn_cast<PackL>(stmt)) {
577            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
578            const auto packedWidth = p->getFieldWidth()->value();
579            Value * const base = compileExpression(b, p->getValue(), false);
580            const auto count = sourceWidth / 2;
581            if (LLVM_LIKELY(count > 1)) {
582                value = b->CreateAlloca(b->getBitBlockType(), b->getInt32(count));
583            }
584            Constant * const ZERO = b->getInt32(0);
585            for (unsigned i = 0; i < count; ++i) {
586                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
587                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
588                Value * P = b->hsimd_packl(packedWidth, A, B);
589                if (LLVM_UNLIKELY(count == 1)) {
590                    value = P;
591                    break;
592                }
593                b->CreateStore(P, b->CreateGEP(value, b->getInt32(i)));
594            }
595        #endif
596        } else {
597            std::string tmp;
598            raw_string_ostream out(tmp);
599            out << "PabloCompiler: ";
600            stmt->print(out);
601            out << " was not recognized by the compiler";
602            report_fatal_error(out.str());
603        }
604        assert (expr);
605        assert (value);
606        mMarker[expr] = value;
607        if (DebugOptionIsSet(DumpTrace)) {
608            std::string tmp;
609            raw_string_ostream name(tmp);
610            expr->print(name);
611            if (value->getType()->isVectorTy()) {
612                b->CallPrintRegister(name.str(), value);
613            } else if (value->getType()->isIntegerTy()) {
614                b->CallPrintInt(name.str(), value);
615            }
616        }
617    }
618}
619
620unsigned getIntegerBitWidth(const Type * ty) {
621    if (ty->isArrayTy()) {
622        assert (ty->getArrayNumElements() == 1);
623        ty = ty->getArrayElementType();
624    }
625    if (ty->isVectorTy()) {
626        assert (ty->getVectorNumElements() == 0);
627        ty = ty->getVectorElementType();
628    }
629    return ty->getIntegerBitWidth();
630}
631
632Value * PabloCompiler::compileExpression(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloAST * const expr, const bool ensureLoaded) {
633    const auto f = mMarker.find(expr);   
634    Value * value = nullptr;
635    if (LLVM_LIKELY(f != mMarker.end())) {
636        value = f->second;
637    } else {
638        if (isa<Integer>(expr)) {
639            value = ConstantInt::get(cast<Integer>(expr)->getType(), cast<Integer>(expr)->value());
640        } else if (isa<Zeroes>(expr)) {
641            value = b->allZeroes();
642        } else if (LLVM_UNLIKELY(isa<Ones>(expr))) {
643            value = b->allOnes();
644        } else if (isa<Extract>(expr)) {
645            const Extract * const extract = cast<Extract>(expr);
646            const Var * const var = cast<Var>(extract->getArray());
647            Value * const index = compileExpression(b, extract->getIndex());
648            value = getPointerToVar(b, var, index);
649        } else if (LLVM_UNLIKELY(isa<Var>(expr))) {
650            const Var * const var = cast<Var>(expr);
651            if (LLVM_LIKELY(var->isKernelParameter() && var->isScalar())) {
652                value = b->getScalarFieldPtr(var->getName());
653            } else { // use before def error
654                std::string tmp;
655                raw_string_ostream out(tmp);
656                out << "PabloCompiler: ";
657                expr->print(out);
658                out << " is not a scalar value or was used before definition";
659                report_fatal_error(out.str());
660            }
661        } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
662            const Operator * const op = cast<Operator>(expr);
663            const PabloAST * lh = op->getLH();
664            const PabloAST * rh = op->getRH();
665            if ((isa<Var>(lh) || isa<Extract>(lh)) || (isa<Var>(rh) || isa<Extract>(rh))) {
666                const unsigned n = std::min(getIntegerBitWidth(lh->getType()), getIntegerBitWidth(rh->getType()));
667                const unsigned m = b->getBitBlockWidth() / n;
668                IntegerType * const fw = b->getIntNTy(m);
669                VectorType * const vTy = VectorType::get(b->getIntNTy(n), m);
670
671                Value * baseLhv = nullptr;
672                Value * lhvStreamIndex = nullptr;
673                if (isa<Var>(lh)) {
674                    lhvStreamIndex = b->getInt32(0);
675                } else if (isa<Extract>(lh)) {
676                    lhvStreamIndex = compileExpression(b, cast<Extract>(lh)->getIndex());
677                } else {
678                    baseLhv = compileExpression(b, lh);
679                }
680
681                Value * baseRhv = nullptr;
682                Value * rhvStreamIndex = nullptr;
683                if (isa<Var>(rh)) {
684                    rhvStreamIndex = b->getInt32(0);
685                } else if (isa<Extract>(lh)) {
686                    rhvStreamIndex = compileExpression(b, cast<Extract>(rh)->getIndex());
687                } else {
688                    baseRhv = compileExpression(b, rh);
689                }
690
691                const TypeId typeId = op->getClassTypeId();
692
693                if (LLVM_UNLIKELY(typeId == TypeId::Add || typeId == TypeId::Subtract)) {
694
695                    value = b->CreateAlloca(vTy, b->getInt32(n));
696
697                    for (unsigned i = 0; i < n; ++i) {
698                        llvm::Constant * const index = b->getInt32(i);
699                        Value * lhv = nullptr;
700                        if (baseLhv) {
701                            lhv = baseLhv;
702                        } else {
703                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
704                            lhv = b->CreateBlockAlignedLoad(lhv);
705                        }
706                        lhv = b->CreateBitCast(lhv, vTy);
707
708                        Value * rhv = nullptr;
709                        if (baseRhv) {
710                            rhv = baseRhv;
711                        } else {
712                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
713                            rhv = b->CreateBlockAlignedLoad(rhv);
714                        }
715                        rhv = b->CreateBitCast(rhv, vTy);
716
717                        Value * result = nullptr;
718                        if (typeId == TypeId::Add) {
719                            result = b->CreateAdd(lhv, rhv);
720                        } else { // if (typeId == TypeId::Subtract) {
721                            result = b->CreateSub(lhv, rhv);
722                        }
723                        b->CreateAlignedStore(result, b->CreateGEP(value, {b->getInt32(0), b->getInt32(i)}), getAlignment(result));
724                    }
725
726                } else {
727
728                    value = UndefValue::get(VectorType::get(fw, n));
729
730                    for (unsigned i = 0; i < n; ++i) {
731                        llvm::Constant * const index = b->getInt32(i);
732                        Value * lhv = nullptr;
733                        if (baseLhv) {
734                            lhv = baseLhv;
735                        } else {
736                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
737                            lhv = b->CreateBlockAlignedLoad(lhv);
738                        }
739                        lhv = b->CreateBitCast(lhv, vTy);
740
741                        Value * rhv = nullptr;
742                        if (baseRhv) {
743                            rhv = baseRhv;
744                        } else {
745                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
746                            rhv = b->CreateBlockAlignedLoad(rhv);
747                        }
748                        rhv = b->CreateBitCast(rhv, vTy);
749
750                        Value * comp = nullptr;
751                        switch (typeId) {
752                            case TypeId::GreaterThanEquals:
753                            case TypeId::LessThan:
754                                comp = b->simd_ult(n, lhv, rhv);
755                                break;
756                            case TypeId::Equals:
757                            case TypeId::NotEquals:
758                                comp = b->simd_eq(n, lhv, rhv);
759                                break;
760                            case TypeId::LessThanEquals:
761                            case TypeId::GreaterThan:
762                                comp = b->simd_ugt(n, lhv, rhv);
763                                break;
764                            default: llvm_unreachable("invalid vector operator id");
765                        }
766                        Value * const mask = b->CreateZExtOrTrunc(b->hsimd_signmask(n, comp), fw);
767                        value = b->mvmd_insert(m, value, mask, i);
768                    }
769
770                    value = b->CreateBitCast(value, b->getBitBlockType());
771                    switch (typeId) {
772                        case TypeId::GreaterThanEquals:
773                        case TypeId::LessThanEquals:
774                        case TypeId::NotEquals:
775                            value = b->simd_not(value);
776                        default: break;
777                    }
778                }
779
780            } else {
781                Value * const lhv = compileExpression(b, lh);
782                Value * const rhv = compileExpression(b, rh);
783                switch (op->getClassTypeId()) {
784                    case TypeId::Add:
785                        value = b->CreateAdd(lhv, rhv); break;
786                    case TypeId::Subtract:
787                        value = b->CreateSub(lhv, rhv); break;
788                    case TypeId::LessThan:
789                        value = b->CreateICmpSLT(lhv, rhv); break;
790                    case TypeId::LessThanEquals:
791                        value = b->CreateICmpSLE(lhv, rhv); break;
792                    case TypeId::Equals:
793                        value = b->CreateICmpEQ(lhv, rhv); break;
794                    case TypeId::GreaterThanEquals:
795                        value = b->CreateICmpSGE(lhv, rhv); break;
796                    case TypeId::GreaterThan:
797                        value = b->CreateICmpSGT(lhv, rhv); break;
798                    case TypeId::NotEquals:
799                        value = b->CreateICmpNE(lhv, rhv); break;
800                    default: llvm_unreachable("invalid scalar operator id");
801                }
802            }
803        } else { // use before def error
804            std::string tmp;
805            raw_string_ostream out(tmp);
806            out << "PabloCompiler: ";
807            expr->print(out);
808            out << " was used before definition";
809            report_fatal_error(out.str());
810        }
811        assert (value);
812        // mMarker.insert({expr, value});
813    }
814    if (LLVM_UNLIKELY(value->getType()->isPointerTy() && ensureLoaded)) {
815        value = b->CreateAlignedLoad(value, getPointerElementAlignment(value));
816    }
817    return value;
818}
819
820Value * PabloCompiler::getPointerToVar(const std::unique_ptr<kernel::KernelBuilder> & b, const Var * var, Value * index1, Value * index2)  {
821    assert (var && index1 && (index2 == nullptr || index1->getType() == index2->getType()));
822    if (LLVM_LIKELY(var->isKernelParameter())) {
823        if (LLVM_UNLIKELY(var->isScalar())) {
824            std::string tmp;
825            raw_string_ostream out(tmp);
826            out << mKernel->getName();
827            out << ": cannot index scalar value ";
828            var->print(out);
829            report_fatal_error(out.str());
830        } else if (var->isReadOnly()) {
831            if (index2) {
832                return b->getInputStreamPackPtr(var->getName(), index1, index2);
833            } else {
834                return b->getInputStreamBlockPtr(var->getName(), index1);
835            }
836        } else if (var->isReadNone()) {
837            if (index2) {
838                return b->getOutputStreamPackPtr(var->getName(), index1, index2);
839            } else {
840                return b->getOutputStreamBlockPtr(var->getName(), index1);
841            }
842        } else {
843            std::string tmp;
844            raw_string_ostream out(tmp);
845            out << mKernel->getName();
846            out << ": stream ";
847            var->print(out);
848            out << " cannot be read from or written to";
849            report_fatal_error(out.str());
850        }
851    } else {
852        Value * const ptr = compileExpression(b, var, false);
853        std::vector<Value *> offsets;
854        offsets.push_back(ConstantInt::getNullValue(index1->getType()));
855        offsets.push_back(index1);
856        if (index2) offsets.push_back(index2);
857        return b->CreateGEP(ptr, offsets);
858    }
859}
860
861PabloCompiler::PabloCompiler(PabloKernel * const kernel)
862: mKernel(kernel)
863, mCarryManager(new CarryManager)
864, mBranchCount(0) {
865    assert ("PabloKernel cannot be null!" && kernel);
866}
867
868PabloCompiler::~PabloCompiler() {
869    delete mCarryManager;
870}
871
872}
Note: See TracBrowser for help on using the repository browser.