source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_simplifier.cpp @ 5607

Last change on this file since 5607 was 5607, checked in by nmedfort, 20 months ago

Bug fix for DistributivePass?. Minor change to Simplifier to prevent the first conditional assignment of a Var from being combined with its strictly dominating assigned value.

File size: 20.7 KB
Line 
1#include <pablo/optimizers/pablo_simplifier.hpp>
2#include <pablo/pablo_kernel.h>
3#include <pablo/codegenstate.h>
4#include <pablo/expression_map.hpp>
5#include <pablo/boolean.h>
6#include <pablo/pe_zeroes.h>
7#include <pablo/pe_ones.h>
8#include <pablo/arithmetic.h>
9#include <pablo/branch.h>
10#include <pablo/ps_assign.h>
11#include <pablo/pe_advance.h>
12#include <pablo/pe_scanthru.h>
13#include <pablo/pe_matchstar.h>
14#include <pablo/pe_var.h>
15#include <boost/container/flat_set.hpp>
16#ifndef NDEBUG
17#include <pablo/analysis/pabloverifier.hpp>
18#endif
19#include <llvm/Support/raw_ostream.h>
20
21
22using namespace boost;
23using namespace boost::container;
24using namespace llvm;
25
26namespace pablo {
27
28using TypeId = PabloAST::ClassTypeId;
29
30/** ------------------------------------------------------------------------------------------------------------- *
31 * @brief fold
32 ** ------------------------------------------------------------------------------------------------------------- */
33PabloAST * triviallyFold(Statement * stmt, PabloBlock * const block) {
34    if (isa<Not>(stmt)) {
35        PabloAST * value = stmt->getOperand(0);
36        if (LLVM_UNLIKELY(isa<Not>(value))) {
37            return cast<Not>(value)->getOperand(0); // ¬¬A ⇔ A
38        } else if (LLVM_UNLIKELY(isa<Zeroes>(value))) {
39            return block->createOnes(stmt->getType()); // ¬0 ⇔ 1
40        }  else if (LLVM_UNLIKELY(isa<Ones>(value))) {
41            return block->createZeroes(stmt->getType()); // ¬1 ⇔ 0
42        }
43    } else if (isa<Advance>(stmt)) {
44        if (LLVM_UNLIKELY(isa<Zeroes>(stmt->getOperand(0)))) {
45            return block->createZeroes(stmt->getType());
46        }
47    } else if (isa<Add>(stmt) || isa<Subtract>(stmt)) {
48       if (LLVM_UNLIKELY(isa<Integer>(stmt->getOperand(0)) && isa<Integer>(stmt->getOperand(1)))) {
49           const Integer * const int0 = cast<Integer>(stmt->getOperand(0));
50           const Integer * const int1 = cast<Integer>(stmt->getOperand(1));
51           Integer::IntTy result = 0;
52           if (isa<Add>(stmt)) {
53               result = int0->value() + int1->value();
54           } else {
55               result = int0->value() - int1->value();
56           }
57           return block->getInteger(result);
58       }
59    } else {
60        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
61            if (LLVM_UNLIKELY(isa<Zeroes>(stmt->getOperand(i)))) {
62                switch (stmt->getClassTypeId()) {
63                    case TypeId::Sel:
64                        block->setInsertPoint(stmt->getPrevNode());
65                        switch (i) {
66                            case 0: return stmt->getOperand(2);
67                            case 1: return block->createAnd(block->createNot(stmt->getOperand(0)), stmt->getOperand(2));
68                            case 2: return block->createAnd(stmt->getOperand(0), stmt->getOperand(1));
69                        }
70                    case TypeId::ScanThru:
71                    case TypeId::MatchStar:
72                        return stmt->getOperand(0);
73                    default: break;
74                }
75            } else if (LLVM_UNLIKELY(isa<Ones>(stmt->getOperand(i)))) {
76                switch (stmt->getClassTypeId()) {
77                    case TypeId::Sel:
78                        block->setInsertPoint(stmt->getPrevNode());
79                        switch (i) {
80                            case 0: return stmt->getOperand(1);
81                            case 1: return block->createOr(stmt->getOperand(0), stmt->getOperand(2));
82                            case 2: return block->createOr(block->createNot(stmt->getOperand(0)), stmt->getOperand(1));
83                        }
84                    case TypeId::ScanThru:
85                        if (LLVM_UNLIKELY(i == 1)) {
86                            return block->createZeroes(stmt->getType());
87                        }
88                        break;
89                    case TypeId::MatchStar:
90                        if (LLVM_UNLIKELY(i == 0)) {
91                            return block->createOnes(stmt->getType());
92                        }
93                        break;
94                    default: break;
95                }
96            }
97        }       
98    }
99    return nullptr;
100}
101
102/** ------------------------------------------------------------------------------------------------------------- *
103 * @brief VariableTable
104 ** ------------------------------------------------------------------------------------------------------------- */
105struct VariableTable {
106
107    VariableTable(VariableTable * predecessor = nullptr)
108    : mPredecessor(predecessor) {
109
110    }
111
112    PabloAST * get(PabloAST * const var) const {
113        const auto f = mMap.find(var);
114        if (f == mMap.end()) {
115            return (mPredecessor) ? mPredecessor->get(var) : nullptr;
116        }
117        return f->second;
118    }
119
120    void put(PabloAST * const var, PabloAST * value) {
121        const auto f = mMap.find(var);
122        if (LLVM_LIKELY(f == mMap.end())) {
123            mMap.emplace(var, value);
124        } else {
125            f->second = value;
126        }
127        assert (get(var) == value);
128    }
129
130    bool isNonZero(const PabloAST * const var) const {
131        if (mNonZero.count(var) != 0) {
132            return true;
133        } else if (mPredecessor) {
134            return mPredecessor->isNonZero(var);
135        }
136        return false;
137    }
138
139    void addNonZero(const PabloAST * const var) {
140        mNonZero.insert(var);
141    }
142
143private:
144    VariableTable * const mPredecessor;
145    flat_map<PabloAST *, PabloAST *> mMap;
146    flat_set<const PabloAST *> mNonZero;
147};
148
149/** ------------------------------------------------------------------------------------------------------------- *
150 * @brief isTrivial
151 *
152 * If this inner block is composed of only Boolean logic and Assign statements and there are fewer than 3
153 * statements, just add the statements in the inner block to the current block
154 ** ------------------------------------------------------------------------------------------------------------- */
155inline bool isTrivial(const PabloBlock * const block) {
156    unsigned computations = 0;
157    for (const Statement * stmt : *block) {
158        switch (stmt->getClassTypeId()) {
159            case TypeId::And:
160            case TypeId::Or:
161            case TypeId::Xor:
162                if (++computations > 3) {
163                    return false;
164                }
165            case TypeId::Not:
166            case TypeId::Assign:
167                break;
168            default:
169                return false;
170        }
171    }
172    return true;
173}
174
175/** ------------------------------------------------------------------------------------------------------------- *
176 * @brief flatten
177 ** ------------------------------------------------------------------------------------------------------------- */
178Statement * flatten(Branch * const br) {
179    Statement * stmt = br;
180    Statement * nested = br->getBody()->front();
181    while (nested) {
182        Statement * next = nested->removeFromParent();
183        nested->insertAfter(stmt);
184        stmt = nested;
185        nested = next;
186    }
187    return br->eraseFromParent();
188}
189
190/** ------------------------------------------------------------------------------------------------------------- *
191 * @brief redundancyElimination
192 *
193 * Note: Do not recursively delete statements in this function. The ExpressionTable could use deleted statements
194 * as replacements. Let the DCE remove the unnecessary statements with the finalized Def-Use information.
195 ** ------------------------------------------------------------------------------------------------------------- */
196void redundancyElimination(PabloBlock * const block, ExpressionTable * const et, VariableTable * const vt) {
197    VariableTable variables(vt);
198
199    // When processing a While body, we cannot use its initial value from the outer
200    // body since the Var will likely be assigned a different value in the current
201    // body that should be used on the subsequent iteration of the loop.
202    if (Branch * br = block->getBranch()) {
203        assert ("block has a branch but the expression and variable tables were not supplied" && et && vt);
204        variables.addNonZero(br->getCondition());
205        for (Var * var : br->getEscaped()) {
206            variables.put(var, var);
207        }
208    }
209
210    ExpressionTable expressions(et);
211
212    Statement * stmt = block->front();
213    while (stmt) {
214
215        if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
216            Assign * const assign = cast<Assign>(stmt);
217            PabloAST * const var = assign->getVariable();
218            PabloAST * value = assign->getValue();
219            while (LLVM_UNLIKELY(isa<Var>(value))) {
220                PabloAST * next = variables.get(cast<Var>(value));
221                if (LLVM_LIKELY(next == nullptr || next == value)) {
222                    break;
223                }
224                value = next;
225                assign->setValue(value);
226            }
227            if (LLVM_UNLIKELY(variables.get(var) == value)) {
228                stmt = stmt->eraseFromParent();
229                continue;
230            }
231            variables.put(var, value);
232        } else if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
233
234            Branch * const br = cast<Branch>(stmt);
235
236            // Test whether we can ever take this branch
237            PabloAST * cond = br->getCondition();
238            if (isa<Var>(cond)) {
239                PabloAST * const value = variables.get(cast<Var>(cond));
240                if (value) {
241                    cond = value;
242                    // TODO: verify this works for a nested If node within a While body.
243                    if (isa<If>(br)) {
244                        br->setCondition(cond);
245                    }
246                }
247            }
248
249            if (LLVM_UNLIKELY(isa<Zeroes>(cond))) {
250                stmt = stmt->eraseFromParent();
251                continue;
252            }
253
254            if (LLVM_LIKELY(isa<If>(br))) {
255                if (LLVM_UNLIKELY(variables.isNonZero(cond))) {
256                    stmt = flatten(br);
257                    continue;
258                }
259            }
260
261            // Process the Branch body
262            redundancyElimination(br->getBody(), &expressions, &variables);
263
264            if (LLVM_LIKELY(isa<If>(br))) {
265                // Check whether the cost of testing the condition and taking the branch with
266                // 100% correct prediction rate exceeds the cost of the body itself
267                if (LLVM_UNLIKELY(isTrivial(br->getBody()))) {
268                    stmt = flatten(br);
269                    continue;
270                }
271            }
272
273        } else {
274
275            // demote any uses of a Var whose value is in scope
276            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
277                PabloAST * op = stmt->getOperand(i);
278                if (LLVM_UNLIKELY(isa<Var>(op))) {
279                    PabloAST * const value = variables.get(cast<Var>(op));
280                    if (value && value != op) {
281                        stmt->setOperand(i, value);
282                    }
283                }
284            }
285
286            PabloAST * const folded = triviallyFold(stmt, block);
287            if (folded) {
288                stmt = stmt->replaceWith(folded);
289                continue;
290            }
291
292            // By recording which statements have already been seen, we can detect the redundant statements
293            // as any having the same type and operands. If so, we can replace its users with the prior statement.
294            // and erase this statement from the AST
295            const auto f = expressions.findOrAdd(stmt);
296            if (!f.second) {
297                stmt = stmt->replaceWith(f.first);
298                continue;
299            }
300
301            // Check whether this statement is trivially non-zero and if so, add it to our set of non-zero variables.
302            // This will allow us to flatten an If scope if its branch is always taken.
303            if (isa<Or>(stmt)) {
304                for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
305                    if (LLVM_UNLIKELY(variables.isNonZero(stmt->getOperand(i)))) {
306                        variables.addNonZero(stmt);
307                        break;
308                    }
309                }
310            } else if (isa<Advance>(stmt)) {
311                const Advance * const adv = cast<Advance>(stmt);
312                if (LLVM_LIKELY(adv->getAmount() < 32)) {
313                    if (LLVM_UNLIKELY(variables.isNonZero(adv->getExpression()))) {
314                        variables.addNonZero(adv);
315                    }
316                }
317            }
318        }
319
320        stmt = stmt->getNextNode();
321    }
322
323    // If this block has a branch statement leading into it, we can verify whether an escaped value
324    // was updated within this block and update the preceeding block's variable state appropriately.
325
326    if (Branch * const br = block->getBranch()) {
327
328        // When removing identical escaped values, we have to consider that the identical Vars could
329        // be assigned new differing values later in the outer body. Thus instead of replacing them
330        // directly, we map future uses of the duplicate Var to the initial one. The DCE pass will
331        // later mark any Assign statement as dead if the Var is never read.
332
333        /// TODO: this doesn't properly optimize the loop control variable(s) yet.
334
335        const auto escaped = br->getEscaped();
336        const auto n = escaped.size();
337        PabloAST * variable[n];
338        PabloAST * incoming[n];
339        PabloAST * outgoing[n];
340
341        for (unsigned i = 0; i < escaped.size(); ++i) {
342            PabloAST * var = escaped[i];
343            incoming[i] = vt->get(var);
344            outgoing[i] = variables.get(var);
345            if (LLVM_UNLIKELY(incoming[i] == outgoing[i])) {
346                var = incoming[i];
347            } else {
348                for (size_t j = 0; j != i; ++j) {
349                    if ((outgoing[j] == outgoing[i]) && (incoming[j] == incoming[i])) {
350                        var = variable[j];
351                        break;
352                    }
353                }
354            }
355            variable[i] = var;
356            vt->put(escaped[i], var);
357        }
358    }
359}
360
361/** ------------------------------------------------------------------------------------------------------------- *
362 * @brief deadCodeElimination
363 ** ------------------------------------------------------------------------------------------------------------- */
364void deadCodeElimination(PabloBlock * const block) {
365
366    flat_map<PabloAST *, Assign *> unread;
367
368    Statement * stmt = block->front();
369    while (stmt) {
370        if (unread.size() != 0) {
371            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
372                PabloAST * const op = stmt->getOperand(i);
373                if (LLVM_UNLIKELY(isa<Var>(op))) {
374                    unread.erase(op);
375                }
376            }
377        }
378        if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
379            Branch * const br = cast<Branch>(stmt);
380            deadCodeElimination(br->getBody());
381            if (LLVM_UNLIKELY(br->getEscaped().empty())) {
382                stmt = stmt->eraseFromParent(true);
383                continue;
384            }
385        } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
386            // An Assign statement is locally dead whenever its variable is not read
387            // before being reassigned a value.
388            PabloAST * var = cast<Assign>(stmt)->getVariable();
389            auto f = unread.find(var);
390            if (f != unread.end()) {
391                auto prior = f->second;
392                prior->eraseFromParent(true);
393                f->second = cast<Assign>(stmt);
394            } else {
395                unread.emplace(var, cast<Assign>(stmt));
396            }
397        } else if (LLVM_UNLIKELY(stmt->getNumUses() == 0)) {
398            stmt = stmt->eraseFromParent(true);
399            continue;
400        }
401        stmt = stmt->getNextNode();
402    }
403}
404
405/** ------------------------------------------------------------------------------------------------------------- *
406 * @brief deadCodeElimination
407 ** ------------------------------------------------------------------------------------------------------------- */
408void deadCodeElimination(PabloKernel * kernel) {
409
410    deadCodeElimination(kernel->getEntryBlock());
411
412    for (unsigned i = 0; i < kernel->getNumOfVariables(); ++i) {
413        Var * var = kernel->getVariable(i);
414        bool unused = true;
415        for (PabloAST * user : var->users()) {
416            if (isa<Assign>(user)) {
417                if (cast<Assign>(user)->getValue() == var) {
418                    unused = false;
419                    break;
420                }
421            } else {
422                unused = false;
423                break;
424            }
425        }
426        if (LLVM_UNLIKELY(unused)) {
427            for (PabloAST * user : var->users()) {
428                cast<Assign>(user)->eraseFromParent(true);
429            }
430        }
431    }
432
433}
434
435/** ------------------------------------------------------------------------------------------------------------- *
436 * @brief strengthReduction
437 *
438 * Find and replace any Pablo operations with a less expensive equivalent operation whenever possible.
439 ** ------------------------------------------------------------------------------------------------------------- */
440void strengthReduction(PabloBlock * const block) {
441
442    Statement * stmt = block->front();
443    while (stmt) {
444        if (isa<Branch>(stmt)) {
445            strengthReduction(cast<Branch>(stmt)->getBody());
446        } else if (isa<Advance>(stmt)) {
447            Advance * adv = cast<Advance>(stmt);
448            if (LLVM_UNLIKELY(isa<Advance>(adv->getOperand(0)))) {
449                // Replace an Advance(Advance(x, n), m) with an Advance(x,n + m)
450                // Test whether this will generate a long advance and abort?
451                Advance * op = cast<Advance>(stmt->getOperand(0));
452                if (LLVM_UNLIKELY(op->getNumUses() == 1)) {
453                    adv->setOperand(0, op->getOperand(0));
454                    adv->setOperand(1, block->getInteger(adv->getAmount() + op->getAmount()));
455                    op->eraseFromParent(false);
456                }
457            }
458        } else if (LLVM_UNLIKELY(isa<ScanThru>(stmt))) {
459            ScanThru * scanThru = cast<ScanThru>(stmt);
460            if (LLVM_UNLIKELY(isa<Advance>(scanThru->getScanFrom()))) {
461                // Replace a ScanThru(Advance(x,n),y) with an ScanThru(Advance(x, n - 1), Advance(x, n - 1) | y), where Advance(x, 0) = x
462                Advance * adv = cast<Advance>(scanThru->getScanFrom());
463                if (LLVM_UNLIKELY(adv->getNumUses() == 1)) {
464                    PabloAST * stream = adv->getExpression();
465                    block->setInsertPoint(stmt);
466                    if (LLVM_UNLIKELY(adv->getAmount() != 1)) {
467                        stream = block->createAdvance(stream, block->getInteger(adv->getAmount() - 1));
468                    }
469                    stmt = scanThru->replaceWith(block->createAdvanceThenScanThru(stream, scanThru->getScanThru()));
470                    adv->eraseFromParent(false);
471                    continue;
472                }
473            } else if (LLVM_UNLIKELY(isa<And>(scanThru->getScanFrom()))) {
474                // Suppose B is an arbitrary bitstream and A = Advance(B, 1). ScanThru(B ∧ ¬A, B) will leave a marker on the position
475                // following the end of any run of 1-bits in B. But this is equivalent to computing A ∧ ¬B since A will have exactly
476                // one 1-bit past the end of any run of 1-bits in B.
477
478
479
480
481
482            }
483        } else if (LLVM_UNLIKELY(isa<ScanTo>(stmt))) {
484            ScanTo * scanTo = cast<ScanTo>(stmt);
485            if (LLVM_UNLIKELY(isa<Advance>(scanTo->getScanFrom()))) {
486                // Replace a ScanTo(Advance(x,n),y) with an ScanTo(Advance(x, n - 1), Advance(x, n - 1) | y), where Advance(x, 0) = x
487                Advance * adv = cast<Advance>(scanTo->getScanFrom());
488                if (LLVM_UNLIKELY(adv->getNumUses() == 1)) {
489                    PabloAST * stream = adv->getExpression();
490                    block->setInsertPoint(stmt);
491                    if (LLVM_UNLIKELY(adv->getAmount() != 1)) {
492                        stream = block->createAdvance(stream, block->getInteger(adv->getAmount() - 1));
493                    }
494                    stmt = scanTo->replaceWith(block->createAdvanceThenScanTo(stream, scanTo->getScanTo()));
495                    adv->eraseFromParent(false);
496                    continue;
497                }
498            }
499        }
500        stmt = stmt->getNextNode();
501    }
502}
503
504/** ------------------------------------------------------------------------------------------------------------- *
505 * @brief optimize
506 ** ------------------------------------------------------------------------------------------------------------- */
507bool Simplifier::optimize(PabloKernel * kernel) {
508    redundancyElimination(kernel->getEntryBlock(), nullptr, nullptr);
509    strengthReduction(kernel->getEntryBlock());
510    deadCodeElimination(kernel);
511    #ifndef NDEBUG
512    PabloVerifier::verify(kernel, "post-simplification");
513    #endif
514    return true;
515}
516
517}
Note: See TracBrowser for help on using the repository browser.