source: icGREP/icgrep-devel/icgrep/pablo/optimizers/codemotionpass.cpp @ 4856

Last change on this file since 4856 was 4856, checked in by nmedfort, 3 years ago

Bug fix for use-def correctness regarding escaping values of If and While nodes.

File size: 9.0 KB
Line 
1#include "codemotionpass.h"
2#include <pablo/function.h>
3#include <pablo/ps_while.h>
4#include <pablo/analysis/pabloverifier.hpp>
5#ifdef USE_BOOST
6#include <boost/container/flat_set.hpp>
7#else
8#include <unordered_set>
9#endif
10#include <pablo/printer_pablos.h>
11#include <iostream>
12
13namespace pablo {
14
15#ifdef USE_BOOST
16using LoopVariants = boost::container::flat_set<const PabloAST *>;
17#else
18using LoopVariants = std::unordered_set<const PabloAST *>;
19#endif
20
21/** ------------------------------------------------------------------------------------------------------------- *
22 * @brief optimize
23 ** ------------------------------------------------------------------------------------------------------------- */
24bool CodeMotionPass::optimize(PabloFunction & function) {
25    CodeMotionPass::process(function.getEntryBlock());
26    #ifndef NDEBUG
27    PabloVerifier::verify(function, "post-sinking");
28    #endif
29    return true;
30}
31
32/** ------------------------------------------------------------------------------------------------------------- *
33 * @brief process
34 ** ------------------------------------------------------------------------------------------------------------- */
35void CodeMotionPass::process(PabloBlock & block) {
36    sink(block);
37    for (Statement * stmt : block) {
38        if (isa<If>(stmt)) {
39            process(cast<If>(stmt)->getBody());
40        } else if (isa<While>(stmt)) {
41            process(cast<While>(stmt)->getBody());
42            // TODO: if we analyzed the probability of this loop being executed once, twice or many times, we could
43            // determine hoisting will helpful or harmful to the expected run time.
44            hoistLoopInvariants(cast<While>(stmt));
45        }
46    }
47}
48
49/** ------------------------------------------------------------------------------------------------------------- *
50 * @brief isSafeToMove
51 ** ------------------------------------------------------------------------------------------------------------- */
52inline static bool isSafeToMove(Statement * stmt) {
53    return !isa<Assign>(stmt) && !isa<Next>(stmt);
54}
55
56/** ------------------------------------------------------------------------------------------------------------- *
57 * @brief calculateDepthToCurrentBlock
58 ** ------------------------------------------------------------------------------------------------------------- */
59inline static unsigned calculateDepthToCurrentBlock(const PabloBlock * scope, const PabloBlock & root) {
60    unsigned depth = 0;
61    while (scope != &root) {
62        ++depth;
63        assert (scope);
64        scope = scope->getParent();
65    }
66    return depth;
67}
68
69/** ------------------------------------------------------------------------------------------------------------- *
70 * @brief findScopeUsages
71 ** ------------------------------------------------------------------------------------------------------------- */
72template <class ScopeSet>
73inline bool findScopeUsages(Statement * stmt, ScopeSet & scopeSet, const PabloBlock & block) {
74    for (PabloAST * use : stmt->users()) {
75        assert (isa<Statement>(use));
76        PabloBlock * const parent = cast<Statement>(use)->getParent();
77        if (LLVM_LIKELY(parent == &block)) {
78            return false;
79        }
80        scopeSet.insert(parent);
81    }
82    return true;
83}
84
85/** ------------------------------------------------------------------------------------------------------------- *
86 * @brief findScopeUsages
87 ** ------------------------------------------------------------------------------------------------------------- */
88template <class ScopeSet>
89inline bool findScopeUsages(Statement * stmt, ScopeSet & scopeSet, const PabloBlock & block, const PabloBlock & ignored) {
90    for (PabloAST * use : stmt->users()) {
91        assert (isa<Statement>(use));
92        PabloBlock * const parent = cast<Statement>(use)->getParent();
93        if (LLVM_LIKELY(parent == &block)) {
94            return false;
95        }
96        if (parent != &ignored) {
97            scopeSet.insert(parent);
98        }
99    }
100    return true;
101}
102
103/** ------------------------------------------------------------------------------------------------------------- *
104 * @brief isAcceptableTarget
105 ** ------------------------------------------------------------------------------------------------------------- */
106inline bool CodeMotionPass::isAcceptableTarget(Statement * stmt, ScopeSet & scopeSet, const PabloBlock & block) {
107    // Scan through this statement's users to see if they're all in a nested scope. If so,
108    // find the least common ancestor of the scope blocks. If it is not the current scope,
109    // then we can sink the instruction.
110    if (isa<If>(stmt)) {
111        for (Assign * def : cast<If>(stmt)->getDefined()) {
112            if (!findScopeUsages(def, scopeSet, block, cast<If>(stmt)->getBody())) {
113                return false;
114            }
115        }
116    } else if (isa<While>(stmt)) {
117        for (Next * var : cast<While>(stmt)->getVariants()) {
118            if (escapes(var) && !findScopeUsages(var, scopeSet, block, cast<While>(stmt)->getBody())) {
119                return false;
120            }
121        }
122    } else if (isSafeToMove(stmt)) {
123        return findScopeUsages(stmt, scopeSet, block);
124    }
125    return false;
126}
127
128/** ------------------------------------------------------------------------------------------------------------- *
129 * @brief sink
130 ** ------------------------------------------------------------------------------------------------------------- */
131void CodeMotionPass::sink(PabloBlock & block) {
132    ScopeSet scopes;
133    Statement * stmt = block.back(); // note: reverse AST traversal
134    while (stmt) {
135        Statement * prevNode = stmt->getPrevNode();
136        if (isAcceptableTarget(stmt, scopes, block)) {
137            assert (scopes.size() > 0);
138            while (scopes.size() > 1) {
139                // Find the LCA of both scopes then add the LCA back to the list of scopes.
140                PabloBlock * scope1 = scopes.back(); scopes.pop_back();
141                unsigned depth1 = calculateDepthToCurrentBlock(scope1, block);
142
143                PabloBlock * scope2 = scopes.back(); scopes.pop_back();
144                unsigned depth2 = calculateDepthToCurrentBlock(scope2, block);
145
146                // If one of these scopes is nested deeper than the other, scan upwards through
147                // the scope tree until both scopes are at the same depth.
148                while (depth1 > depth2) {
149                    scope1 = scope1->getParent();
150                    --depth1;
151                }
152                while (depth1 < depth2) {
153                    scope2 = scope2->getParent();
154                    --depth2;
155                }
156
157                // Then iteratively step backwards until we find a matching set of scopes; this
158                // must be the LCA of our original scopes.
159                while (scope1 != scope2) {
160                    scope1 = scope1->getParent();
161                    scope2 = scope2->getParent();
162                }
163                assert (scope1 && scope2);
164                // But if the LCA is the current block, we can't sink the statement.
165                if (scope1 == &block) {
166                    goto abort;
167                }
168                scopes.push_back(scope1);
169            }
170            assert (scopes.size() == 1);
171            assert (isa<If>(stmt) ? &(cast<If>(stmt)->getBody()) != scopes.front() : true);
172            assert (isa<While>(stmt) ? &(cast<While>(stmt)->getBody()) != scopes.front() : true);
173            stmt->insertBefore(scopes.front()->front());
174        }
175abort:  scopes.clear();
176        stmt = prevNode;
177    }
178}
179
180/** ------------------------------------------------------------------------------------------------------------- *
181 * @brief hoistWhileLoopInvariants
182 ** ------------------------------------------------------------------------------------------------------------- */
183void CodeMotionPass::hoistLoopInvariants(While * loop) {
184    LoopVariants loopVariants;
185    for (Next * variant : loop->getVariants()) {
186        loopVariants.insert(variant);
187        loopVariants.insert(variant->getInitial());
188    }
189    Statement * outerNode = loop->getPrevNode();
190    Statement * stmt = loop->getBody().front();
191    while (stmt) {
192        if (isa<If>(stmt)) {
193            for (Assign * def : cast<If>(stmt)->getDefined()) {
194                loopVariants.insert(def);
195            }
196        } else if (isa<While>(stmt)) {
197            for (Next * var : cast<While>(stmt)->getVariants()) {
198                loopVariants.insert(var);
199            }
200        } else {
201            bool invariant = true;
202            for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
203                if (loopVariants.count(stmt->getOperand(i)) != 0) {
204                    invariant = false;
205                    break;
206                }
207            }
208            if (LLVM_UNLIKELY(invariant)) {
209                Statement * next = stmt->getNextNode();
210                stmt->insertAfter(outerNode);
211                outerNode = stmt;
212                stmt = next;
213            } else {
214                loopVariants.insert(stmt);
215                stmt = stmt->getNextNode();
216            }
217        }
218    }
219}
220
221}
Note: See TracBrowser for help on using the repository browser.