source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 4741

Last change on this file since 4741 was 4741, checked in by nmedfort, 4 years ago

More work on the reassociation pass.

File size: 16.1 KB
Line 
1#include "booleanreassociationpass.h"
2#include <pablo/printer_pablos.h>
3#include <boost/container/flat_set.hpp>
4#include <boost/container/flat_map.hpp>
5#include <iostream>
6#include <boost/circular_buffer.hpp>
7#include <queue>
8#include <pablo/builder.hpp>
9#include <boost/graph/adjacency_list.hpp>
10#include <boost/graph/topological_sort.hpp>
11
12using namespace boost;
13using namespace boost::container;
14
15namespace pablo {
16
17bool BooleanReassociationPass::optimize(PabloFunction & function) {
18    BooleanReassociationPass brp;
19    brp.scan(function);
20    return true;
21}
22
23/** ------------------------------------------------------------------------------------------------------------- *
24 * @brief scan
25 ** ------------------------------------------------------------------------------------------------------------- */
26void BooleanReassociationPass::scan(PabloFunction & function) {
27    Terminals terminals;
28    for (unsigned i = 0; i != function.getNumOfResults(); ++i) {
29        terminals.push_back(function.getResult(i));
30    }
31    scan(function.getEntryBlock(), std::move(terminals));
32}
33
34/** ------------------------------------------------------------------------------------------------------------- *
35 * @brief is_power_of_2
36 * @param n an integer
37 ** ------------------------------------------------------------------------------------------------------------- */
38static inline bool is_power_of_2(const size_t n) {
39    return ((n & (n - 1)) == 0);
40}
41
42/** ------------------------------------------------------------------------------------------------------------- *
43 * @brief log2_plus_one
44 ** ------------------------------------------------------------------------------------------------------------- */
45static inline size_t ceil_log2(const size_t n) {
46    return std::log2<size_t>(n) + (is_power_of_2(n) ? 0 : 1);
47}
48
49/** ------------------------------------------------------------------------------------------------------------- *
50 * @brief isACD
51 ** ------------------------------------------------------------------------------------------------------------- */
52static inline bool isaBooleanOperation(const PabloAST * const expr) {
53    switch (expr->getClassTypeId()) {
54        case PabloAST::ClassTypeId::And:
55        case PabloAST::ClassTypeId::Or:
56        case PabloAST::ClassTypeId::Xor:
57            return true;
58        default:
59            return false;
60    }
61}
62
63using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
64using Vertex = Graph::vertex_descriptor;
65
66/** ------------------------------------------------------------------------------------------------------------- *
67 * @brief isCutNecessary
68 ** ------------------------------------------------------------------------------------------------------------- */
69static inline bool isCutNecessary(const Vertex u, const Vertex v, const Graph & G, const std::vector<unsigned> & component) {
70    if (LLVM_UNLIKELY(component[u] != component[v])) {
71        return true;
72    } else if (LLVM_UNLIKELY(out_degree(v, G) && in_degree(u, G) && G[u]->getClassTypeId() != G[v]->getClassTypeId())) {
73        return true;
74    }
75    return false;
76}
77
78/** ------------------------------------------------------------------------------------------------------------- *
79 * @brief isCutNecessary
80 ** ------------------------------------------------------------------------------------------------------------- */
81static bool print(const Graph & G, const std::vector<unsigned> & component, raw_os_ostream & out) {
82    bool hasError = false;
83    out << "digraph G {\n";
84    unsigned i = 0;
85    for (auto u : make_iterator_range(vertices(G))) {
86        out << "u" << u << " [label=\"";
87        out << i++ << " : ";
88        PabloPrinter::print(G[u], out);
89        out << " (" << component[u] << ')';
90        out << "\"";
91        if (isaBooleanOperation(G[u]) && (in_degree(u, G) == 1 || in_degree(u, G) > 2)) {
92            out << " color=red";
93            hasError = true;
94        }
95        out << "];\n";
96    }
97    for (auto e : make_iterator_range(edges(G))) {
98        Vertex u = source(e, G);
99        Vertex v = target(e, G);
100        out << "u" << u << " -> u" << v;
101        if (isCutNecessary(u, v, G, component)) {
102            out << " [color=red]";
103            hasError = true;
104        }
105        out << ";\n";
106    }
107    out << "}\n";
108    out.flush();
109    return hasError;
110}
111
112/** ------------------------------------------------------------------------------------------------------------- *
113 * @brief scan
114 ** ------------------------------------------------------------------------------------------------------------- */
115void BooleanReassociationPass::scan(PabloBlock & block, Terminals && terminals) {
116
117    using Map = std::unordered_map<PabloAST *, Vertex>;
118    using VertexQueue = std::queue<Vertex>;
119    using EdgeQueue = std::queue<std::pair<Vertex, Vertex>>;
120
121    for (Statement * stmt : block) {
122        if (isa<If>(stmt)) {
123            const auto & defs = cast<const If>(stmt)->getDefined();
124            Terminals terminals(defs.begin(), defs.end());
125            scan(cast<If>(stmt)->getBody(), std::move(terminals));
126        } else if (isa<While>(stmt)) {
127            const auto & vars = cast<const While>(stmt)->getVariants();
128            Terminals terminals(vars.begin(), vars.end());
129            scan(cast<While>(stmt)->getBody(), std::move(terminals));
130        }
131    }
132
133    // And, Or and Xor instructions are all associative, commutative and distributive operations. Thus we can
134    // safely rearrange expressions such as "((((a √ b) √ c) √ d) √ e) √ f" into "((a √ b) √ (c √ d)) √ (e √ f)".
135
136    raw_os_ostream out(std::cerr);
137
138    out << "=================================================\n";
139
140    PabloBuilder builder(block);
141
142    for (;;) {
143
144        Graph G;
145        Map M;
146
147        // Generate a graph depicting the relationships between the terminals. If the original terminals
148        // cannot be optimized with this algorithm bypass them in favour of their operands.
149
150        VertexQueue Q;
151
152        for (Statement * const term : terminals) {
153            if (isaBooleanOperation(term)) {
154                if (LLVM_LIKELY(M.count(term) == 0)) {
155                    const auto v = add_vertex(term, G);
156                    M.insert(std::make_pair(term, v));
157                    Q.push(v);
158                }
159            } else {
160                for (unsigned i = 0; i != term->getNumOperands(); ++i) {
161                    PabloAST * const op = term->getOperand(i);
162                    if (LLVM_LIKELY(isa<Statement>(op) && M.count(op) == 0)) {
163                        const auto v = add_vertex(op, G);
164                        M.insert(std::make_pair(op, v));
165                        Q.push(v);
166                    }
167                }
168            }
169
170        }
171
172        for (;;) {
173            const Vertex u = Q.front(); Q.pop();
174            if (isaBooleanOperation(G[u])) {
175                // Scan through the use-def chains to locate any chains of rearrangable expressions and their inputs
176                Statement * stmt = cast<Statement>(G[u]);
177                for (unsigned i = 0; i != 2; ++i) {
178                    PabloAST * op = stmt->getOperand(i);
179                    auto f = M.find(op);
180                    if (f == M.end()) {
181                        const auto v = add_vertex(op, G);
182                        f = M.insert(std::make_pair(op, v)).first;
183                        if (op->getClassTypeId() == stmt->getClassTypeId() && cast<Statement>(op)->getParent() == &block) {
184                            Q.push(v);
185                        }
186                    }
187                    add_edge(f->second, u, G);
188                }
189            }
190            if (Q.empty()) {
191                break;
192            }
193        }
194
195        // Generate a topological ordering for G; if one of our terminals happens to also be a partial computation of
196        // another terminal, we need to make sure we compute it.
197        std::vector<unsigned> ordering;
198        ordering.reserve(num_vertices(G));
199        topological_sort(G, std::back_inserter(ordering));
200        std::vector<unsigned> component(num_vertices(G));
201
202        for (;;) {
203
204            // Mark which computation component these vertices are in based on their topological (occurence) order.
205            unsigned count = 0;
206            for (auto u : ordering) {
207                unsigned id = 0;
208                // If this one of our original terminals or a sink in G, it is the root of a new component.
209                if (out_degree(u, G) == 0) {
210                    id = ++count;
211                } else {
212                    for (auto e : make_iterator_range(out_edges(u, G))) {
213                        id = std::max(id, component[target(e, G)]);
214                    }
215                }
216                assert (id && "Topological ordering failed!");
217                component[u] = id;
218            }
219
220            if (count > 1) {
221                // Cut the graph wherever a computation crosses a component
222                EdgeQueue Q;
223                graph_traits<Graph>::edge_iterator ei, ei_end;
224
225                for (std::tie(ei, ei_end) = edges(G); ei != ei_end; ) {
226                    const Graph::edge_descriptor e = *ei++;
227                    const Vertex u = source(e, G);
228                    const Vertex v = target(e, G);
229                    if (LLVM_UNLIKELY(isCutNecessary(u, v, G, component))) {
230                        Q.push(std::make_pair(u, v));
231                        remove_edge(u, v, G);
232                    }
233                }
234
235                // If no edges cross a component, we're done.
236                if (Q.empty()) {
237                    break; // outer for loop
238                }
239
240                for (;;) {
241
242                    Vertex u, v;
243                    std::tie(u, v) = Q.front(); Q.pop();
244
245                    // The vertex belonging to a component with a greater number must come "earlier"
246                    // in the program. By replicating it, this ensures it's computed as an output of
247                    // one component and used as an input of another.
248
249                    if (component[u] < component[v]) {
250                        std::swap(u, v);
251                    }
252
253                    // Replicate u and fix the ordering and component vectors to reflect the change in G.
254                    Vertex w = add_vertex(G[u], G);
255                    ordering.insert(std::find(ordering.begin(), ordering.end(), u), w);
256                    assert (component.size() == w);
257                    component.push_back(component[v]);
258                    add_edge(w, v, G);
259
260                    // However, after we do so, we need to make sure the original source vertex will be a
261                    // sink in G unless it is also an input variable (in which case we'd simply end up with
262                    // extraneous isolated vertex. Otherwise, we need to make further cuts and replications.
263
264                    if (in_degree(u, G) != 0) {
265                        for (auto e : make_iterator_range(out_edges(u, G))) {
266                            Q.push(std::make_pair(source(e, G), target(e, G)));
267                        }
268                        clear_out_edges(u, G);
269                    }
270
271                    if (Q.empty()) {
272                        break;
273                    }
274
275                }
276                continue; // outer for loop
277            }
278            break; // outer for loop
279        }
280
281        if (print(G, component, out)) {
282            PabloPrinter::print(block.statements(), out);
283            out.flush();
284            throw std::runtime_error("Illegal graph generated!");
285        }
286
287        // Scan through the graph in reverse order so that we find all subexpressions first
288        for (auto ui = ordering.begin(); ui != ordering.end(); ++ui) {
289            const Vertex u = *ui;
290            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
291
292                out << " -- checking component " << component[u] << " : ";
293                PabloPrinter::print(G[u], out);
294                out << "\n";
295                out.flush();
296
297                // While we're collecting our variable set V, keep track of the maximum path length L.
298                // If L == ceil(log2(|V|)), then this portion of the AST is already optimal.
299
300                flat_map<Vertex, unsigned> L;
301                flat_set<PabloAST *> V;
302                VertexQueue Q;
303
304                Vertex v = u;
305                unsigned maxPathLength = 0;
306                L.emplace(v, 0);
307                for (;;) {
308                    assert (isa<Statement>(G[v]) ? cast<Statement>(G[v])->getParent() != nullptr : true);
309                    if (in_degree(v, G) == 0) {
310                        V.insert(G[v]);
311                    } else {
312                        const auto l = L[v] + 1;
313                        maxPathLength = std::max(maxPathLength, l);
314                        for (auto e : make_iterator_range(in_edges(v, G))) {
315                            const Vertex w = source(e, G);
316                            auto f = L.find(w);
317                            if (LLVM_LIKELY(f == L.end())) {
318                                L.emplace(w, l);
319                            } else {
320                                f->second = std::max(f->second, l);
321                            }
322                            Q.push(w);
323                        }
324                    }
325                    if (Q.empty()) {
326                        break;
327                    }
328                    v = Q.front();
329                    Q.pop();
330                }
331
332                // Should we optimize this portion of the AST?
333                if (maxPathLength > ceil_log2(V.size())) {
334
335                    out << " -- rewriting component " << component[u] << " : |P|=" << maxPathLength << ", |V|=" << V.size() << " (" << ceil_log2(V.size()) << ")\n";
336                    out.flush();
337
338                    circular_buffer<PabloAST *> Q(V.size());
339                    for (PabloAST * var : V) {
340                        Q.push_back(var);
341                    }
342                    Statement * stmt = cast<Statement>(G[u]);
343
344                    block.setInsertPoint(stmt->getPrevNode());
345                    if (isa<And>(stmt)) {
346                        while (Q.size() > 1) {
347                            PabloAST * e1 = Q.front(); Q.pop_front();
348                            PabloAST * e2 = Q.front(); Q.pop_front();
349                            Q.push_back(builder.createAnd(e1, e2));
350                        }
351                    } else if (isa<Or>(stmt)) {
352                        while (Q.size() > 1) {
353                            PabloAST * e1 = Q.front(); Q.pop_front();
354                            PabloAST * e2 = Q.front(); Q.pop_front();
355                            Q.push_back(builder.createOr(e1, e2));
356                        }
357                    } else { // if (isa<Xor>(stmt)) {
358                        while (Q.size() > 1) {
359                            PabloAST * e1 = Q.front(); Q.pop_front();
360                            PabloAST * e2 = Q.front(); Q.pop_front();
361                            Q.push_back(builder.createXor(e1, e2));
362                        }
363                    }
364                    stmt->replaceWith(Q.front(), true, true);
365                }
366
367                for (auto uj = ui; ++uj != ordering.end(); ) {
368                    assert (isa<Statement>(G[*uj]) ? cast<Statement>(G[*uj])->getParent() != nullptr : true);
369                }
370
371            }
372        }
373
374        // Determine the source variables of the next "layer" of the AST
375        flat_set<Statement *> nextSet;
376        for (auto u : ordering) {
377            if (in_degree(u, G) == 0) {
378                PabloAST * const var = G[u];
379                if (LLVM_LIKELY(isa<Statement>(var) && cast<Statement>(var)->getParent() == &block)) {
380                    nextSet.insert(cast<Statement>(var));
381                }
382            } else if (out_degree(u, G) == 0) { // an input may also be the output of some subgraph of G. We don't need to reevaluate it.
383                PabloAST * const var = G[u];
384                if (LLVM_LIKELY(isa<Statement>(var) && cast<Statement>(var)->getParent() == &block)) {
385                    nextSet.erase(cast<Statement>(var));
386                }
387            }
388        }
389
390        if (nextSet.empty()) {
391            break;
392        }
393
394        terminals.assign(nextSet.begin(), nextSet.end());
395
396        out << "-------------------------------------------------\n";
397    }
398}
399
400BooleanReassociationPass::BooleanReassociationPass()
401{
402
403}
404
405
406}
Note: See TracBrowser for help on using the repository browser.