source: icGREP/icgrep-devel/icgrep/pablo/optimizers/codemotionpass.cpp @ 4896

Last change on this file since 4896 was 4896, checked in by nmedfort, 3 years ago

Work on coalescing algorithm + minor changes.

File size: 8.1 KB
Line 
1#include "codemotionpass.h"
2#include <pablo/codegenstate.h>
3#include <pablo/analysis/pabloverifier.hpp>
4#include <boost/container/flat_set.hpp>
5
6using namespace boost;
7using namespace boost::container;
8
9namespace pablo {
10
11/** ------------------------------------------------------------------------------------------------------------- *
12 * @brief optimize
13 ** ------------------------------------------------------------------------------------------------------------- */
14bool CodeMotionPass::optimize(PabloFunction & function) {
15    CodeMotionPass::process(function.getEntryBlock());
16    #ifndef NDEBUG
17    PabloVerifier::verify(function, "post-code-motion");
18    #endif
19    return true;
20}
21
22/** ------------------------------------------------------------------------------------------------------------- *
23 * @brief process
24 ** ------------------------------------------------------------------------------------------------------------- */
25void CodeMotionPass::process(PabloBlock * const block) {
26    sink(block);
27    for (Statement * stmt : *block) {
28        if (isa<If>(stmt)) {
29            process(cast<If>(stmt)->getBody());
30        } else if (isa<While>(stmt)) {
31            process(cast<While>(stmt)->getBody());
32            // TODO: if we analyzed the probability of this loop being executed once, twice, or many times, we could
33            // determine whether hoisting will helpful or harmful to the expected run time.
34            hoistLoopInvariants(cast<While>(stmt));
35        }
36    }
37}
38
39/** ------------------------------------------------------------------------------------------------------------- *
40 * @brief isSafeToMove
41 ** ------------------------------------------------------------------------------------------------------------- */
42inline static bool isSafeToMove(Statement * stmt) {
43    return !isa<Assign>(stmt) && !isa<Next>(stmt);
44}
45
46/** ------------------------------------------------------------------------------------------------------------- *
47 * @brief calculateDepthToCurrentBlock
48 ** ------------------------------------------------------------------------------------------------------------- */
49inline static unsigned calculateDepthToCurrentBlock(const PabloBlock * scope, const PabloBlock * const root) {
50    unsigned depth = 0;
51    while (scope != root) {
52        ++depth;
53        assert (scope);
54        scope = scope->getParent();
55    }
56    return depth;
57}
58
59/** ------------------------------------------------------------------------------------------------------------- *
60 * @brief findScopeUsages
61 ** ------------------------------------------------------------------------------------------------------------- */
62template <class ScopeSet>
63inline bool findScopeUsages(Statement * stmt, ScopeSet & scopeSet, const PabloBlock * const block, const PabloBlock * const blocker) {
64    for (PabloAST * use : stmt->users()) {
65        assert (isa<Statement>(use));
66        PabloBlock * const parent = cast<Statement>(use)->getParent();
67        if (LLVM_LIKELY(parent == block)) {
68            return false;
69        }
70        if (parent != blocker) {
71            scopeSet.insert(parent);
72        }
73    }
74    return true;
75}
76
77/** ------------------------------------------------------------------------------------------------------------- *
78 * @brief isAcceptableTarget
79 ** ------------------------------------------------------------------------------------------------------------- */
80inline bool CodeMotionPass::isAcceptableTarget(Statement * stmt, ScopeSet & scopeSet, const PabloBlock * const block) {
81    // Scan through this statement's users to see if they're all in a nested scope. If so,
82    // find the least common ancestor of the scope blocks. If it is not the current scope,
83    // then we can sink the instruction.
84    if (isa<If>(stmt)) {
85        for (Assign * def : cast<If>(stmt)->getDefined()) {
86            if (!findScopeUsages(def, scopeSet, block, cast<If>(stmt)->getBody())) {
87                return false;
88            }
89        }
90    } else if (isa<While>(stmt)) {
91        for (Next * var : cast<While>(stmt)->getVariants()) {
92            if (escapes(var) && !findScopeUsages(var, scopeSet, block, cast<While>(stmt)->getBody())) {
93                return false;
94            }
95        }
96    } else if (isSafeToMove(stmt)) {
97        return findScopeUsages(stmt, scopeSet, block, nullptr);
98    }
99    return false;
100}
101
102/** ------------------------------------------------------------------------------------------------------------- *
103 * @brief sink
104 ** ------------------------------------------------------------------------------------------------------------- */
105void CodeMotionPass::sink(PabloBlock * const block) {
106    ScopeSet scopes;
107    Statement * stmt = block->back(); // note: reverse AST traversal
108    while (stmt) {
109        Statement * prevNode = stmt->getPrevNode();
110        if (isAcceptableTarget(stmt, scopes, block)) {
111            assert (scopes.size() > 0);
112            while (scopes.size() > 1) {
113                // Find the LCA of both scopes then add the LCA back to the list of scopes.
114                PabloBlock * scope1 = scopes.back(); scopes.pop_back();
115                unsigned depth1 = calculateDepthToCurrentBlock(scope1, block);
116
117                PabloBlock * scope2 = scopes.back(); scopes.pop_back();
118                unsigned depth2 = calculateDepthToCurrentBlock(scope2, block);
119
120                // If one of these scopes is nested deeper than the other, scan upwards through
121                // the scope tree until both scopes are at the same depth.
122                while (depth1 > depth2) {
123                    scope1 = scope1->getParent();
124                    --depth1;
125                }
126                while (depth1 < depth2) {
127                    scope2 = scope2->getParent();
128                    --depth2;
129                }
130
131                // Then iteratively step backwards until we find a matching set of scopes; this
132                // must be the LCA of our original scopes.
133                while (scope1 != scope2) {
134                    scope1 = scope1->getParent();
135                    scope2 = scope2->getParent();
136                }
137                assert (scope1 && scope2);
138                // But if the LCA is the current block, we can't sink the statement.
139                if (scope1 == block) {
140                    goto abort;
141                }
142                scopes.push_back(scope1);
143            }
144            assert (scopes.size() == 1);
145            assert (isa<If>(stmt) ? (cast<If>(stmt)->getBody() != scopes.front()) : true);
146            assert (isa<While>(stmt) ? (cast<While>(stmt)->getBody() != scopes.front()) : true);
147            stmt->insertBefore(scopes.front()->front());
148        }
149abort:  scopes.clear();
150        stmt = prevNode;
151    }
152}
153
154/** ------------------------------------------------------------------------------------------------------------- *
155 * @brief hoistWhileLoopInvariants
156 ** ------------------------------------------------------------------------------------------------------------- */
157void CodeMotionPass::hoistLoopInvariants(While * loop) {
158    flat_set<const PabloAST *> loopVariants;
159    for (Next * variant : loop->getVariants()) {
160        loopVariants.insert(variant);
161        loopVariants.insert(variant->getInitial());
162    }
163    Statement * outerNode = loop->getPrevNode();
164    Statement * stmt = loop->getBody()->front();
165    while (stmt) {
166        if (isa<If>(stmt)) {
167            for (Assign * def : cast<If>(stmt)->getDefined()) {
168                loopVariants.insert(def);
169            }
170        } else if (isa<While>(stmt)) {
171            for (Next * var : cast<While>(stmt)->getVariants()) {
172                loopVariants.insert(var);
173            }
174        } else {
175            bool invariant = true;
176            for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
177                if (loopVariants.count(stmt->getOperand(i)) != 0) {
178                    invariant = false;
179                    break;
180                }
181            }
182            if (LLVM_UNLIKELY(invariant)) {
183                Statement * next = stmt->getNextNode();
184                stmt->insertAfter(outerNode);
185                outerNode = stmt;
186                stmt = next;
187            } else {
188                loopVariants.insert(stmt);
189                stmt = stmt->getNextNode();
190            }
191        }
192    }
193}
194
195}
Note: See TracBrowser for help on using the repository browser.