source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_simplifier.cpp @ 5536

Last change on this file since 5536 was 5536, checked in by nmedfort, 22 months ago

Flatten If branches when the condition is trivially non zero

File size: 20.8 KB
Line 
1#include <pablo/optimizers/pablo_simplifier.hpp>
2#include <pablo/pablo_kernel.h>
3#include <pablo/codegenstate.h>
4#include <pablo/expression_map.hpp>
5#include <pablo/boolean.h>
6#include <pablo/pe_zeroes.h>
7#include <pablo/pe_ones.h>
8#include <pablo/arithmetic.h>
9#include <pablo/branch.h>
10#include <pablo/ps_assign.h>
11#include <pablo/pe_advance.h>
12#include <pablo/pe_scanthru.h>
13#include <pablo/pe_matchstar.h>
14#include <pablo/pe_var.h>
15#include <boost/container/flat_set.hpp>
16#ifndef NDEBUG
17#include <pablo/analysis/pabloverifier.hpp>
18#endif
19#include <llvm/Support/raw_ostream.h>
20
21
22using namespace boost;
23using namespace boost::container;
24using namespace llvm;
25
26namespace pablo {
27
28using TypeId = PabloAST::ClassTypeId;
29
30/** ------------------------------------------------------------------------------------------------------------- *
31 * @brief fold
32 ** ------------------------------------------------------------------------------------------------------------- */
33PabloAST * triviallyFold(Statement * stmt, PabloBlock * const block) {
34    if (isa<Not>(stmt)) {
35        PabloAST * value = stmt->getOperand(0);
36        if (LLVM_UNLIKELY(isa<Not>(value))) {
37            return cast<Not>(value)->getOperand(0); // ¬¬A ⇔ A
38        } else if (LLVM_UNLIKELY(isa<Zeroes>(value))) {
39            return block->createOnes(stmt->getType()); // ¬0 ⇔ 1
40        }  else if (LLVM_UNLIKELY(isa<Ones>(value))) {
41            return block->createZeroes(stmt->getType()); // ¬1 ⇔ 0
42        }
43    } else if (isa<Advance>(stmt)) {
44        if (LLVM_UNLIKELY(isa<Zeroes>(stmt->getOperand(0)))) {
45            return block->createZeroes(stmt->getType());
46        }
47    } else if (isa<Add>(stmt) || isa<Subtract>(stmt)) {
48       if (LLVM_UNLIKELY(isa<Integer>(stmt->getOperand(0)) && isa<Integer>(stmt->getOperand(1)))) {
49           const Integer * const int0 = cast<Integer>(stmt->getOperand(0));
50           const Integer * const int1 = cast<Integer>(stmt->getOperand(1));
51           Integer::IntTy result = 0;
52           if (isa<Add>(stmt)) {
53               result = int0->value() + int1->value();
54           } else {
55               result = int0->value() - int1->value();
56           }
57           return block->getInteger(result);
58       }
59    } else {
60        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
61            if (LLVM_UNLIKELY(isa<Zeroes>(stmt->getOperand(i)))) {
62                switch (stmt->getClassTypeId()) {
63                    case TypeId::Sel:
64                        block->setInsertPoint(stmt->getPrevNode());
65                        switch (i) {
66                            case 0: return stmt->getOperand(2);
67                            case 1: return block->createAnd(block->createNot(stmt->getOperand(0)), stmt->getOperand(2));
68                            case 2: return block->createAnd(stmt->getOperand(0), stmt->getOperand(1));
69                        }
70                    case TypeId::ScanThru:
71                    case TypeId::MatchStar:
72                        return stmt->getOperand(0);
73                    default: break;
74                }
75            } else if (LLVM_UNLIKELY(isa<Ones>(stmt->getOperand(i)))) {
76                switch (stmt->getClassTypeId()) {
77                    case TypeId::Sel:
78                        block->setInsertPoint(stmt->getPrevNode());
79                        switch (i) {
80                            case 0: return stmt->getOperand(1);
81                            case 1: return block->createOr(stmt->getOperand(0), stmt->getOperand(2));
82                            case 2: return block->createOr(block->createNot(stmt->getOperand(0)), stmt->getOperand(1));
83                        }
84                    case TypeId::ScanThru:
85                        if (LLVM_UNLIKELY(i == 1)) {
86                            return block->createZeroes(stmt->getType());
87                        }
88                        break;
89                    case TypeId::MatchStar:
90                        if (LLVM_UNLIKELY(i == 0)) {
91                            return block->createOnes(stmt->getType());
92                        }
93                        break;
94                    default: break;
95                }
96            }
97        }       
98    }
99    return nullptr;
100}
101
102/** ------------------------------------------------------------------------------------------------------------- *
103 * @brief VariableTable
104 ** ------------------------------------------------------------------------------------------------------------- */
105struct VariableTable {
106
107    VariableTable(VariableTable * predecessor = nullptr)
108    : mPredecessor(predecessor) {
109
110    }
111
112    PabloAST * get(PabloAST * const var) const {
113        const auto f = mMap.find(var);
114        if (f == mMap.end()) {
115            return (mPredecessor) ? mPredecessor->get(var) : nullptr;
116        }
117        return f->second;
118    }
119
120    void put(PabloAST * const var, PabloAST * value) {
121        const auto f = mMap.find(var);
122        if (LLVM_LIKELY(f == mMap.end())) {
123            mMap.emplace(var, value);
124        } else {
125            f->second = value;
126        }
127        assert (get(var) == value);
128    }
129
130    bool isNonZero(const PabloAST * const var) const {
131        if (mNonZero.count(var) != 0) {
132            return true;
133        } else if (mPredecessor) {
134            return mPredecessor->isNonZero(var);
135        }
136        return false;
137    }
138
139    void addNonZero(const PabloAST * const var) {
140        mNonZero.insert(var);
141    }
142
143private:
144    VariableTable * const mPredecessor;
145    flat_map<PabloAST *, PabloAST *> mMap;
146    flat_set<const PabloAST *> mNonZero;
147};
148
149/** ------------------------------------------------------------------------------------------------------------- *
150 * @brief isTrivial
151 *
152 * If this inner block is composed of only Boolean logic and Assign statements and there are fewer than 3
153 * statements, just add the statements in the inner block to the current block
154 ** ------------------------------------------------------------------------------------------------------------- */
155inline bool isTrivial(const PabloBlock * const block) {
156    unsigned computations = 0;
157    for (const Statement * stmt : *block) {
158        switch (stmt->getClassTypeId()) {
159            case TypeId::And:
160            case TypeId::Or:
161            case TypeId::Xor:
162                if (++computations > 3) {
163                    return false;
164                }
165            case TypeId::Not:
166            case TypeId::Assign:
167                break;
168            default:
169                return false;
170        }
171    }
172    return true;
173}
174
175/** ------------------------------------------------------------------------------------------------------------- *
176 * @brief flatten
177 ** ------------------------------------------------------------------------------------------------------------- */
178Statement * flatten(Branch * const br) {
179    Statement * stmt = br;
180    Statement * nested = br->getBody()->front();
181    while (nested) {
182        Statement * next = nested->removeFromParent();
183        nested->insertAfter(stmt);
184        stmt = nested;
185        nested = next;
186    }
187    return br->eraseFromParent();
188}
189
190/** ------------------------------------------------------------------------------------------------------------- *
191 * @brief redundancyElimination
192 *
193 * Note: Do not recursively delete statements in this function. The ExpressionTable could use deleted statements
194 * as replacements. Let the DCE remove the unnecessary statements with the finalized Def-Use information.
195 ** ------------------------------------------------------------------------------------------------------------- */
196void redundancyElimination(PabloBlock * const block, ExpressionTable * const et, VariableTable * const vt) {
197    VariableTable variables(vt);
198
199    // When processing a While body, we cannot use its initial value from the outer
200    // body since the Var will likely be assigned a different value in the current
201    // body that should be used on the subsequent iteration of the loop.
202    if (Branch * br = block->getBranch()) {
203        assert ("block has a branch but the expression and variable tables were not supplied" && et && vt);
204        variables.addNonZero(br->getCondition());
205        if (LLVM_UNLIKELY(isa<While>(br))) {
206            for (Var * var : cast<While>(br)->getEscaped()) {
207                variables.put(var, var);
208            }
209        }
210    }
211
212    ExpressionTable expressions(et);
213
214    Statement * stmt = block->front();
215    while (stmt) {
216
217        if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
218            Assign * const assign = cast<Assign>(stmt);
219            PabloAST * const var = assign->getVariable();
220            PabloAST * value = assign->getValue();
221            while (LLVM_UNLIKELY(isa<Var>(value))) {
222                PabloAST * next = variables.get(cast<Var>(value));
223                if (LLVM_LIKELY(next == nullptr || next == value)) {
224                    break;
225                }
226                value = next;
227                assign->setValue(value);
228            }
229            if (LLVM_UNLIKELY(variables.get(var) == value)) {
230                stmt = stmt->eraseFromParent();
231                continue;
232            }
233            variables.put(var, value);
234        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
235
236            Branch * const br = cast<Branch>(stmt);
237
238            // Test whether we can ever take this branch
239            PabloAST * cond = br->getCondition();
240            if (isa<Var>(cond)) {
241                PabloAST * const value = variables.get(cast<Var>(cond));
242                if (value) {
243                    cond = value;
244                    // TODO: verify this works for a nested If node within a While body.
245                    if (isa<If>(br)) {
246                        br->setCondition(cond);
247                    }
248                }
249            }
250
251            if (LLVM_UNLIKELY(isa<Zeroes>(cond))) {
252                stmt = stmt->eraseFromParent();
253                continue;
254            }
255
256            if (LLVM_LIKELY(isa<If>(br))) {
257                if (LLVM_UNLIKELY(variables.isNonZero(br->getCondition()))) {
258                    stmt = flatten(br);
259                    continue;
260                }
261            }
262
263            // Process the Branch body
264            redundancyElimination(br->getBody(), &expressions, &variables);
265
266            if (LLVM_LIKELY(isa<If>(br))) {
267                // Check whether the cost of testing the condition and taking the branch with
268                // 100% correct prediction rate exceeds the cost of the body itself
269                if (LLVM_UNLIKELY(isTrivial(br->getBody()))) {
270                    stmt = flatten(br);
271                    continue;
272                }
273            }
274
275        } else {
276
277            // demote any uses of a Var whose value is in scope
278            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
279                PabloAST * op = stmt->getOperand(i);
280                if (LLVM_UNLIKELY(isa<Var>(op))) {
281                    PabloAST * const value = variables.get(cast<Var>(op));
282                    if (value && value != op) {
283                        stmt->setOperand(i, value);
284                    }
285                }
286            }
287
288            PabloAST * const folded = triviallyFold(stmt, block);
289            if (folded) {
290                stmt = stmt->replaceWith(folded);
291                continue;
292            }
293
294            // By recording which statements have already been seen, we can detect the redundant statements
295            // as any having the same type and operands. If so, we can replace its users with the prior statement.
296            // and erase this statement from the AST
297            const auto f = expressions.findOrAdd(stmt);
298            if (!f.second) {
299                stmt = stmt->replaceWith(f.first);
300                continue;
301            }
302
303            // Check whether this statement is trivially non-zero and if so, add it to our set of non-zero variables.
304            // This will allow us to flatten an If scope if its branch is always taken.
305            if (isa<Or>(stmt)) {
306                for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
307                    if (LLVM_UNLIKELY(variables.isNonZero(stmt->getOperand(i)))) {
308                        variables.addNonZero(stmt);
309                        break;
310                    }
311                }
312            } else if (isa<Advance>(stmt)) {
313                const Advance * const adv = cast<Advance>(stmt);
314                if (LLVM_LIKELY(adv->getAmount() < 32)) {
315                    if (LLVM_UNLIKELY(variables.isNonZero(adv->getExpression()))) {
316                        variables.addNonZero(adv);
317                    }
318                }
319            }
320        }
321
322        stmt = stmt->getNextNode();
323    }
324
325    // If this block has a branch statement leading into it, we can verify whether an escaped value
326    // was updated within this block and update the preceeding block's variable state appropriately.
327
328    if (Branch * const br = block->getBranch()) {
329
330        // When removing identical escaped values, we have to consider that the identical Vars could
331        // be assigned new differing values later in the outer body. Thus instead of replacing them
332        // directly, we map future uses of the duplicate Var to the initial one. The DCE pass will
333        // later mark any Assign statement as dead if the Var is never read.
334
335        /// TODO: this doesn't properly optimize the loop control variable(s) yet.
336
337        const auto escaped = br->getEscaped();
338        const auto n = escaped.size();
339        PabloAST * variable[n];
340        PabloAST * incoming[n];
341        PabloAST * outgoing[n];
342
343        for (unsigned i = 0; i < escaped.size(); ++i) {
344            PabloAST * var = escaped[i];
345            incoming[i] = vt->get(var);
346            outgoing[i] = variables.get(var);
347            if (LLVM_UNLIKELY(incoming[i] == outgoing[i])) {
348                var = incoming[i];
349            } else {
350                for (size_t j = 0; j != i; ++j) {
351                    if ((outgoing[j] == outgoing[i]) && (incoming[j] == incoming[i])) {
352                        var = variable[j];
353                        break;
354                    }
355                }
356            }
357            variable[i] = var;
358            vt->put(escaped[i], var);
359        }
360    }
361}
362
363/** ------------------------------------------------------------------------------------------------------------- *
364 * @brief deadCodeElimination
365 ** ------------------------------------------------------------------------------------------------------------- */
366void deadCodeElimination(PabloBlock * const block) {
367
368    flat_map<PabloAST *, Assign *> unread;
369
370    Statement * stmt = block->front();
371    while (stmt) {
372        if (unread.size() != 0) {
373            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
374                PabloAST * const op = stmt->getOperand(i);
375                if (LLVM_UNLIKELY(isa<Var>(op))) {
376                    unread.erase(op);
377                }
378            }
379        }
380        if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
381            Branch * const br = cast<Branch>(stmt);
382            deadCodeElimination(br->getBody());
383            if (LLVM_UNLIKELY(br->getEscaped().empty())) {
384                stmt = stmt->eraseFromParent(true);
385                continue;
386            }
387        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
388            // An Assign statement is locally dead whenever its variable is not read
389            // before being reassigned a value.
390            PabloAST * var = cast<Assign>(stmt)->getVariable();
391            auto f = unread.find(var);
392            if (f != unread.end()) {
393                auto prior = f->second;
394                prior->eraseFromParent(true);
395                f->second = cast<Assign>(stmt);
396            } else {
397                unread.emplace(var, cast<Assign>(stmt));
398            }
399        } else if (LLVM_UNLIKELY(stmt->getNumUses() == 0)) {
400            stmt = stmt->eraseFromParent(true);
401            continue;
402        }
403        stmt = stmt->getNextNode();
404    }
405}
406
407/** ------------------------------------------------------------------------------------------------------------- *
408 * @brief deadCodeElimination
409 ** ------------------------------------------------------------------------------------------------------------- */
410void deadCodeElimination(PabloKernel * kernel) {
411
412    deadCodeElimination(kernel->getEntryBlock());
413
414    for (unsigned i = 0; i < kernel->getNumOfVariables(); ++i) {
415        Var * var = kernel->getVariable(i);
416        bool unused = true;
417        for (PabloAST * user : var->users()) {
418            if (isa<Assign>(user)) {
419                if (cast<Assign>(user)->getValue() == var) {
420                    unused = false;
421                    break;
422                }
423            } else {
424                unused = false;
425                break;
426            }
427        }
428        if (LLVM_UNLIKELY(unused)) {
429            for (PabloAST * user : var->users()) {
430                cast<Assign>(user)->eraseFromParent(true);
431            }
432        }
433    }
434
435}
436
437/** ------------------------------------------------------------------------------------------------------------- *
438 * @brief strengthReduction
439 *
440 * Find and replace any Pablo operations with a less expensive equivalent operation whenever possible.
441 ** ------------------------------------------------------------------------------------------------------------- */
442void strengthReduction(PabloBlock * const block) {
443
444    Statement * stmt = block->front();
445    while (stmt) {
446        if (isa<Branch>(stmt)) {
447            strengthReduction(cast<Branch>(stmt)->getBody());
448        } else if (isa<Advance>(stmt)) {
449            Advance * adv = cast<Advance>(stmt);
450            if (LLVM_UNLIKELY(isa<Advance>(adv->getOperand(0)))) {
451                // Replace an Advance(Advance(x, n), m) with an Advance(x,n + m)
452                // Test whether this will generate a long advance and abort?
453                Advance * op = cast<Advance>(stmt->getOperand(0));
454                if (LLVM_UNLIKELY(op->getNumUses() == 1)) {
455                    adv->setOperand(0, op->getOperand(0));
456                    adv->setOperand(1, block->getInteger(adv->getAmount() + op->getAmount()));
457                    op->eraseFromParent(false);
458                }
459            }
460        } else if (LLVM_UNLIKELY(isa<ScanThru>(stmt))) {
461            ScanThru * scanThru = cast<ScanThru>(stmt);
462            if (LLVM_UNLIKELY(isa<Advance>(scanThru->getScanFrom()))) {
463                // Replace a ScanThru(Advance(x,n),y) with an ScanThru(Advance(x, n - 1), Advance(x, n - 1) | y), where Advance(x, 0) = x
464                Advance * adv = cast<Advance>(scanThru->getScanFrom());
465                if (LLVM_UNLIKELY(adv->getNumUses() == 1)) {
466                    PabloAST * stream = adv->getExpression();
467                    block->setInsertPoint(stmt);
468                    if (LLVM_UNLIKELY(adv->getAmount() != 1)) {
469                        stream = block->createAdvance(stream, block->getInteger(adv->getAmount() - 1));
470                    }
471                    stmt = scanThru->replaceWith(block->createAdvanceThenScanThru(stream, scanThru->getScanThru()));
472                    adv->eraseFromParent(false);
473                    continue;
474                }
475            } else if (LLVM_UNLIKELY(isa<And>(scanThru->getScanFrom()))) {
476                // Suppose B is an arbitrary bitstream and A = Advance(B, 1). ScanThru(B ∧ ¬A, B) will leave a marker on the position
477                // following the end of any run of 1-bits in B. But this is equivalent to computing A ∧ ¬B since A will have exactly
478                // one 1-bit past the end of any run of 1-bits in B.
479
480
481
482
483
484            }
485        } else if (LLVM_UNLIKELY(isa<ScanTo>(stmt))) {
486            ScanTo * scanTo = cast<ScanTo>(stmt);
487            if (LLVM_UNLIKELY(isa<Advance>(scanTo->getScanFrom()))) {
488                // Replace a ScanTo(Advance(x,n),y) with an ScanTo(Advance(x, n - 1), Advance(x, n - 1) | y), where Advance(x, 0) = x
489                Advance * adv = cast<Advance>(scanTo->getScanFrom());
490                if (LLVM_UNLIKELY(adv->getNumUses() == 1)) {
491                    PabloAST * stream = adv->getExpression();
492                    block->setInsertPoint(stmt);
493                    if (LLVM_UNLIKELY(adv->getAmount() != 1)) {
494                        stream = block->createAdvance(stream, block->getInteger(adv->getAmount() - 1));
495                    }
496                    stmt = scanTo->replaceWith(block->createAdvanceThenScanTo(stream, scanTo->getScanTo()));
497                    adv->eraseFromParent(false);
498                    continue;
499                }
500            }
501        }
502        stmt = stmt->getNextNode();
503    }
504}
505
506/** ------------------------------------------------------------------------------------------------------------- *
507 * @brief optimize
508 ** ------------------------------------------------------------------------------------------------------------- */
509bool Simplifier::optimize(PabloKernel * kernel) {
510    redundancyElimination(kernel->getEntryBlock(), nullptr, nullptr);
511    strengthReduction(kernel->getEntryBlock());
512    deadCodeElimination(kernel);
513    #ifndef NDEBUG
514    PabloVerifier::verify(kernel, "post-simplification");
515    #endif
516    return true;
517}
518
519}
Note: See TracBrowser for help on using the repository browser.