source: icGREP/icgrep-devel/icgrep/pablo/optimizers/codemotionpass.cpp @ 5202

Last change on this file since 5202 was 5202, checked in by nmedfort, 2 years ago

Initial work on adding types to PabloAST and mutable Var objects.

File size: 7.4 KB
Line 
1#include "codemotionpass.h"
2#include <pablo/codegenstate.h>
3#include <pablo/analysis/pabloverifier.hpp>
4#include <boost/container/flat_set.hpp>
5#include <boost/container/flat_map.hpp>
6#include <boost/graph/adjacency_list.hpp>
7#include <boost/graph/topological_sort.hpp>
8#include <boost/circular_buffer.hpp>
9
10using namespace boost;
11using namespace boost::container;
12
13namespace pablo {
14
15/** ------------------------------------------------------------------------------------------------------------- *
16 * @brief optimize
17 ** ------------------------------------------------------------------------------------------------------------- */
18bool CodeMotionPass::optimize(PabloFunction & function) {
19    CodeMotionPass::movement(function.getEntryBlock());
20    #ifndef NDEBUG
21    PabloVerifier::verify(function, "post-code-motion");
22    #endif
23    return true;
24}
25
26/** ------------------------------------------------------------------------------------------------------------- *
27 * @brief movement
28 ** ------------------------------------------------------------------------------------------------------------- */
29void CodeMotionPass::movement(PabloBlock * const block) {
30    sink(block);
31    for (Statement * stmt : *block) {
32        if (isa<If>(stmt)) {
33            movement(cast<If>(stmt)->getBody());
34        } else if (isa<While>(stmt)) {
35            movement(cast<While>(stmt)->getBody());
36            // TODO: if we analyzed the probability of this loop being executed once, twice, or many times, we could
37            // determine whether hoisting will helpful or harmful to the expected run time.
38            hoistLoopInvariants(cast<While>(stmt));
39        }
40    }
41}
42
43/** ------------------------------------------------------------------------------------------------------------- *
44 * @brief depthTo
45 ** ------------------------------------------------------------------------------------------------------------- */
46inline static unsigned depthTo(const PabloBlock * scope, const PabloBlock * const root) {
47    unsigned depth = 0;
48    while (scope != root) {
49        ++depth;
50        assert (scope);
51        scope = scope->getPredecessor();
52    }
53    return depth;
54}
55
56/** ------------------------------------------------------------------------------------------------------------- *
57 * @brief findScopeUsages
58 ** ------------------------------------------------------------------------------------------------------------- */
59template <class ScopeSet>
60inline bool findScopeUsages(PabloAST * expr, ScopeSet & scopeSet, const PabloBlock * const block, const PabloBlock * const blocker) {
61    for (PabloAST * use : expr->users()) {
62        assert (isa<Statement>(use));
63        PabloBlock * const parent = cast<Statement>(use)->getParent();
64        if (LLVM_LIKELY(parent == block)) {
65            return false;
66        }
67        if (parent != blocker) {
68            scopeSet.insert(parent);
69        }
70    }
71    return true;
72}
73
74/** ------------------------------------------------------------------------------------------------------------- *
75 * @brief isAcceptableTarget
76 ** ------------------------------------------------------------------------------------------------------------- */
77inline bool CodeMotionPass::isAcceptableTarget(Statement * stmt, ScopeSet & scopeSet, const PabloBlock * const block) {
78    // Scan through this statement's users to see if they're all in a nested scope. If so,
79    // find the least common ancestor of the scope blocks. If it is not the current scope,
80    // then we can sink the instruction.
81    assert (scopeSet.empty());
82    if (isa<Branch>(stmt)) {
83        for (Var * def : cast<Branch>(stmt)->getEscaped()) {
84            if (!findScopeUsages(def, scopeSet, block, cast<Branch>(stmt)->getBody())) {
85                return false;
86            }
87        }
88    } else if (!isa<Assign>(stmt)) {
89        return findScopeUsages(stmt, scopeSet, block, nullptr);
90    }
91    return false;
92}
93
94/** ------------------------------------------------------------------------------------------------------------- *
95 * @brief sink
96 ** ------------------------------------------------------------------------------------------------------------- */
97inline void CodeMotionPass::sink(PabloBlock * const block) {
98    ScopeSet scopes;
99    Statement * stmt = block->back(); // note: reverse AST traversal
100    while (stmt) {
101        Statement * prevNode = stmt->getPrevNode();
102        if (isAcceptableTarget(stmt, scopes, block)) {
103            assert (scopes.size() > 0);
104            while (scopes.size() > 1) {
105                // Find the LCA of both scopes then add the LCA back to the list of scopes.
106                PabloBlock * scope1 = scopes.back(); scopes.pop_back();
107                unsigned depth1 = depthTo(scope1, block);
108
109                PabloBlock * scope2 = scopes.back(); scopes.pop_back();
110                unsigned depth2 = depthTo(scope2, block);
111
112                // If one of these scopes is nested deeper than the other, scan upwards through
113                // the scope tree until both scopes are at the same depth.
114                while (depth1 > depth2) {
115                    scope1 = scope1->getPredecessor();
116                    --depth1;
117                }
118                while (depth1 < depth2) {
119                    scope2 = scope2->getPredecessor();
120                    --depth2;
121                }
122
123                // Then iteratively step backwards until we find a matching set of scopes; this
124                // must be the LCA of our original scopes.
125                while (scope1 != scope2) {
126                    assert (scope1 && scope2);
127                    scope1 = scope1->getPredecessor();
128                    scope2 = scope2->getPredecessor();
129                }
130                assert (scope1);
131                // But if the LCA is the current block, we can't sink the statement.
132                if (scope1 == block) {
133                    goto abort;
134                }
135                scopes.push_back(scope1);
136            }
137            assert (scopes.size() == 1);
138            assert (isa<Branch>(stmt) ? (cast<Branch>(stmt)->getBody() != scopes.front()) : true);
139            stmt->insertBefore(scopes.front()->front());
140        }
141abort:  scopes.clear();
142        stmt = prevNode;
143    }
144}
145
146/** ------------------------------------------------------------------------------------------------------------- *
147 * @brief hoistWhileLoopInvariants
148 ** ------------------------------------------------------------------------------------------------------------- */
149void CodeMotionPass::hoistLoopInvariants(While * loop) {
150    flat_set<const PabloAST *> loopVariants;
151    for (Var * variant : loop->getEscaped()) {
152        loopVariants.insert(variant);
153    }
154    Statement * outerNode = loop->getPrevNode();
155    Statement * stmt = loop->getBody()->front();
156    while (stmt) {
157        if (isa<Branch>(stmt)) {
158            for (Var * var : cast<Branch>(stmt)->getEscaped()) {
159                loopVariants.insert(var);
160            }
161        } else {
162            bool invariant = true;
163            for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
164                if (loopVariants.count(stmt->getOperand(i)) != 0) {
165                    invariant = false;
166                    break;
167                }
168            }
169            if (LLVM_UNLIKELY(invariant)) {
170                Statement * next = stmt->getNextNode();
171                stmt->insertAfter(outerNode);
172                outerNode = stmt;
173                stmt = next;
174            } else {
175                loopVariants.insert(stmt);
176                stmt = stmt->getNextNode();
177            }
178        }
179    }
180}
181
182}
Note: See TracBrowser for help on using the repository browser.