source: icGREP/icgrep-devel/icgrep/pablo/optimizers/codemotionpass.cpp @ 4870

Last change on this file since 4870 was 4870, checked in by nmedfort, 3 years ago

Bug fix for Multiplexing. Added ability to set the body of a If/While? node after creation.

File size: 8.4 KB
Line 
1#include "codemotionpass.h"
2#include <pablo/function.h>
3#include <pablo/ps_while.h>
4#include <pablo/analysis/pabloverifier.hpp>
5#ifdef USE_BOOST
6#include <boost/container/flat_set.hpp>
7#else
8#include <unordered_set>
9#endif
10#include <pablo/printer_pablos.h>
11#include <iostream>
12
13namespace pablo {
14
15#ifdef USE_BOOST
16using LoopVariants = boost::container::flat_set<const PabloAST *>;
17#else
18using LoopVariants = std::unordered_set<const PabloAST *>;
19#endif
20
21/** ------------------------------------------------------------------------------------------------------------- *
22 * @brief optimize
23 ** ------------------------------------------------------------------------------------------------------------- */
24bool CodeMotionPass::optimize(PabloFunction & function) {
25    CodeMotionPass::process(function.getEntryBlock());
26    #ifndef NDEBUG
27    PabloVerifier::verify(function, "post-code-motion");
28    #endif
29    return true;
30}
31
32/** ------------------------------------------------------------------------------------------------------------- *
33 * @brief process
34 ** ------------------------------------------------------------------------------------------------------------- */
35void CodeMotionPass::process(PabloBlock * const block) {
36    sink(block);
37    for (Statement * stmt : *block) {
38        if (isa<If>(stmt)) {
39            process(cast<If>(stmt)->getBody());
40        } else if (isa<While>(stmt)) {
41            process(cast<While>(stmt)->getBody());
42            // TODO: if we analyzed the probability of this loop being executed once, twice, or many times, we could
43            // determine whether hoisting will helpful or harmful to the expected run time.
44            hoistLoopInvariants(cast<While>(stmt));
45        }
46    }
47}
48
49/** ------------------------------------------------------------------------------------------------------------- *
50 * @brief isSafeToMove
51 ** ------------------------------------------------------------------------------------------------------------- */
52inline static bool isSafeToMove(Statement * stmt) {
53    return !isa<Assign>(stmt) && !isa<Next>(stmt);
54}
55
56/** ------------------------------------------------------------------------------------------------------------- *
57 * @brief calculateDepthToCurrentBlock
58 ** ------------------------------------------------------------------------------------------------------------- */
59inline static unsigned calculateDepthToCurrentBlock(const PabloBlock * scope, const PabloBlock * const root) {
60    unsigned depth = 0;
61    while (scope != root) {
62        ++depth;
63        assert (scope);
64        scope = scope->getParent();
65    }
66    return depth;
67}
68
69/** ------------------------------------------------------------------------------------------------------------- *
70 * @brief findScopeUsages
71 ** ------------------------------------------------------------------------------------------------------------- */
72template <class ScopeSet>
73inline bool findScopeUsages(Statement * stmt, ScopeSet & scopeSet, const PabloBlock * const block, const PabloBlock * const ignored = nullptr) {
74    for (PabloAST * use : stmt->users()) {
75        assert (isa<Statement>(use));
76        PabloBlock * const parent = cast<Statement>(use)->getParent();
77        if (LLVM_LIKELY(parent == block)) {
78            return false;
79        }
80        if (parent != ignored) {
81            scopeSet.insert(parent);
82        }
83    }
84    return true;
85}
86
87/** ------------------------------------------------------------------------------------------------------------- *
88 * @brief isAcceptableTarget
89 ** ------------------------------------------------------------------------------------------------------------- */
90inline bool CodeMotionPass::isAcceptableTarget(Statement * stmt, ScopeSet & scopeSet, const PabloBlock * const block) {
91    // Scan through this statement's users to see if they're all in a nested scope. If so,
92    // find the least common ancestor of the scope blocks. If it is not the current scope,
93    // then we can sink the instruction.
94    if (isa<If>(stmt)) {
95        for (Assign * def : cast<If>(stmt)->getDefined()) {
96            if (!findScopeUsages(def, scopeSet, block, cast<If>(stmt)->getBody())) {
97                return false;
98            }
99        }
100    } else if (isa<While>(stmt)) {
101        for (Next * var : cast<While>(stmt)->getVariants()) {
102            if (escapes(var) && !findScopeUsages(var, scopeSet, block, cast<While>(stmt)->getBody())) {
103                return false;
104            }
105        }
106    } else if (isSafeToMove(stmt)) {
107        return findScopeUsages(stmt, scopeSet, block);
108    }
109    return false;
110}
111
112/** ------------------------------------------------------------------------------------------------------------- *
113 * @brief sink
114 ** ------------------------------------------------------------------------------------------------------------- */
115void CodeMotionPass::sink(PabloBlock * const block) {
116    ScopeSet scopes;
117    Statement * stmt = block->back(); // note: reverse AST traversal
118    while (stmt) {
119        Statement * prevNode = stmt->getPrevNode();
120        if (isAcceptableTarget(stmt, scopes, block)) {
121            assert (scopes.size() > 0);
122            while (scopes.size() > 1) {
123                // Find the LCA of both scopes then add the LCA back to the list of scopes.
124                PabloBlock * scope1 = scopes.back(); scopes.pop_back();
125                unsigned depth1 = calculateDepthToCurrentBlock(scope1, block);
126
127                PabloBlock * scope2 = scopes.back(); scopes.pop_back();
128                unsigned depth2 = calculateDepthToCurrentBlock(scope2, block);
129
130                // If one of these scopes is nested deeper than the other, scan upwards through
131                // the scope tree until both scopes are at the same depth.
132                while (depth1 > depth2) {
133                    scope1 = scope1->getParent();
134                    --depth1;
135                }
136                while (depth1 < depth2) {
137                    scope2 = scope2->getParent();
138                    --depth2;
139                }
140
141                // Then iteratively step backwards until we find a matching set of scopes; this
142                // must be the LCA of our original scopes.
143                while (scope1 != scope2) {
144                    scope1 = scope1->getParent();
145                    scope2 = scope2->getParent();
146                }
147                assert (scope1 && scope2);
148                // But if the LCA is the current block, we can't sink the statement.
149                if (scope1 == block) {
150                    goto abort;
151                }
152                scopes.push_back(scope1);
153            }
154            assert (scopes.size() == 1);
155            assert (isa<If>(stmt) ? (cast<If>(stmt)->getBody() != scopes.front()) : true);
156            assert (isa<While>(stmt) ? (cast<While>(stmt)->getBody() != scopes.front()) : true);
157            stmt->insertBefore(scopes.front()->front());
158        }
159abort:  scopes.clear();
160        stmt = prevNode;
161    }
162}
163
164/** ------------------------------------------------------------------------------------------------------------- *
165 * @brief hoistWhileLoopInvariants
166 ** ------------------------------------------------------------------------------------------------------------- */
167void CodeMotionPass::hoistLoopInvariants(While * loop) {
168    LoopVariants loopVariants;
169    for (Next * variant : loop->getVariants()) {
170        loopVariants.insert(variant);
171        loopVariants.insert(variant->getInitial());
172    }
173    Statement * outerNode = loop->getPrevNode();
174    Statement * stmt = loop->getBody()->front();
175    while (stmt) {
176        if (isa<If>(stmt)) {
177            for (Assign * def : cast<If>(stmt)->getDefined()) {
178                loopVariants.insert(def);
179            }
180        } else if (isa<While>(stmt)) {
181            for (Next * var : cast<While>(stmt)->getVariants()) {
182                loopVariants.insert(var);
183            }
184        } else {
185            bool invariant = true;
186            for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
187                if (loopVariants.count(stmt->getOperand(i)) != 0) {
188                    invariant = false;
189                    break;
190                }
191            }
192            if (LLVM_UNLIKELY(invariant)) {
193                Statement * next = stmt->getNextNode();
194                stmt->insertAfter(outerNode);
195                outerNode = stmt;
196                stmt = next;
197            } else {
198                loopVariants.insert(stmt);
199                stmt = stmt->getNextNode();
200            }
201        }
202    }
203}
204
205}
Note: See TracBrowser for help on using the repository browser.