source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 4747

Last change on this file since 4747 was 4747, checked in by nmedfort, 4 years ago

Temporary check in.

File size: 14.9 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <pablo/builder.hpp>
6#include <boost/graph/adjacency_list.hpp>
7#include <boost/graph/topological_sort.hpp>
8#include <queue>
9
10using namespace boost;
11using namespace boost::container;
12
13namespace pablo {
14
15bool BooleanReassociationPass::optimize(PabloFunction & function) {
16    BooleanReassociationPass brp;
17    brp.scan(function);
18    return true;
19}
20
21/** ------------------------------------------------------------------------------------------------------------- *
22 * @brief scan
23 ** ------------------------------------------------------------------------------------------------------------- */
24void BooleanReassociationPass::scan(PabloFunction & function) {
25    Terminals terminals;
26    for (unsigned i = 0; i != function.getNumOfResults(); ++i) {
27        terminals.push_back(function.getResult(i));
28    }
29    scan(function.getEntryBlock(), std::move(terminals));
30}
31
32/** ------------------------------------------------------------------------------------------------------------- *
33 * @brief is_power_of_2
34 * @param n an integer
35 ** ------------------------------------------------------------------------------------------------------------- */
36static inline bool is_power_of_2(const size_t n) {
37    return ((n & (n - 1)) == 0);
38}
39
40/** ------------------------------------------------------------------------------------------------------------- *
41 * @brief log2_plus_one
42 ** ------------------------------------------------------------------------------------------------------------- */
43static inline size_t ceil_log2(const size_t n) {
44    return std::log2<size_t>(n) + (is_power_of_2(n) ? 0 : 1);
45}
46
47/** ------------------------------------------------------------------------------------------------------------- *
48 * @brief isACD
49 ** ------------------------------------------------------------------------------------------------------------- */
50static inline bool isaBooleanOperation(const PabloAST * const expr) {
51    assert (expr);
52    switch (expr->getClassTypeId()) {
53        case PabloAST::ClassTypeId::And:
54        case PabloAST::ClassTypeId::Or:
55        case PabloAST::ClassTypeId::Xor:
56            return true;
57        default:
58            return false;
59    }
60}
61
62using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
63using Vertex = Graph::vertex_descriptor;
64
65/** ------------------------------------------------------------------------------------------------------------- *
66 * @brief isCutNecessary
67 ** ------------------------------------------------------------------------------------------------------------- */
68static inline bool isCutNecessary(const Vertex u, const Vertex v, const Graph & G, const std::vector<unsigned> & component) {
69    // Either this edge crosses a component boundary or the operations performed by the vertices differs, we need to cut
70    // the graph here and generate two partial equations.
71    if (LLVM_UNLIKELY(component[u] != component[v])) {
72        return true;
73    } else if (LLVM_UNLIKELY((in_degree(u, G) != 0) && (G[u]->getClassTypeId() != G[v]->getClassTypeId()))) {
74        return true;
75    }
76    return false;
77}
78
79/** ------------------------------------------------------------------------------------------------------------- *
80 * @brief scan
81 ** ------------------------------------------------------------------------------------------------------------- */
82void BooleanReassociationPass::scan(PabloBlock & block, Terminals && terminals) {
83
84    using Map = std::unordered_map<PabloAST *, Vertex>;
85    using VertexQueue = std::queue<Vertex>;
86    using EdgeQueue = std::queue<std::pair<Vertex, Vertex>>;
87
88    PabloBuilder builder(block);
89
90    for (Statement * stmt : block) {
91        if (isa<If>(stmt)) {
92            const auto & defs = cast<const If>(stmt)->getDefined();
93            Terminals terminals(defs.begin(), defs.end());
94            scan(cast<If>(stmt)->getBody(), std::move(terminals));
95        } else if (isa<While>(stmt)) {
96            const auto & vars = cast<const While>(stmt)->getVariants();
97            Terminals terminals(vars.begin(), vars.end());
98            scan(cast<While>(stmt)->getBody(), std::move(terminals));
99        } else {
100            builder.record(stmt);
101        }
102
103    }
104
105    // And, Or and Xor instructions are all associative, commutative and distributive operations. Thus we can
106    // safely rearrange expressions such as "((((a √ b) √ c) √ d) √ e) √ f" into "((a √ b) √ (c √ d)) √ (e √ f)".
107
108    bool modifiedAST = false;
109
110    for (;;) {
111
112        Graph G;
113        Map M;
114
115        // Generate a graph depicting the relationships between the terminals. If the original terminals
116        // cannot be optimized with this algorithm bypass them in favour of their operands.
117
118        VertexQueue Q;
119        unsigned count = 0;
120        for (Statement * const term : terminals) {
121            assert (term);
122            assert (Q.size() == count);
123            if (isaBooleanOperation(term)) {
124                if (LLVM_LIKELY(M.count(term) == 0)) {                   
125                    const auto v = add_vertex(term, G);
126                    assert (v < num_vertices(G));
127                    M.insert(std::make_pair(term, v));
128                    Q.push(v);
129                    ++count;
130                }
131            } else {
132                for (unsigned i = 0; i != term->getNumOperands(); ++i) {
133                    PabloAST * const op = term->getOperand(i);
134                    assert (op);
135                    if (LLVM_LIKELY(isa<Statement>(op) && M.count(op) == 0)) {
136                        const auto v = add_vertex(op, G);
137                        assert (v < num_vertices(G));
138                        M.insert(std::make_pair(op, v));
139                        Q.push(v);
140                        ++count;
141                    }
142                }
143            }           
144        }
145
146        for (;;) {
147            assert (Q.size() == count);
148            const Vertex u = Q.front(); Q.pop();
149            --count;
150            assert (u < num_vertices(G));
151            if (isaBooleanOperation(G[u])) {
152                // Scan through the use-def chains to locate any chains of rearrangable expressions and their inputs
153                Statement * stmt = cast<Statement>(G[u]);
154                for (unsigned i = 0; i != 2; ++i) {
155                    PabloAST * op = stmt->getOperand(i);
156                    auto f = M.find(op);
157                    if (f == M.end()) {
158                        const auto v = add_vertex(op, G);
159                        assert (v < num_vertices(G));
160                        f = M.insert(std::make_pair(op, v)).first;
161                        if (op->getClassTypeId() == stmt->getClassTypeId() && cast<Statement>(op)->getParent() == &block) {
162                            Q.push(v);
163                            ++count;
164                        }
165                    }
166                    add_edge(f->second, u, G);
167                }
168            }
169            if (Q.empty()) {
170                break;
171            }
172        }
173
174        // Generate a topological ordering for G; if one of our terminals happens to also be a partial computation of
175        // another terminal, we need to make sure we compute it as an independent subexpression.
176        std::vector<unsigned> ordering;
177        ordering.reserve(num_vertices(G));
178        topological_sort(G, std::back_inserter(ordering));
179        std::vector<unsigned> component(num_vertices(G));
180
181        for (;;) {
182
183            // Mark which computation component these vertices are in based on their topological (occurence) order.
184            unsigned components = 0;
185            for (auto u : ordering) {
186                unsigned id = 0;
187                // If this is a sink in G, it is the root of a new component.
188                if (out_degree(u, G) == 0) {
189                    id = ++components;
190                } else {
191                    for (auto e : make_iterator_range(out_edges(u, G))) {
192                        id = std::max(id, component[target(e, G)]);
193                    }
194                }
195                assert (id && "Topological ordering failed!");
196                component[u] = id;
197            }
198
199            // Cut the graph wherever a computation crosses a component or whenever we need to cut the graph because
200            // the instructions corresponding to the pair of nodes differs.
201            EdgeQueue Q;
202            graph_traits<Graph>::edge_iterator ei, ei_end;
203            for (std::tie(ei, ei_end) = edges(G); ei != ei_end; ) {
204                const Graph::edge_descriptor e = *ei++;
205                const Vertex u = source(e, G);
206                const Vertex v = target(e, G);
207                if (LLVM_UNLIKELY(isCutNecessary(u, v, G, component))) {
208                    Q.push(std::make_pair(u, v));
209                    remove_edge(u, v, G);
210                }
211            }
212
213            // If no cuts are necessary, we're done.
214            if (Q.empty()) {
215                break;
216            }
217
218            for (;;) {
219
220                Vertex u, v;
221                std::tie(u, v) = Q.front(); Q.pop();
222
223                // The vertex belonging to a component with a greater number must come "earlier"
224                // in the program. By replicating it, this ensures it's computed as an output of
225                // one component and used as an input of another.
226
227                if (component[u] < component[v]) {
228                    std::swap(u, v);
229                }
230
231                // Replicate u and fix the ordering and component vectors to reflect the change in G.
232                Vertex w = add_vertex(G[u], G);
233                ordering.insert(std::find(ordering.begin(), ordering.end(), u), w);
234                assert (component.size() == w);
235                component.push_back(component[v]);
236                add_edge(w, v, G);
237
238                // However, after we do so, we need to make sure the original source vertex will be a
239                // sink in G unless it is also an input variable (in which case we'd simply end up with
240                // extraneous isolated vertex. Otherwise, we need to make further cuts and replications.
241
242                if (in_degree(u, G) != 0) {
243                    for (auto e : make_iterator_range(out_edges(u, G))) {
244                        Q.push(std::make_pair(source(e, G), target(e, G)));
245                    }
246                    clear_out_edges(u, G);
247                }
248
249                if (Q.empty()) {
250                    break;
251                }
252
253            }
254        }
255
256        // Scan through the graph in reverse order so that we find all subexpressions first
257        for (const Vertex u : ordering) {
258            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
259
260                // While we're collecting our variable set V, keep track of the maximum path length L.
261                // If L == ceil(log2(|V|)), then this portion of the AST is already optimal.
262
263                flat_map<Vertex, unsigned> L;
264                flat_set<PabloAST *> V;
265                VertexQueue Q;
266
267                Vertex v = u;
268                unsigned maxPathLength = 0;
269                L.emplace(v, 0);
270                for (;;) {                   
271                    if (in_degree(v, G) == 0) {
272                        V.insert(G[v]);
273                    } else {
274                        const auto l = L[v] + 1;
275                        maxPathLength = std::max(maxPathLength, l);
276                        for (auto e : make_iterator_range(in_edges(v, G))) {
277                            const Vertex w = source(e, G);
278                            auto f = L.find(w);
279                            if (LLVM_LIKELY(f == L.end())) {
280                                L.emplace(w, l);
281                            } else {
282                                f->second = std::max(f->second, l);
283                            }
284                            Q.push(w);
285                        }
286                    }
287                    if (Q.empty()) {
288                        break;
289                    }
290                    v = Q.front();
291                    Q.pop();
292                }
293
294                // Should we optimize this portion of the AST?
295                if (maxPathLength > ceil_log2(V.size())) {
296
297                    Statement * stmt = cast<Statement>(G[u]);
298
299                    circular_buffer<PabloAST *> Q(V.size());
300                    for (PabloAST * var : V) {
301                        Q.push_back(var);
302                    }
303
304                    block.setInsertPoint(stmt->getPrevNode());
305                    if (isa<And>(stmt)) {
306                        while (Q.size() > 1) {
307                            PabloAST * e1 = Q.front(); Q.pop_front();
308                            PabloAST * e2 = Q.front(); Q.pop_front();
309                            Q.push_back(builder.createAnd(e1, e2));
310                        }
311                    } else if (isa<Or>(stmt)) {
312                        while (Q.size() > 1) {
313                            PabloAST * e1 = Q.front(); Q.pop_front();
314                            PabloAST * e2 = Q.front(); Q.pop_front();
315                            Q.push_back(builder.createOr(e1, e2));
316                        }
317                    } else { assert(isa<Xor>(stmt));
318                        while (Q.size() > 1) {
319                            PabloAST * e1 = Q.front(); Q.pop_front();
320                            PabloAST * e2 = Q.front(); Q.pop_front();
321                            Q.push_back(builder.createXor(e1, e2));
322                        }
323                    }
324                    stmt->replaceWith(Q.front(), true, false);
325                    modifiedAST = true;
326                }
327            }
328        }
329
330        // Determine the source variables of the next "layer" of the AST
331        flat_set<Statement *> nextSet;
332        for (auto u : ordering) {
333            if (in_degree(u, G) == 0) {
334                PabloAST * const var = G[u];
335                if (LLVM_LIKELY(isa<Statement>(var) && cast<Statement>(var)->getParent() == &block)) {
336                    nextSet.insert(cast<Statement>(var));
337                }
338            } else if (out_degree(u, G) == 0) { // an input may also be the output of some subgraph of G. We don't need to reevaluate it.
339                PabloAST * const var = G[u];
340                if (LLVM_LIKELY(isa<Statement>(var) && cast<Statement>(var)->getParent() == &block)) {
341                    nextSet.erase(cast<Statement>(var));
342                }
343            }
344        }
345
346        if (nextSet.empty()) {
347            break;
348        }
349
350        terminals.assign(nextSet.begin(), nextSet.end());
351    }
352
353    // If we modified the AST, we likely left dead code in it. Go through and remove any from this block.
354    if (modifiedAST) {
355        Statement * stmt = block.front();
356        while (stmt) {
357            if (stmt->getNumUses() == 0 && !(isa<If>(stmt) || isa<While>(stmt))){
358                stmt = stmt->eraseFromParent(true);
359            } else {
360                stmt = stmt->getNextNode();
361            }
362        }
363    }
364}
365
366BooleanReassociationPass::BooleanReassociationPass()
367{
368
369}
370
371
372}
Note: See TracBrowser for help on using the repository browser.