source: icGREP/icgrep-devel/icgrep/pablo/optimizers/codemotionpass.cpp @ 4854

Last change on this file since 4854 was 4854, checked in by nmedfort, 4 years ago

Made code sinking a full code motion pass.

File size: 9.0 KB
Line 
1#include "codemotionpass.h"
2#include <pablo/function.h>
3#include <pablo/ps_while.h>
4#include <pablo/analysis/pabloverifier.hpp>
5#ifdef USE_BOOST
6#include <boost/container/flat_set.hpp>
7#else
8#include <unordered_set>
9#endif
10#include <pablo/printer_pablos.h>
11#include <iostream>
12
13namespace pablo {
14
15#ifdef USE_BOOST
16using LoopVariants = boost::container::flat_set<const PabloAST *>;
17#else
18using LoopVariants = std::unordered_set<const PabloAST *>;
19#endif
20
21/** ------------------------------------------------------------------------------------------------------------- *
22 * @brief optimize
23 ** ------------------------------------------------------------------------------------------------------------- */
24bool CodeMotionPass::optimize(PabloFunction & function) {
25    CodeMotionPass::process(function.getEntryBlock());
26    #ifndef NDEBUG
27    PabloVerifier::verify(function, "post-sinking");
28    #endif
29    return true;
30}
31
32/** ------------------------------------------------------------------------------------------------------------- *
33 * @brief process
34 ** ------------------------------------------------------------------------------------------------------------- */
35void CodeMotionPass::process(PabloBlock & block) {
36    sink(block);
37    for (Statement * stmt : block) {
38        if (isa<If>(stmt)) {
39            process(cast<If>(stmt)->getBody());
40        } else if (isa<While>(stmt)) {
41            process(cast<While>(stmt)->getBody());
42            // TODO: if we analyzed the probability of this loop being executed once, twice or many times, we could
43            // determine hoisting will helpful or harmful to the expected run time.
44            hoistLoopInvariants(cast<While>(stmt));
45        }
46    }
47}
48
49
50/** ------------------------------------------------------------------------------------------------------------- *
51 * @brief isSafeToMove
52 ** ------------------------------------------------------------------------------------------------------------- */
53inline static bool isSafeToMove(Statement * stmt) {
54    return !isa<Assign>(stmt) && !isa<Next>(stmt);
55}
56
57/** ------------------------------------------------------------------------------------------------------------- *
58 * @brief calculateDepthToCurrentBlock
59 ** ------------------------------------------------------------------------------------------------------------- */
60inline static unsigned calculateDepthToCurrentBlock(const PabloBlock * scope, const PabloBlock & root) {
61    unsigned depth = 0;
62    while (scope != &root) {
63        ++depth;
64        assert (scope);
65        scope = scope->getParent();
66    }
67    return depth;
68}
69
70/** ------------------------------------------------------------------------------------------------------------- *
71 * @brief findScopeUsages
72 ** ------------------------------------------------------------------------------------------------------------- */
73template <class ScopeSet>
74inline bool findScopeUsages(Statement * stmt, ScopeSet & scopeSet, const PabloBlock & block) {
75    for (PabloAST * use : stmt->users()) {
76        assert (isa<Statement>(use));
77        PabloBlock * const parent = cast<Statement>(use)->getParent();
78        if (LLVM_LIKELY(parent == &block)) {
79            return false;
80        }
81        scopeSet.insert(parent);
82    }
83    return true;
84}
85
86/** ------------------------------------------------------------------------------------------------------------- *
87 * @brief findScopeUsages
88 ** ------------------------------------------------------------------------------------------------------------- */
89template <class ScopeSet>
90inline bool findScopeUsages(Statement * stmt, ScopeSet & scopeSet, const PabloBlock & block, const PabloBlock & ignored) {
91    for (PabloAST * use : stmt->users()) {
92        assert (isa<Statement>(use));
93        PabloBlock * const parent = cast<Statement>(use)->getParent();
94        if (LLVM_LIKELY(parent == &block)) {
95            return false;
96        }
97        if (parent != &ignored) {
98            scopeSet.insert(parent);
99        }
100    }
101    return true;
102}
103
104/** ------------------------------------------------------------------------------------------------------------- *
105 * @brief isAcceptableTarget
106 ** ------------------------------------------------------------------------------------------------------------- */
107inline bool CodeMotionPass::isAcceptableTarget(Statement * stmt, ScopeSet & scopeSet, const PabloBlock & block) {
108    // Scan through this statement's users to see if they're all in a nested scope. If so,
109    // find the least common ancestor of the scope blocks. If it is not the current scope,
110    // then we can sink the instruction.
111    if (isa<If>(stmt)) {
112        for (Assign * def : cast<If>(stmt)->getDefined()) {
113            if (!findScopeUsages(def, scopeSet, block, cast<If>(stmt)->getBody())) {
114                return false;
115            }
116        }
117    } else if (isa<While>(stmt)) {
118        for (Next * var : cast<While>(stmt)->getVariants()) {
119            if (escapes(var) && !findScopeUsages(var, scopeSet, block, cast<While>(stmt)->getBody())) {
120                return false;
121            }
122        }
123    } else if (isSafeToMove(stmt)) {
124        return findScopeUsages(stmt, scopeSet, block);
125    }
126    return false;
127}
128
129/** ------------------------------------------------------------------------------------------------------------- *
130 * @brief sink
131 ** ------------------------------------------------------------------------------------------------------------- */
132void CodeMotionPass::sink(PabloBlock & block) {
133    ScopeSet scopes;
134    Statement * stmt = block.back(); // note: reverse AST traversal
135    while (stmt) {
136        Statement * prevNode = stmt->getPrevNode();
137        if (isAcceptableTarget(stmt, scopes, block)) {
138            assert (scopes.size() > 0);
139            while (scopes.size() > 1) {
140                // Find the LCA of both scopes then add the LCA back to the list of scopes.
141                PabloBlock * scope1 = scopes.back(); scopes.pop_back();
142                unsigned depth1 = calculateDepthToCurrentBlock(scope1, block);
143
144                PabloBlock * scope2 = scopes.back(); scopes.pop_back();
145                unsigned depth2 = calculateDepthToCurrentBlock(scope2, block);
146
147                // If one of these scopes is nested deeper than the other, scan upwards through
148                // the scope tree until both scopes are at the same depth.
149                while (depth1 > depth2) {
150                    scope1 = scope1->getParent();
151                    --depth1;
152                }
153                while (depth1 < depth2) {
154                    scope2 = scope2->getParent();
155                    --depth2;
156                }
157
158                // Then iteratively step backwards until we find a matching set of scopes; this
159                // must be the LCA of our original scopes.
160                while (scope1 != scope2) {
161                    scope1 = scope1->getParent();
162                    scope2 = scope2->getParent();
163                }
164                assert (scope1 && scope2);
165                // But if the LCA is the current block, we can't sink the statement.
166                if (scope1 == &block) {
167                    goto abort;
168                }
169                scopes.push_back(scope1);
170            }
171            assert (scopes.size() == 1);
172            assert (isa<If>(stmt) ? &(cast<If>(stmt)->getBody()) != scopes.front() : true);
173            assert (isa<While>(stmt) ? &(cast<While>(stmt)->getBody()) != scopes.front() : true);
174            stmt->insertBefore(scopes.front()->front());
175        }
176abort:  scopes.clear();
177        stmt = prevNode;
178    }
179}
180
181/** ------------------------------------------------------------------------------------------------------------- *
182 * @brief hoistWhileLoopInvariants
183 ** ------------------------------------------------------------------------------------------------------------- */
184void CodeMotionPass::hoistLoopInvariants(While * loop) {
185    LoopVariants loopVariants;
186    for (Next * variant : loop->getVariants()) {
187        loopVariants.insert(variant);
188        loopVariants.insert(variant->getInitial());
189    }
190    Statement * outerNode = loop->getPrevNode();
191    Statement * stmt = loop->getBody().front();
192    while (stmt) {
193        if (isa<If>(stmt)) {
194            for (Assign * def : cast<If>(stmt)->getDefined()) {
195                loopVariants.insert(def);
196            }
197        } else if (isa<While>(stmt)) {
198            for (Next * var : cast<While>(stmt)->getVariants()) {
199                loopVariants.insert(var);
200            }
201        } else {
202            bool invariant = true;
203            for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
204                if (loopVariants.count(stmt->getOperand(i)) != 0) {
205                    invariant = false;
206                    break;
207                }
208            }
209            if (LLVM_UNLIKELY(invariant)) {
210                Statement * next = stmt->getNextNode();
211                stmt->insertAfter(outerNode);
212                outerNode = stmt;
213                stmt = next;
214            } else {
215                loopVariants.insert(stmt);
216                stmt = stmt->getNextNode();
217            }
218        }
219    }
220}
221
222}
Note: See TracBrowser for help on using the repository browser.