source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 4748

Last change on this file since 4748 was 4748, checked in by nmedfort, 4 years ago

First (hopefully) working version of the boolean reassociation pass + some bug fixes.

File size: 15.0 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <pablo/builder.hpp>
6#include <boost/graph/adjacency_list.hpp>
7#include <boost/graph/topological_sort.hpp>
8#include <queue>
9
10using namespace boost;
11using namespace boost::container;
12
13namespace pablo {
14
15bool BooleanReassociationPass::optimize(PabloFunction & function) {
16    BooleanReassociationPass brp;
17    brp.scan(function);
18    return true;
19}
20
21/** ------------------------------------------------------------------------------------------------------------- *
22 * @brief scan
23 ** ------------------------------------------------------------------------------------------------------------- */
24void BooleanReassociationPass::scan(PabloFunction & function) {
25    Terminals terminals;
26    for (unsigned i = 0; i != function.getNumOfResults(); ++i) {
27        terminals.push_back(function.getResult(i));
28    }
29    scan(function.getEntryBlock(), std::move(terminals));
30}
31
32/** ------------------------------------------------------------------------------------------------------------- *
33 * @brief is_power_of_2
34 * @param n an integer
35 ** ------------------------------------------------------------------------------------------------------------- */
36static inline bool is_power_of_2(const size_t n) {
37    return ((n & (n - 1)) == 0);
38}
39
40/** ------------------------------------------------------------------------------------------------------------- *
41 * @brief log2_plus_one
42 ** ------------------------------------------------------------------------------------------------------------- */
43static inline size_t ceil_log2(const size_t n) {
44    return std::log2<size_t>(n) + (is_power_of_2(n) ? 0 : 1);
45}
46
47/** ------------------------------------------------------------------------------------------------------------- *
48 * @brief isACD
49 ** ------------------------------------------------------------------------------------------------------------- */
50static inline bool isaBooleanOperation(const PabloAST * const expr) {
51    assert (expr);
52    switch (expr->getClassTypeId()) {
53        case PabloAST::ClassTypeId::And:
54        case PabloAST::ClassTypeId::Or:
55        case PabloAST::ClassTypeId::Xor:
56            return true;
57        default:
58            return false;
59    }
60}
61
62using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
63using Vertex = Graph::vertex_descriptor;
64using VertexQueue = circular_buffer<Vertex>;
65
66/** ------------------------------------------------------------------------------------------------------------- *
67 * @brief isCutNecessary
68 ** ------------------------------------------------------------------------------------------------------------- */
69static inline bool isCutNecessary(const Vertex u, const Vertex v, const Graph & G, const std::vector<unsigned> & component) {
70    // Either this edge crosses a component boundary or the operations performed by the vertices differs, we need to cut
71    // the graph here and generate two partial equations.
72    if (LLVM_UNLIKELY(component[u] != component[v])) {
73        return true;
74    } else if (LLVM_UNLIKELY((in_degree(u, G) != 0) && (G[u]->getClassTypeId() != G[v]->getClassTypeId()))) {
75        return true;
76    }
77    return false;
78}
79
80/** ------------------------------------------------------------------------------------------------------------- *
81 * @brief push
82 ** ------------------------------------------------------------------------------------------------------------- */
83static inline void push(const Vertex u, VertexQueue & Q) {
84    if (LLVM_UNLIKELY(Q.full())) {
85        Q.set_capacity(Q.capacity() * 2);
86    }
87    Q.push_back(u);
88    assert (Q.back() == u);
89}
90
91/** ------------------------------------------------------------------------------------------------------------- *
92 * @brief pop
93 ** ------------------------------------------------------------------------------------------------------------- */
94static inline Vertex pop(VertexQueue & Q) {
95    assert (!Q.empty() && "Popping an empty vertex queue");
96    const Vertex u = Q.front();
97    Q.pop_front();
98    return u;
99}
100
101/** ------------------------------------------------------------------------------------------------------------- *
102 * @brief scan
103 ** ------------------------------------------------------------------------------------------------------------- */
104void BooleanReassociationPass::scan(PabloBlock & block, Terminals && terminals) {
105
106    using Map = std::unordered_map<PabloAST *, Vertex>;
107    using EdgeQueue = std::queue<std::pair<Vertex, Vertex>>;
108
109    for (Statement * stmt : block) {
110        if (isa<If>(stmt)) {
111            const auto & defs = cast<const If>(stmt)->getDefined();
112            Terminals terminals(defs.begin(), defs.end());
113            scan(cast<If>(stmt)->getBody(), std::move(terminals));
114        } else if (isa<While>(stmt)) {
115            const auto & vars = cast<const While>(stmt)->getVariants();
116            Terminals terminals(vars.begin(), vars.end());
117            scan(cast<While>(stmt)->getBody(), std::move(terminals));
118        }
119    }
120
121    // And, Or and Xor instructions are all associative, commutative and distributive operations. Thus we can
122    // safely rearrange expressions such as "((((a √ b) √ c) √ d) √ e) √ f" into "((a √ b) √ (c √ d)) √ (e √ f)".
123
124    VertexQueue Q(128);
125
126    for (;;) {
127
128        Graph G;
129        Map M;
130
131        // Generate a graph depicting the relationships between the terminals. If the original terminals
132        // cannot be optimized with this algorithm bypass them in favour of their operands. If those cannot
133        // be optimized, they'll be left as the initial terminals for the next "layer" of the AST.
134
135        for (Statement * const term : terminals) {
136            assert (term);
137            if (isaBooleanOperation(term)) {
138                if (LLVM_LIKELY(M.count(term) == 0)) {                   
139                    const Vertex v = add_vertex(term, G);
140                    assert (v < num_vertices(G));
141                    M.insert(std::make_pair(term, v));
142                    push(v, Q);
143                }
144            } else {
145                for (unsigned i = 0; i != term->getNumOperands(); ++i) {
146                    PabloAST * const op = term->getOperand(i);
147                    assert (op);
148                    if (LLVM_LIKELY(isa<Statement>(op) && M.count(op) == 0)) {
149                        const Vertex v = add_vertex(op, G);
150                        assert (v < num_vertices(G));
151                        M.insert(std::make_pair(op, v));
152                        push(v, Q);
153                    }
154                }
155            }           
156        }
157
158        if (Q.empty()) {
159            break;
160        }
161
162        for (;;) {
163            const Vertex u = pop(Q);
164            assert (u < num_vertices(G));
165            if (isaBooleanOperation(G[u])) {
166                // Scan through the use-def chains to locate any chains of rearrangable expressions and their inputs
167                Statement * stmt = cast<Statement>(G[u]);
168                for (unsigned i = 0; i != 2; ++i) {
169                    PabloAST * op = stmt->getOperand(i);
170                    auto f = M.find(op);
171                    if (f == M.end()) {
172                        const Vertex v = add_vertex(op, G);
173                        assert (v < num_vertices(G));
174                        f = M.insert(std::make_pair(op, v)).first;
175                        if (op->getClassTypeId() == stmt->getClassTypeId() && cast<Statement>(op)->getParent() == &block) {
176                            push(v, Q);
177                        }
178                    }
179                    add_edge(f->second, u, G);
180                }
181            }
182            if (Q.empty()) {
183                break;
184            }
185        }
186
187        // Generate a topological ordering for G; if one of our terminals happens to also be a partial computation of
188        // another terminal, we need to make sure we compute it as an independent subexpression.
189        std::vector<unsigned> ordering;
190        ordering.reserve(num_vertices(G));
191        topological_sort(G, std::back_inserter(ordering));
192        std::vector<unsigned> component(num_vertices(G));
193
194        for (;;) {
195
196            // Mark which computation component these vertices are in based on their topological (occurence) order.
197            unsigned components = 0;
198            for (auto u : ordering) {
199                unsigned id = 0;
200                // If this is a sink in G, it is the root of a new component.
201                if (out_degree(u, G) == 0) {
202                    id = ++components;
203                } else {
204                    for (auto e : make_iterator_range(out_edges(u, G))) {
205                        id = std::max(id, component[target(e, G)]);
206                    }
207                }
208                assert (id && "Topological ordering failed!");
209                component[u] = id;
210            }
211
212            // Cut the graph wherever a computation crosses a component or whenever we need to cut the graph because
213            // the instructions corresponding to the pair of nodes differs.
214            EdgeQueue E;
215            graph_traits<Graph>::edge_iterator ei, ei_end;
216            for (std::tie(ei, ei_end) = edges(G); ei != ei_end; ) {
217                const Graph::edge_descriptor e = *ei++;
218                const Vertex u = source(e, G);
219                const Vertex v = target(e, G);
220                if (LLVM_UNLIKELY(isCutNecessary(u, v, G, component))) {
221                    E.push(std::make_pair(u, v));
222                    remove_edge(u, v, G);
223                }
224            }
225
226            // If no cuts are necessary, we're done.
227            if (E.empty()) {
228                break;
229            }
230
231            for (;;) {
232
233                Vertex u, v;
234                std::tie(u, v) = E.front(); E.pop();
235
236                // The vertex belonging to a component with a greater number must come "earlier"
237                // in the program. By replicating it, this ensures it's computed as an output of
238                // one component and used as an input of another.
239
240                if (component[u] < component[v]) {
241                    std::swap(u, v);
242                }
243
244                // Replicate u and fix the ordering and component vectors to reflect the change in G.
245                Vertex w = add_vertex(G[u], G);
246                ordering.insert(std::find(ordering.begin(), ordering.end(), u), w);
247                assert (component.size() == w);
248                component.push_back(component[v]);
249                add_edge(w, v, G);
250
251                // However, after we do so, we need to make sure the original source vertex will be a
252                // sink in G unless it is also an input variable (in which case we'd simply end up with
253                // extraneous isolated vertex. Otherwise, we need to make further cuts and replications.
254
255                if (in_degree(u, G) != 0) {
256                    for (auto e : make_iterator_range(out_edges(u, G))) {
257                        E.push(std::make_pair(source(e, G), target(e, G)));
258                    }
259                    clear_out_edges(u, G);
260                }
261
262                if (E.empty()) {
263                    break;
264                }
265
266            }
267        }
268
269        // Scan through the graph in reverse order so that we find all subexpressions first
270        for (const Vertex u : ordering) {
271            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
272
273                // While we're collecting our variable set V, keep track of the maximum path length L.
274                // If L == ceil(log2(|V|)), then this portion of the AST is already optimal.
275
276                flat_map<Vertex, unsigned> L;
277                flat_set<PabloAST *> V;
278
279                Vertex v = u;
280                unsigned maxPathLength = 0;
281                L.emplace(v, 0);
282                for (;;) {                   
283                    if (in_degree(v, G) == 0) {
284                        V.insert(G[v]);
285                    } else {
286                        const auto l = L[v] + 1;
287                        maxPathLength = std::max(maxPathLength, l);
288                        for (auto e : make_iterator_range(in_edges(v, G))) {
289                            const Vertex w = source(e, G);
290                            auto f = L.find(w);
291                            if (LLVM_LIKELY(f == L.end())) {
292                                L.emplace(w, l);
293                            } else {
294                                f->second = std::max(f->second, l);
295                            }
296                            push(w, Q);
297                        }
298                    }
299                    if (Q.empty()) {
300                        break;
301                    }
302                    v = pop(Q);
303                }
304
305                // Should we optimize this portion of the AST?
306                if (maxPathLength > ceil_log2(V.size())) {
307
308                    Statement * stmt = cast<Statement>(G[u]);
309
310                    circular_buffer<PabloAST *> Q(V.size());
311                    for (PabloAST * var : V) {
312                        Q.push_back(var);
313                    }
314
315                    block.setInsertPoint(stmt->getPrevNode());
316                    if (isa<And>(stmt)) {
317                        while (Q.size() > 1) {
318                            PabloAST * e1 = Q.front(); Q.pop_front();
319                            PabloAST * e2 = Q.front(); Q.pop_front();
320                            Q.push_back(block.createAnd(e1, e2));
321                        }
322                    } else if (isa<Or>(stmt)) {
323                        while (Q.size() > 1) {
324                            PabloAST * e1 = Q.front(); Q.pop_front();
325                            PabloAST * e2 = Q.front(); Q.pop_front();
326                            Q.push_back(block.createOr(e1, e2));
327                        }
328                    } else { assert(isa<Xor>(stmt));
329                        while (Q.size() > 1) {
330                            PabloAST * e1 = Q.front(); Q.pop_front();
331                            PabloAST * e2 = Q.front(); Q.pop_front();
332                            Q.push_back(block.createXor(e1, e2));
333                        }
334                    }
335                    stmt->replaceWith(Q.front(), true, true);
336                }
337            }
338        }
339
340        // Determine the source variables of the next "layer" of the AST
341        flat_set<Statement *> nextSet;
342        for (auto u : ordering) {
343            if (in_degree(u, G) == 0) {
344                PabloAST * const var = G[u];
345                if (LLVM_LIKELY(isa<Statement>(var) && cast<Statement>(var)->getParent() == &block)) {
346                    nextSet.insert(cast<Statement>(var));
347                }
348            } else if (out_degree(u, G) == 0) { // an input may also be the output of some subgraph of G. We don't need to reevaluate it.
349                PabloAST * const var = G[u];
350                if (LLVM_LIKELY(isa<Statement>(var) && cast<Statement>(var)->getParent() == &block)) {
351                    nextSet.erase(cast<Statement>(var));
352                }
353            }
354        }
355
356        if (nextSet.empty()) {
357            break;
358        }
359
360        terminals.assign(nextSet.begin(), nextSet.end());
361    }
362}
363
364BooleanReassociationPass::BooleanReassociationPass()
365{
366
367}
368
369
370}
Note: See TracBrowser for help on using the repository browser.