source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 4752

Last change on this file since 4752 was 4752, checked in by nmedfort, 4 years ago

More work on the reassociation pass + few additional Simplification tests

File size: 25.3 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <pablo/optimizers/pablo_simplifier.hpp>
8#include <queue>
9#include <iostream>
10#include <pablo/printer_pablos.h>
11
12
13using namespace boost;
14using namespace boost::container;
15
16namespace pablo {
17
18using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
19using Vertex = Graph::vertex_descriptor;
20using VertexQueue = circular_buffer<Vertex>;
21using Map = std::unordered_map<PabloAST *, Vertex>;
22using EdgeQueue = std::queue<std::pair<Vertex, Vertex>>;
23
24static void summarizeAST(PabloBlock & block, std::vector<Statement *> && terminals, Graph & G);
25
26/** ------------------------------------------------------------------------------------------------------------- *
27 * @brief optimize
28 ** ------------------------------------------------------------------------------------------------------------- */
29bool BooleanReassociationPass::optimize(PabloFunction & function) {
30    BooleanReassociationPass brp;
31//    raw_os_ostream out(std::cerr);
32//    out << "BEFORE:\n\n";
33//    PabloPrinter::print(function.getEntryBlock().statements(), out);
34    brp.scan(function);
35    Simplifier::optimize(function);
36
37//    out << "\n\nAFTER:\n\n";
38//    PabloPrinter::print(function.getEntryBlock().statements(), out);
39    return true;
40}
41
42/** ------------------------------------------------------------------------------------------------------------- *
43 * @brief scan
44 ** ------------------------------------------------------------------------------------------------------------- */
45void BooleanReassociationPass::scan(PabloFunction & function) {
46    std::vector<Statement *> terminals;
47    for (unsigned i = 0; i != function.getNumOfResults(); ++i) {
48        terminals.push_back(function.getResult(i));
49    }
50    scan(function.getEntryBlock(), std::move(terminals));
51}
52
53/** ------------------------------------------------------------------------------------------------------------- *
54 * @brief scan
55 ** ------------------------------------------------------------------------------------------------------------- */
56void BooleanReassociationPass::scan(PabloBlock & block, std::vector<Statement *> && terminals) {
57
58    processScope(block, std::move(terminals));
59
60    for (Statement * stmt : block) {
61        if (isa<If>(stmt)) {
62            const auto & defs = cast<const If>(stmt)->getDefined();
63            std::vector<Statement *> terminals(defs.begin(), defs.end());
64            scan(cast<If>(stmt)->getBody(), std::move(terminals));
65        } else if (isa<While>(stmt)) {
66            const auto & vars = cast<const While>(stmt)->getVariants();
67            std::vector<Statement *> terminals(vars.begin(), vars.end());
68            scan(cast<While>(stmt)->getBody(), std::move(terminals));
69        }
70    }
71
72}
73
74/** ------------------------------------------------------------------------------------------------------------- *
75 * @brief is_power_of_2
76 * @param n an integer
77 ** ------------------------------------------------------------------------------------------------------------- */
78static inline bool is_power_of_2(const size_t n) {
79    return ((n & (n - 1)) == 0);
80}
81
82/** ------------------------------------------------------------------------------------------------------------- *
83 * @brief log2_plus_one
84 ** ------------------------------------------------------------------------------------------------------------- */
85static inline size_t ceil_log2(const size_t n) {
86    return std::log2<size_t>(n) + (is_power_of_2(n) ? 0 : 1);
87}
88
89/** ------------------------------------------------------------------------------------------------------------- *
90 * @brief isOptimizable
91 *
92 * And, Or and Xor instructions are all associative, commutative and distributive operations. Thus we can
93 * safely rearrange expressions such as "((((a √ b) √ c) √ d) √ e) √ f" into "((a √ b) √ (c √ d)) √ (e √ f)".
94 ** ------------------------------------------------------------------------------------------------------------- */
95static inline bool isOptimizable(const PabloAST * const expr) {
96    assert (expr);
97    switch (expr->getClassTypeId()) {
98        case PabloAST::ClassTypeId::And:
99        case PabloAST::ClassTypeId::Or:
100        case PabloAST::ClassTypeId::Xor:
101            return true;
102        default:
103            return false;
104    }
105}
106
107/** ------------------------------------------------------------------------------------------------------------- *
108 * @brief inCurrentBlock
109 ** ------------------------------------------------------------------------------------------------------------- */
110static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock & block) {
111    return stmt->getParent() == &block;
112}
113
114static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock & block) {
115    return isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block);
116}
117
118/** ------------------------------------------------------------------------------------------------------------- *
119 * @brief isAnyUserNotInCurrentBlock
120 ** ------------------------------------------------------------------------------------------------------------- */
121static inline bool isAnyUserNotInCurrentBlock(const PabloAST * expr, const PabloBlock & block) {
122    for (PabloAST * user : expr->users()) {
123        if (!inCurrentBlock(cast<Statement>(user), block)) {
124            return true;
125        }
126    }
127    return false;
128}
129
130/** ------------------------------------------------------------------------------------------------------------- *
131 * @brief isCutNecessary
132 ** ------------------------------------------------------------------------------------------------------------- */
133static inline bool isCutNecessary(const Vertex u, const Vertex v, const Graph & G, const std::vector<unsigned> & component) {
134    // Either this edge crosses a component boundary or the operations performed by the vertices differs, we need to cut
135    // the graph here and generate two partial equations.
136    if (LLVM_UNLIKELY(component[u] != component[v])) {
137        return true;
138    } else if (LLVM_UNLIKELY((in_degree(u, G) != 0) && (G[u]->getClassTypeId() != G[v]->getClassTypeId()))) {
139        return true;
140    }
141    return false;
142}
143
144/** ------------------------------------------------------------------------------------------------------------- *
145 * @brief push
146 ** ------------------------------------------------------------------------------------------------------------- */
147static inline void push(const Vertex u, VertexQueue & Q) {
148    if (LLVM_UNLIKELY(Q.full())) {
149        Q.set_capacity(Q.capacity() * 2);
150    }
151    Q.push_back(u);
152    assert (Q.back() == u);
153}
154
155/** ------------------------------------------------------------------------------------------------------------- *
156 * @brief pop
157 ** ------------------------------------------------------------------------------------------------------------- */
158static inline Vertex pop(VertexQueue & Q) {
159    assert (!Q.empty() && "Popping an empty vertex queue");
160    const Vertex u = Q.front();
161    Q.pop_front();
162    return u;
163}
164
165/** ------------------------------------------------------------------------------------------------------------- *
166 * @brief getVertex
167 ** ------------------------------------------------------------------------------------------------------------- */
168static inline Vertex getVertex(PabloAST * expr, Graph & G, Map & M) {
169    const auto f = M.find(expr);
170    if (f != M.end()) {
171        return f->second;
172    }
173    const auto u = add_vertex(expr, G);
174    M.insert(std::make_pair(expr, u));
175    return u;
176}
177
178/** ------------------------------------------------------------------------------------------------------------- *
179 * @brief createTree
180 ** ------------------------------------------------------------------------------------------------------------- */
181static PabloAST * createTree(PabloBlock & block, const PabloAST::ClassTypeId typeId, circular_buffer<PabloAST *> & Q) {
182    while (Q.size() > 1) {
183        PabloAST * e1 = Q.front(); Q.pop_front();
184        PabloAST * e2 = Q.front(); Q.pop_front();
185        PabloAST * expr = nullptr;
186        switch (typeId) {
187            case PabloAST::ClassTypeId::And:
188                expr = block.createAnd(e1, e2); break;
189            case PabloAST::ClassTypeId::Or:
190                expr = block.createOr(e1, e2); break;
191            case PabloAST::ClassTypeId::Xor:
192                expr = block.createXor(e1, e2); break;
193            default: break;
194        }
195        Q.push_back(expr);
196    }
197    PabloAST * r = Q.front();
198    Q.clear();
199    return r;
200}
201
202/** ------------------------------------------------------------------------------------------------------------- *
203 * @brief applyDistributionLaw
204 ** ------------------------------------------------------------------------------------------------------------- */
205static bool applyDistributionLaw(PabloBlock & block, const PabloAST::ClassTypeId typeId, flat_set<PabloAST *> & vars) {
206    circular_buffer<PabloAST *> Q0(vars.size());
207    circular_buffer<PabloAST *> Q1(vars.size());
208    std::vector<PabloAST *> distributedVars;
209
210    for (auto vi = vars.begin(); vi != vars.end(); ) {
211        PabloAST * const e0 = *vi;
212
213        if (e0->getClassTypeId() == typeId) {
214            Statement * const s0 = cast<Statement>(e0);
215
216            for (auto vj = vi + 1; vj != vars.end(); ) {
217                PabloAST * const e1 = *vj;
218
219                if (e1->getClassTypeId() == typeId) {
220                    Statement * const s1 = cast<Statement>(e1);
221                    bool distributed = false;
222
223                    if (s0->getOperand(0) == s1->getOperand(0)) {
224                        Q0.push_back(s1->getOperand(1));
225                        distributed = true;
226                    } else if (s0->getOperand(0) == s1->getOperand(1)) {
227                        Q0.push_back(s1->getOperand(0));
228                        distributed = true;
229                    }
230
231                    if (s0->getOperand(1) == s1->getOperand(0)) {
232                        Q1.push_back(s1->getOperand(1));
233                        distributed = true;
234                    } else if (s0->getOperand(1) == s1->getOperand(1)) {
235                        Q1.push_back(s1->getOperand(0));
236                        distributed = true;
237                    }
238
239                    if (distributed) {
240                        vj = vars.erase(vj);
241                        continue;
242                    }
243                }
244
245                ++vj;
246            }
247
248            if (LLVM_UNLIKELY(Q0.size() > 0 || Q1.size() > 0)) {
249                const PabloAST::ClassTypeId innerTypeId =
250                        (typeId == PabloAST::ClassTypeId::Or) ? PabloAST::ClassTypeId::And : PabloAST::ClassTypeId::Or;
251
252                vi = vars.erase(vi);
253                if (Q0.size() > 0) {
254                    Q0.push_back(s0->getOperand(1));
255                    PabloAST * distributed = createTree(block, innerTypeId, Q0);
256                    switch (typeId) {
257                        case PabloAST::ClassTypeId::And:
258                            distributed = block.createAnd(s0->getOperand(0), distributed); break;
259                        case PabloAST::ClassTypeId::Or:
260                            distributed = block.createOr(s0->getOperand(0), distributed); break;
261                        default: break;
262                    }
263                    distributedVars.push_back(distributed);
264                }
265                if (Q1.size() > 0) {
266                    Q1.push_front(s0->getOperand(0));
267                    PabloAST * distributed = createTree(block, innerTypeId, Q1);
268                    switch (typeId) {
269                        case PabloAST::ClassTypeId::And:
270                            distributed = block.createAnd(s0->getOperand(1), distributed); break;
271                        case PabloAST::ClassTypeId::Or:
272                            distributed = block.createOr(s0->getOperand(1), distributed); break;
273                        default: break;
274                    }
275                    distributedVars.push_back(distributed);
276                }
277                continue;
278            }
279        }
280        ++vi;
281    }
282    if (distributedVars.empty()) {
283        return false;
284    }
285    for (PabloAST * var : distributedVars) {
286        vars.insert(var);
287    }
288    return true;
289}
290
291/** ------------------------------------------------------------------------------------------------------------- *
292 * @brief processScope
293 ** ------------------------------------------------------------------------------------------------------------- */
294void BooleanReassociationPass::processScope(PabloBlock & block, std::vector<Statement *> && terminals) {
295
296    Graph G;
297    summarizeAST(block, std::move(terminals), G);
298
299    raw_os_ostream out(std::cerr);
300    out << "digraph G {\n";
301    for (auto u : make_iterator_range(vertices(G))) {
302        out << "v" << u << " [label=\"";
303        PabloAST * expr = G[u];
304        if (isa<Statement>(expr)) {
305            if (LLVM_UNLIKELY(isa<If>(expr))) {
306                out << "if ";
307                PabloPrinter::print(cast<If>(expr)->getOperand(0), out);
308                out << ":";
309            } else if (LLVM_UNLIKELY(isa<While>(expr))) {
310                out << "while ";
311                PabloPrinter::print(cast<While>(expr)->getOperand(0), out);
312                out << ":";
313            } else {
314                PabloPrinter::print(cast<Statement>(expr), "", out);
315            }
316        } else {
317            PabloPrinter::print(expr, out);
318        }
319        out << "\"";
320        if (!inCurrentBlock(expr, block)) {
321            out << " style=dashed";
322        }
323        out << "];\n";
324    }
325    for (auto e : make_iterator_range(edges(G))) {
326        out << "v" << source(e, G) << " -> v" << target(e, G) << ";\n";
327    }
328
329    out << "{ rank=same;";
330    for (auto u : make_iterator_range(vertices(G))) {
331        if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
332            out << " v" << u;
333        }
334    }
335    out << "}\n";
336
337    out << "{ rank=same;";
338    for (auto u : make_iterator_range(vertices(G))) {
339        if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
340            out << " v" << u;
341        }
342    }
343    out << "}\n";
344
345    out << "}\n\n";
346    out.flush();
347
348
349
350
351
352}
353
354/** ------------------------------------------------------------------------------------------------------------- *
355 * @brief summarizeAST
356 *
357 * This function scans through a basic block (starting by its terminals) and computes a DAG in which any sequences
358 * of AND, OR or XOR functions are "flattened" and allowed to have any number of inputs. This allows us to
359 * reassociate them in the most efficient way possible.
360 ** ------------------------------------------------------------------------------------------------------------- */
361static void summarizeAST(PabloBlock & block, std::vector<Statement *> && terminals, Graph & G) {
362
363    Map M;
364    VertexQueue Q(128);
365    EdgeQueue E;
366
367    for (;;) {
368
369        Graph Gk;
370        Map Mk;
371
372        // Generate a graph depicting the relationships between the terminals. If the original terminals
373        // cannot be optimized with this algorithm bypass them in favour of their operands. If those cannot
374        // be optimized, they'll be left as the initial terminals for the next "layer" of the AST.
375
376        for (Statement * term : terminals) {
377            if (LLVM_LIKELY(Mk.count(term) == 0)) {
378                // add or find this terminal in our global graph
379                Vertex x = getVertex(term, G, M);
380                if (inCurrentBlock(term, block)) {
381                    if (isOptimizable(term)) {
382                        const Vertex u = add_vertex(term, Gk);
383                        Mk.insert(std::make_pair(term, u));
384                        push(u, Q);
385                        continue;
386                    }
387                } else if (isa<Assign>(term) || isa<Next>(term)) {
388                    // If this is an Assign (Next) node whose operand does not originate from the current block
389                    // then check to see if there is an If (While) node that does.
390                    Statement * branch = nullptr;
391                    if (isa<Assign>(term)) {
392                        for (PabloAST * user : term->users()) {
393                            if (isa<If>(user)) {
394                                const If * ifNode = cast<If>(user);
395                                if (inCurrentBlock(ifNode, block)) {
396                                    const auto & defs = ifNode->getDefined();
397                                    if (LLVM_LIKELY(std::find(defs.begin(), defs.end(), cast<Assign>(term)) != defs.end())) {
398                                        branch = cast<Statement>(user);
399                                        break;
400                                    }
401                                }
402                            }
403                        }
404                    } else { // if (isa<Next>(term))
405                        for (PabloAST * user : term->users()) {
406                            if (isa<While>(user)) {
407                                const While * whileNode = cast<While>(user);
408                                if (inCurrentBlock(whileNode, block)) {
409                                    const auto & vars = whileNode->getVariants();
410                                    if (LLVM_LIKELY(std::find(vars.begin(), vars.end(), cast<Next>(term)) != vars.end())) {
411                                        branch = cast<Statement>(user);
412                                        break;
413                                    }
414                                }
415                            }
416                        }
417                    }
418
419                    // If we didn't find a branch, then the Assign (Next) node must have come from a preceeding
420                    // block. Just skip it for now.
421                    if (branch == nullptr) {
422                        continue;
423                    }
424
425                    // Otherwise add the branch to G and test its operands rather than the original terminal
426                    const Vertex z = getVertex(branch, G, M);
427                    add_edge(z, x, G);
428                    x = z;
429                    term = branch;
430                }
431
432                for (unsigned i = 0; i != term->getNumOperands(); ++i) {
433                    PabloAST * const op = term->getOperand(i);
434                    if (LLVM_LIKELY(inCurrentBlock(op, block))) {
435                        const Vertex y = getVertex(op, G, M);
436                        add_edge(y, x, G);
437                        if (LLVM_LIKELY(Mk.count(op) == 0)) {
438                            const Vertex v = add_vertex(op, Gk);
439                            Mk.insert(std::make_pair(op, v));
440                            push(v, Q);
441                        }
442                    }
443                }
444            }
445        }
446
447        if (LLVM_UNLIKELY(Q.empty())) {
448            break;
449        }
450
451        for (;;) {
452            const Vertex u = pop(Q);
453            if (isOptimizable(Gk[u])) {
454                Statement * stmt = cast<Statement>(Gk[u]);
455                if (isAnyUserNotInCurrentBlock(stmt, block)) {
456                    const Vertex v = add_vertex(block.createZeroes(), Gk);
457                    add_edge(u, v, Gk);
458                }
459                // Scan through the use-def chains to locate any chains of rearrangable expressions and their inputs
460                for (unsigned i = 0; i != 2; ++i) {
461                    PabloAST * op = stmt->getOperand(i);
462                    auto f = Mk.find(op);
463                    if (f == Mk.end()) {
464                        const Vertex v = add_vertex(op, Gk);
465                        f = Mk.insert(std::make_pair(op, v)).first;
466                        if (op->getClassTypeId() == stmt->getClassTypeId() && inCurrentBlock(cast<Statement>(op), block)) {
467                            push(v, Q);
468                        }
469                    }
470                    add_edge(f->second, u, Gk);
471                }
472            }
473            if (Q.empty()) {
474                break;
475            }
476        }
477
478        // Generate a topological ordering for G; if one of our terminals happens to also be a partial computation of
479        // another terminal, we need to make sure we compute it as an independent subexpression.
480        std::vector<unsigned> ordering;
481        ordering.reserve(num_vertices(Gk));
482        topological_sort(Gk, std::back_inserter(ordering));
483        std::vector<unsigned> component(num_vertices(Gk));
484
485        for (;;) {
486
487            // Mark which computation component these vertices are in based on their topological (occurence) order.
488            unsigned components = 0;
489            for (auto u : ordering) {
490                unsigned id = 0;
491                // If this is a sink in G, it is the root of a new component.
492                if (out_degree(u, Gk) == 0) {
493                    id = ++components;
494                } else {
495                    for (auto e : make_iterator_range(out_edges(u, Gk))) {
496                        id = std::max(id, component[target(e, Gk)]);
497                    }
498                }
499                assert (id && "Topological ordering failed!");
500                component[u] = id;
501            }
502
503            // Cut the graph wherever a computation crosses a component or whenever we need to cut the graph because
504            // the instructions corresponding to the pair of nodes differs.
505            graph_traits<Graph>::edge_iterator ei, ei_end;
506            for (std::tie(ei, ei_end) = edges(Gk); ei != ei_end; ) {
507                const Graph::edge_descriptor e = *ei++;
508                const Vertex u = source(e, Gk);
509                const Vertex v = target(e, Gk);
510                if (LLVM_UNLIKELY(isCutNecessary(u, v, Gk, component))) {
511                    E.push(std::make_pair(u, v));
512                    remove_edge(u, v, Gk);
513                }
514            }
515
516            // If no cuts are necessary, we're done.
517            if (E.empty()) {
518                break;
519            }
520
521            for (;;) {
522
523                Vertex u, v;
524                std::tie(u, v) = E.front(); E.pop();
525
526                // The vertex belonging to a component with a greater number must come "earlier"
527                // in the program. By replicating it, this ensures it's computed as an output of
528                // one component and used as an input of another.
529
530                if (component[u] < component[v]) {
531                    std::swap(u, v);
532                }
533
534                // Replicate u and fix the ordering and component vectors to reflect the change in G.
535                Vertex w = add_vertex(Gk[u], Gk);
536                ordering.insert(std::find(ordering.begin(), ordering.end(), u), w);
537                assert (component.size() == w);
538                component.push_back(component[v]);
539                add_edge(w, v, Gk);
540
541                // However, after we do so, we need to make sure the original source vertex will be a
542                // sink in G unless it is also an input variable (in which case we'd simply end up with
543                // extraneous isolated vertex. Otherwise, we need to make further cuts and replications.
544
545                if (in_degree(u, Gk) != 0) {
546                    for (auto e : make_iterator_range(out_edges(u, Gk))) {
547                        E.push(std::make_pair(source(e, Gk), target(e, Gk)));
548                    }
549                    clear_out_edges(u, Gk);
550                }
551
552                if (E.empty()) {
553                    break;
554                }
555            }
556        }
557
558        // Scan through the graph so that we process the outermost expressions first
559        for (const Vertex u : ordering) {
560            if (LLVM_UNLIKELY(out_degree(u, Gk) == 0)) {
561                if (LLVM_UNLIKELY(isa<Zeroes>(Gk[u]))) {
562                    continue;
563                }
564                const Vertex x = getVertex(Gk[u], G, M);
565                if (LLVM_LIKELY(in_degree(u, Gk) > 0)) {
566                    flat_set<PabloAST *> vars;
567                    flat_set<Vertex> visited;
568                    for (Vertex v = u;;) {
569                        if (in_degree(v, Gk) == 0) {
570                            vars.insert(Gk[v]);
571                        } else {
572                            for (auto e : make_iterator_range(in_edges(v, Gk))) {
573                                const Vertex w = source(e, Gk);
574                                if (LLVM_LIKELY(visited.insert(w).second)) {
575                                    push(w, Q);
576                                }
577                            }
578                        }
579                        if (Q.empty()) {
580                            break;
581                        }
582                        v = pop(Q);
583                    }
584                    for (PabloAST * var : vars) {
585                        add_edge(getVertex(var, G, M), x, G);
586                    }
587                }
588            }
589        }
590
591        // Determine the source variables of the next "layer" of the AST
592        flat_set<Statement *> nextSet;
593        for (auto u : ordering) {
594            if (LLVM_UNLIKELY(in_degree(u, Gk) == 0 && isa<Statement>(Gk[u]))) {
595                nextSet.insert(cast<Statement>(Gk[u]));
596            } else if (LLVM_UNLIKELY(out_degree(u, Gk) == 0 && isa<Statement>(Gk[u]))) { // an input may also be the output of a subgraph of G. We don't need to reevaluate it.
597                nextSet.erase(cast<Statement>(Gk[u]));
598            }
599        }
600
601        if (LLVM_UNLIKELY(nextSet.empty())) {
602            break;
603        }
604
605        terminals.assign(nextSet.begin(), nextSet.end());
606    }
607
608}
609
610BooleanReassociationPass::BooleanReassociationPass()
611{
612
613}
614
615
616}
Note: See TracBrowser for help on using the repository browser.