source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 5828

Last change on this file since 5828 was 5828, checked in by nmedfort, 12 months ago

Pablo support for byte comparisions; LineFeed? kernel processes byte streams directly. Some clean up of PabloBuilder? functionality.

File size: 37.8 KB
Line 
1/*
2 *  Copyright (c) 2014-16 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "pablo_compiler.h"
8#include <pablo/pablo_kernel.h>
9#include <pablo/pablo_toolchain.h>
10#include <pablo/codegenstate.h>
11#include <pablo/boolean.h>
12#include <pablo/arithmetic.h>
13#include <pablo/branch.h>
14#include <pablo/pe_advance.h>
15#include <pablo/pe_lookahead.h>
16#include <pablo/pe_matchstar.h>
17#include <pablo/pe_scanthru.h>
18#include <pablo/pe_infile.h>
19#include <pablo/pe_count.h>
20#include <pablo/pe_integer.h>
21#include <pablo/pe_string.h>
22#include <pablo/pe_zeroes.h>
23#include <pablo/pe_ones.h>
24#include <pablo/pe_repeat.h>
25#include <pablo/pe_pack.h>
26#include <pablo/pe_var.h>
27#include <pablo/ps_assign.h>
28#ifdef USE_CARRYPACK_MANAGER
29#include <pablo/carrypack_manager.h>
30#else
31#include <pablo/carry_manager.h>
32#endif
33#include <kernels/kernel_builder.h>
34#include <kernels/streamset.h>
35#include <llvm/IR/Module.h>
36#include <llvm/Support/raw_os_ostream.h>
37
38using namespace llvm;
39
40namespace pablo {
41
42using TypeId = PabloAST::ClassTypeId;
43
44inline static unsigned getAlignment(const Value * const type) {
45    return type->getType()->getPrimitiveSizeInBits() / 8;
46}
47
48inline static unsigned getPointerElementAlignment(const Value * const ptr) {
49    return ptr->getType()->getPointerElementType()->getPrimitiveSizeInBits() / 8;
50}
51
52void PabloCompiler::initializeKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
53    mBranchCount = 0;
54    examineBlock(b, mKernel->getEntryBlock());
55    mCarryManager->initializeCarryData(b, mKernel);
56    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {
57        const auto count = (mBranchCount * 2) + 1;
58        mKernel->addScalar(ArrayType::get(mKernel->getSizeTy(), count), "profile");
59        mBasicBlock.reserve(count);
60    }
61}
62
63void PabloCompiler::releaseKernelData(const std::unique_ptr<kernel::KernelBuilder> & b) {
64    mCarryManager->releaseCarryData(b);
65}
66
67void PabloCompiler::clearCarryData(const std::unique_ptr<kernel::KernelBuilder> & b) {
68    mCarryManager->clearCarryData(b);
69}
70
71void PabloCompiler::compile(const std::unique_ptr<kernel::KernelBuilder> & b) {
72    mCarryManager->initializeCodeGen(b);
73    PabloBlock * const entryBlock = mKernel->getEntryBlock(); assert (entryBlock);
74    mMarker.emplace(entryBlock->createZeroes(), b->allZeroes());
75    mMarker.emplace(entryBlock->createOnes(), b->allOnes());
76    mBranchCount = 0;
77    addBranchCounter(b);
78    compileBlock(b, entryBlock);
79    mCarryManager->finalizeCodeGen(b);
80}
81
82void PabloCompiler::examineBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
83    for (const Statement * stmt : *block) {
84        if (LLVM_UNLIKELY(isa<Lookahead>(stmt))) {
85            const Lookahead * const la = cast<Lookahead>(stmt);
86            PabloAST * input = la->getExpression();
87            if (isa<Extract>(input)) {
88                input = cast<Extract>(input)->getArray();
89            }
90            bool notFound = true;
91            if (LLVM_LIKELY(isa<Var>(input))) {
92                for (unsigned i = 0; i < mKernel->getNumOfInputs(); ++i) {
93                    if (input == mKernel->getInput(i)) {
94                        const auto & binding = mKernel->getStreamInput(i);
95                        if (LLVM_UNLIKELY(!binding.hasLookahead() || binding.getLookahead() < la->getAmount())) {
96                            std::string tmp;
97                            raw_string_ostream out(tmp);
98                            input->print(out);
99                            out << " must have a lookahead attribute of at least " << la->getAmount();
100                            report_fatal_error(out.str());
101                        }
102                        notFound = false;
103                        break;
104                    }
105                }
106            }
107            if (LLVM_UNLIKELY(notFound)) {
108                report_fatal_error("Lookahead " + stmt->getName() + " can only be performed on an input streamset");
109            }
110        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
111            ++mBranchCount;
112            examineBlock(b, cast<Branch>(stmt)->getBody());
113        } else if (LLVM_UNLIKELY(isa<Count>(stmt))) {
114            mAccumulator.insert(std::make_pair(stmt, b->getInt32(mKernel->addUnnamedScalar(stmt->getType()))));
115        }
116    }   
117}
118
119void PabloCompiler::addBranchCounter(const std::unique_ptr<kernel::KernelBuilder> & b) {
120    if (CompileOptionIsSet(PabloCompilationFlags::EnableProfiling)) {       
121        Value * ptr = b->getScalarFieldPtr("profile");
122        assert (mBasicBlock.size() < ptr->getType()->getPointerElementType()->getArrayNumElements());
123        ptr = b->CreateGEP(ptr, {b->getInt32(0), b->getInt32(mBasicBlock.size())});
124        const auto alignment = getPointerElementAlignment(ptr);
125        Value * value = b->CreateAlignedLoad(ptr, alignment, false, "branchCounter");
126        value = b->CreateAdd(value, ConstantInt::get(cast<IntegerType>(value->getType()), 1));
127        b->CreateAlignedStore(value, ptr, alignment);
128        mBasicBlock.push_back(b->GetInsertBlock());
129    }
130}
131
132inline void PabloCompiler::compileBlock(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloBlock * const block) {
133    for (const Statement * statement : *block) {
134        compileStatement(b, statement);
135    }
136}
137
138void PabloCompiler::compileIf(const std::unique_ptr<kernel::KernelBuilder> & b, const If * const ifStatement) {
139    //
140    //  The If-ElseZero stmt:
141    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
142    //  If the value of the predicate is nonzero, then determine the values of variables
143    //  <var>* by executing the given statements.  Otherwise, the value of the
144    //  variables are all zero.  Requirements: (a) no variable that is defined within
145    //  the body of the if may be accessed outside unless it is explicitly
146    //  listed in the variable list, (b) every variable in the defined list receives
147    //  a value within the body, and (c) the logical consequence of executing
148    //  the statements in the event that the predicate is zero is that the
149    //  values of all defined variables indeed work out to be 0.
150    //
151    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
152    //  is inserted for each variable in the defined variable list.  It receives
153    //  a zero value from the ifentry block and the defined value from the if
154    //  body.
155    //
156
157    BasicBlock * const ifEntryBlock = b->GetInsertBlock();
158    ++mBranchCount;
159    BasicBlock * const ifBodyBlock = b->CreateBasicBlock("if.body_" + std::to_string(mBranchCount));
160    BasicBlock * const ifEndBlock = b->CreateBasicBlock("if.end_" + std::to_string(mBranchCount));
161   
162    std::vector<std::pair<const Var *, Value *>> incoming;
163
164    for (const Var * var : ifStatement->getEscaped()) {
165        if (LLVM_UNLIKELY(var->isKernelParameter())) {
166            Value * marker = nullptr;
167            if (var->isScalar()) {
168                marker = b->getScalarFieldPtr(var->getName());
169            } else if (var->isReadOnly()) {
170                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
171            } else if (var->isReadNone()) {
172                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
173            }
174            mMarker[var] = marker;
175        } else {
176            auto f = mMarker.find(var);
177            if (LLVM_UNLIKELY(f == mMarker.end())) {
178                std::string tmp;
179                raw_string_ostream out(tmp);
180                var->print(out);
181                out << " is uninitialized prior to entering ";
182                ifStatement->print(out);
183                report_fatal_error(out.str());
184            }
185            incoming.emplace_back(var, f->second);
186        }
187    }
188
189    const PabloBlock * ifBody = ifStatement->getBody();
190   
191    mCarryManager->enterIfScope(b, ifBody);
192
193    Value * condition = compileExpression(b, ifStatement->getCondition());
194    if (condition->getType() == b->getBitBlockType()) {
195        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
196    }
197   
198    b->CreateCondBr(condition, ifBodyBlock, ifEndBlock);
199   
200    // Entry processing is complete, now handle the body of the if.
201    b->SetInsertPoint(ifBodyBlock);
202
203    mCarryManager->enterIfBody(b, ifEntryBlock);
204
205    addBranchCounter(b);
206
207    compileBlock(b, ifBody);
208
209    mCarryManager->leaveIfBody(b, b->GetInsertBlock());
210
211    BasicBlock * ifExitBlock = b->GetInsertBlock();
212
213    b->CreateBr(ifEndBlock);
214
215    ifEndBlock->moveAfter(ifExitBlock);
216
217    //End Block
218    b->SetInsertPoint(ifEndBlock);
219
220    mCarryManager->leaveIfScope(b, ifEntryBlock, ifExitBlock);
221
222    for (const auto i : incoming) {
223        const Var * var; Value * incoming;
224        std::tie(var, incoming) = i;
225
226        auto f = mMarker.find(var);
227        if (LLVM_UNLIKELY(f == mMarker.end())) {
228            std::string tmp;
229            raw_string_ostream out(tmp);
230            out << "PHINode creation error: ";
231            var->print(out);
232            out << " was not assigned an outgoing value.";
233            report_fatal_error(out.str());
234        }
235
236        Value * const outgoing = f->second;
237        if (LLVM_UNLIKELY(incoming == outgoing)) {
238            continue;
239        }
240
241        if (LLVM_UNLIKELY(incoming->getType() != outgoing->getType())) {
242            std::string tmp;
243            raw_string_ostream out(tmp);
244            out << "PHINode creation error: incoming type of ";
245            var->print(out);
246            out << " (";
247            incoming->getType()->print(out);
248            out << ") differs from the outgoing type (";
249            outgoing->getType()->print(out);
250            out << ") within ";
251            ifStatement->print(out);
252            report_fatal_error(out.str());
253        }
254
255        PHINode * phi = b->CreatePHI(incoming->getType(), 2, var->getName());
256        phi->addIncoming(incoming, ifEntryBlock);
257        phi->addIncoming(outgoing, ifExitBlock);
258        f->second = phi;
259    }
260
261    addBranchCounter(b);
262}
263
264void PabloCompiler::compileWhile(const std::unique_ptr<kernel::KernelBuilder> & b, const While * const whileStatement) {
265
266    const PabloBlock * const whileBody = whileStatement->getBody();
267
268    BasicBlock * whileEntryBlock = b->GetInsertBlock();
269
270    const auto escaped = whileStatement->getEscaped();
271
272#ifdef ENABLE_BOUNDED_WHILE
273    PHINode * bound_phi = nullptr;  // Needed for bounded while loops.
274#endif
275    // On entry to the while structure, proceed to execute the first iteration
276    // of the loop body unconditionally. The while condition is tested at the end of
277    // the loop.
278
279    for (const Var * var : escaped) {
280        if (LLVM_UNLIKELY(var->isKernelParameter())) {
281            Value * marker = nullptr;
282            if (var->isScalar()) {
283                marker = b->getScalarFieldPtr(var->getName());
284            } else if (var->isReadOnly()) {
285                marker = b->getInputStreamBlockPtr(var->getName(), b->getInt32(0));
286            } else if (var->isReadNone()) {
287                marker = b->getOutputStreamBlockPtr(var->getName(), b->getInt32(0));
288            }
289            mMarker[var] = marker;
290        }
291    }
292
293    mCarryManager->enterLoopScope(b, whileBody);
294
295    BasicBlock * whileBodyBlock = b->CreateBasicBlock("while.body_" + std::to_string(mBranchCount));
296    BasicBlock * whileEndBlock = b->CreateBasicBlock("while.end_" + std::to_string(mBranchCount));
297    ++mBranchCount;
298
299    b->CreateBr(whileBodyBlock);
300
301    b->SetInsertPoint(whileBodyBlock);
302
303    //
304    // There are 3 sets of Phi nodes for the while loop.
305    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
306    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
307    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
308#ifdef ENABLE_BOUNDED_WHILE
309    // (4) The loop bound, if any.
310#endif
311
312    std::vector<std::pair<const Var *, PHINode *>> variants;
313
314    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
315    for (const auto var : escaped) {
316        auto f = mMarker.find(var);
317        if (LLVM_UNLIKELY(f == mMarker.end())) {
318            std::string tmp;
319            raw_string_ostream out(tmp);
320            out << "PHINode creation error: ";
321            var->print(out);
322            out << " is uninitialized prior to entering ";
323            whileStatement->print(out);
324            report_fatal_error(out.str());
325        }
326        Value * entryValue = f->second;
327        PHINode * phi = b->CreatePHI(entryValue->getType(), 2, var->getName());
328        phi->addIncoming(entryValue, whileEntryBlock);
329        f->second = phi;
330        assert(mMarker[var] == phi);
331        variants.emplace_back(var, phi);
332    }
333#ifdef ENABLE_BOUNDED_WHILE
334    if (whileStatement->getBound()) {
335        bound_phi = b->CreatePHI(b->getSizeTy(), 2, "while_bound");
336        bound_phi->addIncoming(b->getSize(whileStatement->getBound()), whileEntryBlock);
337    }
338#endif
339
340    mCarryManager->enterLoopBody(b, whileEntryBlock);
341
342    addBranchCounter(b);
343
344    compileBlock(b, whileBody);
345
346    // After the whileBody has been compiled, we may be in a different basic block.
347
348    mCarryManager->leaveLoopBody(b, b->GetInsertBlock());
349
350
351#ifdef ENABLE_BOUNDED_WHILE
352    if (whileStatement->getBound()) {
353        Value * new_bound = b->CreateSub(bound_phi, b->getSize(1));
354        bound_phi->addIncoming(new_bound, whileExitBlock);
355        condition = b->CreateAnd(condition, b->CreateICmpUGT(new_bound, ConstantInt::getNullValue(b->getSizeTy())));
356    }
357#endif
358
359    BasicBlock * const whileExitBlock = b->GetInsertBlock();
360
361    // and for any variant nodes in the loop body
362    for (const auto variant : variants) {
363        const Var * var; PHINode * incomingPhi;
364        std::tie(var, incomingPhi) = variant;
365        const auto f = mMarker.find(var);
366        if (LLVM_UNLIKELY(f == mMarker.end())) {
367            std::string tmp;
368            raw_string_ostream out(tmp);
369            out << "PHINode creation error: ";
370            var->print(out);
371            out << " is no longer assigned a value.";
372            report_fatal_error(out.str());
373        }
374
375        Value * const outgoingValue = f->second;
376
377        if (LLVM_UNLIKELY(incomingPhi->getType() != outgoingValue->getType())) {
378            std::string tmp;
379            raw_string_ostream out(tmp);
380            out << "PHINode creation error: incoming type of ";
381            var->print(out);
382            out << " (";
383            incomingPhi->getType()->print(out);
384            out << ") differs from the outgoing type (";
385            outgoingValue->getType()->print(out);
386            out << ") within ";
387            whileStatement->print(out);
388            report_fatal_error(out.str());
389        }
390
391        incomingPhi->addIncoming(outgoingValue, whileExitBlock);
392    }
393
394    // Terminate the while loop body with a conditional branch back.
395    Value * condition = compileExpression(b, whileStatement->getCondition());
396    if (condition->getType() == b->getBitBlockType()) {
397        condition = b->bitblock_any(mCarryManager->generateSummaryTest(b, condition));
398    }
399
400    b->CreateCondBr(condition, whileBodyBlock, whileEndBlock);
401
402    whileEndBlock->moveAfter(whileExitBlock);
403
404    b->SetInsertPoint(whileEndBlock);
405
406    mCarryManager->leaveLoopScope(b, whileEntryBlock, whileExitBlock);
407
408    addBranchCounter(b);
409}
410
411void PabloCompiler::compileStatement(const std::unique_ptr<kernel::KernelBuilder> & b, const Statement * const stmt) {
412
413    if (LLVM_UNLIKELY(isa<If>(stmt))) {
414        compileIf(b, cast<If>(stmt));
415    } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
416        compileWhile(b, cast<While>(stmt));
417    } else {
418        const PabloAST * expr = stmt;
419        Value * value = nullptr;
420        if (isa<And>(stmt)) {
421            value = compileExpression(b, stmt->getOperand(0));
422            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
423                value = b->simd_and(value, compileExpression(b, stmt->getOperand(1)));
424            }
425        } else if (isa<Or>(stmt)) {
426            value = compileExpression(b, stmt->getOperand(0));
427            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
428                value = b->simd_or(value, compileExpression(b, stmt->getOperand(1)));
429            }
430        } else if (isa<Xor>(stmt)) {
431            value = compileExpression(b, stmt->getOperand(0));
432            for (unsigned i = 1; i < stmt->getNumOperands(); ++i) {
433                value = b->simd_xor(value, compileExpression(b, stmt->getOperand(1)));
434            }
435        } else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
436            Value* ifMask = compileExpression(b, sel->getCondition());
437            Value* ifTrue = b->simd_and(ifMask, compileExpression(b, sel->getTrueExpr()));
438            Value* ifFalse = b->simd_and(b->simd_not(ifMask), compileExpression(b, sel->getFalseExpr()));
439            value = b->simd_or(ifTrue, ifFalse);
440        } else if (isa<Not>(stmt)) {
441            value = b->simd_not(compileExpression(b, stmt->getOperand(0)));
442        } else if (isa<Advance>(stmt)) {
443            const Advance * const adv = cast<Advance>(stmt);
444            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
445            // manager so that it properly selects the correct carry bit.
446            value = mCarryManager->advanceCarryInCarryOut(b, adv, compileExpression(b, adv->getExpression()));
447        } else if (isa<IndexedAdvance>(stmt)) {
448            const IndexedAdvance * const adv = cast<IndexedAdvance>(stmt);
449            Value * strm = compileExpression(b, adv->getExpression());
450            Value * index_strm = compileExpression(b, adv->getIndex());
451            // If our expr is an Extract op on a mutable Var then we need to pass the index value to the carry
452            // manager so that it properly selects the correct carry bit.
453            value = mCarryManager->indexedAdvanceCarryInCarryOut(b, adv, strm, index_strm);
454        } else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
455            Value * const marker = compileExpression(b, mstar->getMarker());
456            Value * const cc = compileExpression(b, mstar->getCharClass());
457            Value * const marker_and_cc = b->simd_and(marker, cc);
458            Value * const sum = mCarryManager->addCarryInCarryOut(b, mstar, marker_and_cc, cc);
459            value = b->simd_or(b->simd_xor(sum, cc), marker);
460        } else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
461            Value * const from = compileExpression(b, sthru->getScanFrom());
462            Value * const thru = compileExpression(b, sthru->getScanThru());
463            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, thru);
464            value = b->simd_and(sum, b->simd_not(thru));
465        } else if (const ScanTo * sthru = dyn_cast<ScanTo>(stmt)) {
466            Value * const marker_expr = compileExpression(b, sthru->getScanFrom());
467            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
468            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, marker_expr, b->simd_not(to));
469            value = b->simd_and(sum, to);
470        } else if (const AdvanceThenScanThru * sthru = dyn_cast<AdvanceThenScanThru>(stmt)) {
471            Value * const from = compileExpression(b, sthru->getScanFrom());
472            Value * const thru = compileExpression(b, sthru->getScanThru());
473            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, thru));
474            value = b->simd_and(sum, b->simd_not(thru));
475        } else if (const AdvanceThenScanTo * sthru = dyn_cast<AdvanceThenScanTo>(stmt)) {
476            Value * const from = compileExpression(b, sthru->getScanFrom());
477            Value * const to = b->simd_xor(compileExpression(b, sthru->getScanTo()), b->getScalarField("EOFmask"));
478            Value * const sum = mCarryManager->addCarryInCarryOut(b, sthru, from, b->simd_or(from, b->simd_not(to)));
479            value = b->simd_and(sum, to);
480        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
481            expr = cast<Assign>(stmt)->getVariable();
482            value = compileExpression(b, cast<Assign>(stmt)->getValue());
483            if (isa<Extract>(expr) || (isa<Var>(expr) && cast<Var>(expr)->isKernelParameter())) {
484                Value * const ptr = compileExpression(b, expr, false);
485                b->CreateAlignedStore(value, ptr, getAlignment(value));
486                value = ptr;
487            }
488        } else if (const InFile * e = dyn_cast<InFile>(stmt)) {
489            Value * EOFmask = b->getScalarField("EOFmask");
490            value = b->simd_and(compileExpression(b, e->getExpr()), b->simd_not(EOFmask));
491        } else if (const AtEOF * e = dyn_cast<AtEOF>(stmt)) {
492            Value * EOFbit = b->getScalarField("EOFbit");
493            value = b->simd_and(compileExpression(b, e->getExpr()), EOFbit);
494        } else if (const Count * c = dyn_cast<Count>(stmt)) {
495            Value * EOFbit = b->getScalarField("EOFbit");
496            Value * EOFmask = b->getScalarField("EOFmask");
497            Value * const to_count = b->simd_and(b->simd_or(b->simd_not(EOFmask), EOFbit), compileExpression(b, c->getExpr()));
498            const unsigned counterSize = b->getSizeTy()->getBitWidth();
499            const auto f = mAccumulator.find(c);
500            if (LLVM_UNLIKELY(f == mAccumulator.end())) {
501                report_fatal_error("Unknown accumulator: " + c->getName().str());
502            }
503            Value * ptr = b->getScalarFieldPtr(f->second);
504            const auto alignment = getPointerElementAlignment(ptr);
505            Value * countSoFar = b->CreateAlignedLoad(ptr, alignment, c->getName() + "_accumulator");
506            auto fields = (b->getBitBlockWidth() / counterSize);
507            Value * fieldCounts = b->simd_popcount(counterSize, to_count);
508            while (fields > 1) {
509                fields = fields/2;
510                fieldCounts = b->CreateAdd(fieldCounts, b->mvmd_srli(counterSize, fieldCounts, fields));
511            }
512            value = b->CreateAdd(b->mvmd_extract(counterSize, fieldCounts, 0), countSoFar, "countSoFar");
513            b->CreateAlignedStore(value, ptr, alignment);
514        } else if (const Lookahead * l = dyn_cast<Lookahead>(stmt)) {
515            PabloAST * stream = l->getExpression();
516            Value * index = nullptr;
517            if (LLVM_UNLIKELY(isa<Extract>(stream))) {               
518                index = compileExpression(b, cast<Extract>(stream)->getIndex(), true);
519                stream = cast<Extract>(stream)->getArray();
520            } else {
521                index = b->getInt32(0);
522            }
523            const auto bit_shift = (l->getAmount() % b->getBitBlockWidth());
524            const auto block_shift = (l->getAmount() / b->getBitBlockWidth());
525            Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift));
526            Value * lookAhead = b->CreateBlockAlignedLoad(ptr);
527            if (bit_shift == 0) {  // Simple case with no intra-block shifting.
528                value = lookAhead;
529            } else { // Need to form shift result from two adjacent blocks.
530                Value * ptr = b->getInputStreamBlockPtr(cast<Var>(stream)->getName(), index, b->getSize(block_shift + 1));
531                Value * lookAhead1 = b->CreateBlockAlignedLoad(ptr);
532                if (LLVM_UNLIKELY((bit_shift % 8) == 0)) { // Use a single whole-byte shift, if possible.
533                    value = b->mvmd_dslli(8, lookAhead1, lookAhead, (bit_shift / 8));
534                } else {
535                    Type  * const streamType = b->getIntNTy(b->getBitBlockWidth());
536                    Value * b1 = b->CreateBitCast(lookAhead1, streamType);
537                    Value * b0 = b->CreateBitCast(lookAhead, streamType);
538                    Value * result = b->CreateOr(b->CreateShl(b1, b->getBitBlockWidth() - bit_shift), b->CreateLShr(b0, bit_shift));
539                    value = b->CreateBitCast(result, b->getBitBlockType());
540                }
541            }
542        } else if (const Repeat * const s = dyn_cast<Repeat>(stmt)) {
543            value = compileExpression(b, s->getValue());
544            Type * const ty = s->getType();
545            if (LLVM_LIKELY(ty->isVectorTy())) {
546                const auto fw = s->getFieldWidth()->value();
547                value = b->CreateZExtOrTrunc(value, b->getIntNTy(fw));
548                value = b->simd_fill(fw, value);
549            } else {
550                value = b->CreateZExtOrTrunc(value, ty);
551            }
552        #if 0
553        } else if (const PackH * const p = dyn_cast<PackH>(stmt)) {
554            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
555            const auto packedWidth = p->getFieldWidth()->value();
556            Value * const base = compileExpression(b, p->getValue(), false);
557            const auto packs = sourceWidth / 2;
558            if (LLVM_LIKELY(packs > 1)) {
559                value = b->CreateAlloca(b->getBitBlockType(), b->getInt32(packs));
560            }
561            Constant * const ZERO = b->getInt32(0);
562            for (unsigned i = 0; i < packs; ++i) {
563                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
564                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
565                Value * P = b->hsimd_packh(packedWidth, A, B);
566                if (LLVM_UNLIKELY(packs == 1)) {
567                    value = P;
568                    break;
569                }
570                b->CreateStore(P, b->CreateGEP(value, b->getInt32(i)));
571            }
572        } else if (const PackL * const p = dyn_cast<PackL>(stmt)) {
573            const auto sourceWidth = p->getValue()->getType()->getVectorElementType()->getIntegerBitWidth();
574            const auto packedWidth = p->getFieldWidth()->value();
575            Value * const base = compileExpression(b, p->getValue(), false);
576            const auto count = sourceWidth / 2;
577            if (LLVM_LIKELY(count > 1)) {
578                value = b->CreateAlloca(b->getBitBlockType(), b->getInt32(count));
579            }
580            Constant * const ZERO = b->getInt32(0);
581            for (unsigned i = 0; i < count; ++i) {
582                Value * A = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2)}));
583                Value * B = b->CreateLoad(b->CreateGEP(base, {ZERO, b->getInt32(i * 2 + 1)}));
584                Value * P = b->hsimd_packl(packedWidth, A, B);
585                if (LLVM_UNLIKELY(count == 1)) {
586                    value = P;
587                    break;
588                }
589                b->CreateStore(P, b->CreateGEP(value, b->getInt32(i)));
590            }
591        #endif
592        } else {
593            std::string tmp;
594            raw_string_ostream out(tmp);
595            out << "PabloCompiler: ";
596            stmt->print(out);
597            out << " was not recognized by the compiler";
598            report_fatal_error(out.str());
599        }
600        assert (expr);
601        assert (value);
602        mMarker[expr] = value;
603        if (DebugOptionIsSet(DumpTrace)) {
604            std::string tmp;
605            raw_string_ostream name(tmp);
606            expr->print(name);
607            if (value->getType()->isVectorTy()) {
608                b->CallPrintRegister(name.str(), value);
609            } else if (value->getType()->isIntegerTy()) {
610                b->CallPrintInt(name.str(), value);
611            }
612        }
613    }
614}
615
616unsigned getIntegerBitWidth(const Type * ty) {
617    if (ty->isArrayTy()) {
618        assert (ty->getArrayNumElements() == 1);
619        ty = ty->getArrayElementType();
620    }
621    if (ty->isVectorTy()) {
622        assert (ty->getVectorNumElements() == 0);
623        ty = ty->getVectorElementType();
624    }
625    return ty->getIntegerBitWidth();
626}
627
628Value * PabloCompiler::compileExpression(const std::unique_ptr<kernel::KernelBuilder> & b, const PabloAST * const expr, const bool ensureLoaded) {
629    const auto f = mMarker.find(expr);   
630    Value * value = nullptr;
631    if (LLVM_LIKELY(f != mMarker.end())) {
632        value = f->second;
633    } else {
634        if (isa<Integer>(expr)) {
635            value = ConstantInt::get(cast<Integer>(expr)->getType(), cast<Integer>(expr)->value());
636        } else if (isa<Zeroes>(expr)) {
637            value = b->allZeroes();
638        } else if (LLVM_UNLIKELY(isa<Ones>(expr))) {
639            value = b->allOnes();
640        } else if (isa<Extract>(expr)) {
641            const Extract * const extract = cast<Extract>(expr);
642            const Var * const var = cast<Var>(extract->getArray());
643            Value * const index = compileExpression(b, extract->getIndex());
644            value = getPointerToVar(b, var, index);
645        } else if (LLVM_UNLIKELY(isa<Var>(expr))) {
646            const Var * const var = cast<Var>(expr);
647            if (LLVM_LIKELY(var->isKernelParameter() && var->isScalar())) {
648                value = b->getScalarFieldPtr(var->getName());
649            } else { // use before def error
650                std::string tmp;
651                raw_string_ostream out(tmp);
652                out << "PabloCompiler: ";
653                expr->print(out);
654                out << " is not a scalar value or was used before definition";
655                report_fatal_error(out.str());
656            }
657        } else if (LLVM_UNLIKELY(isa<Operator>(expr))) {
658            const Operator * const op = cast<Operator>(expr);
659            const PabloAST * lh = op->getLH();
660            const PabloAST * rh = op->getRH();
661            if ((isa<Var>(lh) || isa<Extract>(lh)) || (isa<Var>(rh) || isa<Extract>(rh))) {
662                const unsigned n = std::min(getIntegerBitWidth(lh->getType()), getIntegerBitWidth(rh->getType()));
663                const unsigned m = b->getBitBlockWidth() / n;
664                IntegerType * const fw = b->getIntNTy(m);
665                VectorType * const vTy = VectorType::get(b->getIntNTy(n), m);
666
667                Value * baseLhv = nullptr;
668                Value * lhvStreamIndex = nullptr;
669                if (isa<Var>(lh)) {
670                    lhvStreamIndex = b->getInt32(0);
671                } else if (isa<Extract>(lh)) {
672                    lhvStreamIndex = compileExpression(b, cast<Extract>(lh)->getIndex());
673                } else {
674                    baseLhv = compileExpression(b, lh, false);
675                }
676
677                Value * baseRhv = nullptr;
678                Value * rhvStreamIndex = nullptr;
679                if (isa<Var>(rh)) {
680                    rhvStreamIndex = b->getInt32(0);
681                } else if (isa<Extract>(lh)) {
682                    rhvStreamIndex = compileExpression(b, cast<Extract>(rh)->getIndex());
683                } else {
684                    baseRhv = compileExpression(b, rh, false);
685                }
686
687                const TypeId typeId = op->getClassTypeId();
688
689                if (LLVM_UNLIKELY(typeId == TypeId::Add || typeId == TypeId::Subtract)) {
690
691
692
693                    value = b->CreateAlloca(vTy, b->getInt32(n));
694
695                    for (unsigned i = 0; i < n; ++i) {
696                        llvm::Constant * const index = b->getInt32(i);
697                        Value * lhv = nullptr;
698                        if (baseLhv) {
699                            lhv = baseLhv;
700                        } else {
701                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
702                            lhv = b->CreateAlignedLoad(lhv, getAlignment(lhv));
703                        }
704                        lhv = b->CreateBitCast(lhv, vTy);
705
706                        Value * rhv = nullptr;
707                        if (baseRhv) {
708                            rhv = baseRhv;
709                        } else {
710                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
711                            rhv = b->CreateAlignedLoad(rhv, getAlignment(rhv));
712                        }
713                        rhv = b->CreateBitCast(rhv, vTy);
714
715                        Value * result = nullptr;
716                        if (typeId == TypeId::Add) {
717                            result = b->CreateAdd(lhv, rhv);
718                        } else {
719                            result = b->CreateSub(lhv, rhv);
720                        }
721                        b->CreateAlignedStore(result, b->CreateGEP(value, {b->getInt32(0), b->getInt32(i)}), getAlignment(result));
722                    }
723
724
725
726                } else {
727
728                    value = UndefValue::get(VectorType::get(fw, n));
729
730                    for (unsigned i = 0; i < n; ++i) {
731                        llvm::Constant * const index = b->getInt32(i);
732                        Value * lhv = nullptr;
733                        if (baseLhv) {
734                            lhv = baseLhv;
735                        } else {
736                            lhv = getPointerToVar(b, cast<Var>(lh), lhvStreamIndex, index);
737                            lhv = b->CreateAlignedLoad(lhv, getAlignment(lhv));
738                        }
739                        lhv = b->CreateBitCast(lhv, vTy);
740
741                        Value * rhv = nullptr;
742                        if (baseRhv) {
743                            rhv = baseRhv;
744                        } else {
745                            rhv = getPointerToVar(b, cast<Var>(rh), rhvStreamIndex, index);
746                            rhv = b->CreateAlignedLoad(rhv, getAlignment(rhv));
747                        }
748                        rhv = b->CreateBitCast(rhv, vTy);
749
750                        Value * comp = nullptr;
751                        switch (typeId) {
752                            case TypeId::GreaterThanEquals:
753                            case TypeId::LessThan:
754                                comp = b->simd_ult(n, lhv, rhv);
755                                break;
756                            case TypeId::Equals:
757                            case TypeId::NotEquals:
758                                comp = b->simd_eq(n, lhv, rhv);
759                                break;
760                            case TypeId::LessThanEquals:
761                            case TypeId::GreaterThan:
762                                comp = b->simd_ugt(n, lhv, rhv);
763                                break;
764                            default: llvm_unreachable("invalid vector operator id");
765                        }
766                        Value * const mask = b->CreateBitCast(b->hsimd_signmask(n, comp), fw);
767                        value = b->mvmd_insert(m, value, mask, i);
768                    }
769
770                    value = b->CreateBitCast(value, b->getBitBlockType());
771                    switch (typeId) {
772                        case TypeId::GreaterThanEquals:
773                        case TypeId::LessThanEquals:
774                        case TypeId::NotEquals:
775                            value = b->simd_not(value);
776                        default: break;
777                    }
778                }
779
780            } else {
781                Value * const lhv = compileExpression(b, lh);
782                Value * const rhv = compileExpression(b, rh);
783                switch (op->getClassTypeId()) {
784                    case TypeId::Add:
785                        value = b->CreateAdd(lhv, rhv); break;
786                    case TypeId::Subtract:
787                        value = b->CreateSub(lhv, rhv); break;
788                    case TypeId::LessThan:
789                        value = b->CreateICmpSLT(lhv, rhv); break;
790                    case TypeId::LessThanEquals:
791                        value = b->CreateICmpSLE(lhv, rhv); break;
792                    case TypeId::Equals:
793                        value = b->CreateICmpEQ(lhv, rhv); break;
794                    case TypeId::GreaterThanEquals:
795                        value = b->CreateICmpSGE(lhv, rhv); break;
796                    case TypeId::GreaterThan:
797                        value = b->CreateICmpSGT(lhv, rhv); break;
798                    case TypeId::NotEquals:
799                        value = b->CreateICmpNE(lhv, rhv); break;
800                    default: llvm_unreachable("invalid scalar operator id");
801                }
802            }
803        } else { // use before def error
804            std::string tmp;
805            raw_string_ostream out(tmp);
806            out << "PabloCompiler: ";
807            expr->print(out);
808            out << " was used before definition";
809            report_fatal_error(out.str());
810        }
811        assert (value);
812        // mMarker.insert({expr, value});
813    }
814    if (LLVM_UNLIKELY(value->getType()->isPointerTy() && ensureLoaded)) {
815        value = b->CreateAlignedLoad(value, getPointerElementAlignment(value));
816    }
817    return value;
818}
819
820Value * PabloCompiler::getPointerToVar(const std::unique_ptr<kernel::KernelBuilder> & b, const Var * var, Value * index1, Value * index2)  {
821    assert (var && index1 && (index2 == nullptr || index1->getType() == index2->getType()));
822    if (LLVM_LIKELY(var->isKernelParameter())) {
823        if (LLVM_UNLIKELY(var->isScalar())) {
824            std::string tmp;
825            raw_string_ostream out(tmp);
826            out << mKernel->getName();
827            out << ": cannot index scalar value ";
828            var->print(out);
829            report_fatal_error(out.str());
830        } else if (var->isReadOnly()) {
831            if (index2) {
832                return b->getInputStreamPackPtr(var->getName(), index1, index2);
833            } else {
834                return b->getInputStreamBlockPtr(var->getName(), index1);
835            }
836        } else if (var->isReadNone()) {
837            if (index2) {
838                return b->getOutputStreamPackPtr(var->getName(), index1, index2);
839            } else {
840                return b->getOutputStreamBlockPtr(var->getName(), index1);
841            }
842        } else {
843            std::string tmp;
844            raw_string_ostream out(tmp);
845            out << mKernel->getName();
846            out << ": stream ";
847            var->print(out);
848            out << " cannot be read from or written to";
849            report_fatal_error(out.str());
850        }
851    } else {
852        Value * const ptr = compileExpression(b, var, false);
853        std::vector<Value *> offsets;
854        offsets.push_back(ConstantInt::getNullValue(index1->getType()));
855        offsets.push_back(index1);
856        if (index2) offsets.push_back(index2);
857        return b->CreateGEP(ptr, offsets);
858    }
859}
860
861PabloCompiler::PabloCompiler(PabloKernel * const kernel)
862: mKernel(kernel)
863, mCarryManager(new CarryManager)
864, mBranchCount(0) {
865    assert ("PabloKernel cannot be null!" && kernel);
866}
867
868PabloCompiler::~PabloCompiler() {
869    delete mCarryManager;
870}
871
872}
Note: See TracBrowser for help on using the repository browser.