source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 5152

Last change on this file since 5152 was 5152, checked in by nmedfort, 3 years ago

Work on a Z3 based reassociation pass.

File size: 59.1 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <boost/graph/strong_components.hpp>
8#include <pablo/optimizers/pablo_simplifier.hpp>
9#include <pablo/analysis/pabloverifier.hpp>
10#include <algorithm>
11#include <numeric> // std::iota
12#include <pablo/printer_pablos.h>
13#include <iostream>
14#include <llvm/Support/CommandLine.h>
15#include "maxsat.hpp"
16
17using namespace boost;
18using namespace boost::container;
19
20
21static cl::OptionCategory ReassociationOptions("Reassociation Optimization Options", "These options control the Pablo Reassociation optimization pass.");
22
23static cl::opt<unsigned> LoadEarly("Reassociation-Load-Early", cl::init(false),
24                                  cl::desc("When recomputing an Associative operation, load values from preceeding blocks at the beginning of the "
25                                           "Scope Block rather than at the point of first use."),
26                                  cl::cat(ReassociationOptions));
27
28namespace pablo {
29
30using TypeId = PabloAST::ClassTypeId;
31using Graph = BooleanReassociationPass::Graph;
32using Vertex = Graph::vertex_descriptor;
33using VertexData = BooleanReassociationPass::VertexData;
34using DistributionGraph = BooleanReassociationPass::DistributionGraph;
35using DistributionMap = flat_map<Graph::vertex_descriptor, DistributionGraph::vertex_descriptor>;
36using VertexSet = std::vector<Vertex>;
37using VertexSets = std::vector<VertexSet>;
38using Biclique = std::pair<VertexSet, VertexSet>;
39using BicliqueSet = std::vector<Biclique>;
40using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
41using DistributionSets = std::vector<DistributionSet>;
42
43/** ------------------------------------------------------------------------------------------------------------- *
44 * @brief helper functions
45 ** ------------------------------------------------------------------------------------------------------------- */
46template<typename Iterator>
47inline Graph::edge_descriptor first(const std::pair<Iterator, Iterator> & range) {
48    assert (range.first != range.second);
49    return *range.first;
50}
51
52static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock * const block) {
53    return stmt->getParent() == block;
54}
55
56static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock * const block) {
57    return expr ? isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block) : true;
58}
59
60inline TypeId & getType(VertexData & data) {
61    return std::get<0>(data);
62}
63
64inline TypeId getType(const VertexData & data) {
65    return std::get<0>(data);
66}
67
68inline Z3_ast & getDefinition(VertexData & data) {
69    return std::get<2>(data);
70}
71
72inline PabloAST * getValue(const VertexData & data) {
73    return std::get<1>(data);
74}
75
76inline PabloAST *& getValue(VertexData & data) {
77    return std::get<1>(data);
78}
79
80inline bool isAssociative(const VertexData & data) {
81    switch (getType(data)) {
82        case TypeId::And:
83        case TypeId::Or:
84        case TypeId::Xor:
85            return true;
86        default:
87            return false;
88    }
89}
90
91inline bool isDistributive(const VertexData & data) {
92    switch (getType(data)) {
93        case TypeId::And:
94        case TypeId::Or:
95            return true;
96        default:
97            return false;
98    }
99}
100
101void add_edge(PabloAST * expr, const Vertex u, const Vertex v, Graph & G) {
102    // Make sure each edge is unique
103    assert (u < num_vertices(G) && v < num_vertices(G));
104    assert (u != v);
105
106    // Just because we've supplied an expr doesn't mean it's useful. Check it.
107    if (expr) {
108        if (isAssociative(G[v])) {
109            expr = nullptr;
110        } else {
111            bool clear = true;
112            if (const Statement * dest = dyn_cast_or_null<Statement>(getValue(G[v]))) {
113                for (unsigned i = 0; i < dest->getNumOperands(); ++i) {
114                    if (dest->getOperand(i) == expr) {
115                        clear = false;
116                        break;
117                    }
118                }
119            }
120            if (LLVM_LIKELY(clear)) {
121                expr = nullptr;
122            }
123        }
124    }
125
126    for (auto e : make_iterator_range(out_edges(u, G))) {
127        if (LLVM_UNLIKELY(target(e, G) == v)) {
128            if (expr) {
129                if (G[e] == nullptr) {
130                    G[e] = expr;
131                } else if (G[e] != expr) {
132                    continue;
133                }
134            }
135            return;
136        }
137    }
138    G[boost::add_edge(u, v, G).first] = expr;
139}
140
141/** ------------------------------------------------------------------------------------------------------------- *
142 * @brief intersects
143 ** ------------------------------------------------------------------------------------------------------------- */
144template <class Type>
145inline bool intersects(const Type & A, const Type & B) {
146    auto first1 = A.begin(), last1 = A.end();
147    auto first2 = B.begin(), last2 = B.end();
148    while (first1 != last1 && first2 != last2) {
149        if (*first1 < *first2) {
150            ++first1;
151        } else if (*first2 < *first1) {
152            ++first2;
153        } else {
154            return true;
155        }
156    }
157    return false;
158}
159
160
161/** ------------------------------------------------------------------------------------------------------------- *
162 * @brief printGraph
163 ** ------------------------------------------------------------------------------------------------------------- */
164static void printGraph(const Graph & G, const std::string name) {
165    raw_os_ostream out(std::cerr);
166
167    std::vector<unsigned> c(num_vertices(G));
168    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
169
170    out << "digraph " << name << " {\n";
171    for (auto u : make_iterator_range(vertices(G))) {
172        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
173            continue;
174        }
175        out << "v" << u << " [label=\"" << u << ": ";
176        TypeId typeId;
177        PabloAST * expr;
178        Z3_ast node;
179        std::tie(typeId, expr, node) = G[u];
180        bool temporary = false;
181        bool error = false;
182        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
183            temporary = true;
184            switch (typeId) {
185                case TypeId::And:
186                    out << "And";
187                    break;
188                case TypeId::Or:
189                    out << "Or";
190                    break;
191                case TypeId::Xor:
192                    out << "Xor";
193                    break;
194                case TypeId::Not:
195                    out << "Not";
196                    break;
197                default:
198                    out << "???";
199                    error = true;
200                    break;
201            }
202            if (expr) {
203                out << " ("; PabloPrinter::print(expr, out); out << ")";
204            }
205        } else {
206            PabloPrinter::print(expr, out);
207        }
208        if (node == nullptr) {
209            out << " (*)";
210            error = true;
211        }
212        out << "\"";
213        if (typeId == TypeId::Var) {
214            out << " style=dashed";
215        }
216        if (error) {
217            out << " color=red";
218        } else if (temporary) {
219            out << " color=blue";
220        }
221        out << "];\n";
222    }
223    for (auto e : make_iterator_range(edges(G))) {
224        const auto s = source(e, G);
225        const auto t = target(e, G);
226        out << "v" << s << " -> v" << t;
227        bool cyclic = (c[s] == c[t]);
228        if (G[e] || cyclic) {
229            out << " [";
230             if (G[e]) {
231                out << "label=\"";
232                PabloPrinter::print(G[e], out);
233                out << "\" ";
234             }
235             if (cyclic) {
236                out << "color=red ";
237             }
238             out << "]";
239        }
240        out << ";\n";
241    }
242
243    if (num_vertices(G) > 0) {
244
245        out << "{ rank=same;";
246        for (auto u : make_iterator_range(vertices(G))) {
247            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
248                out << " v" << u;
249            }
250        }
251        out << "}\n";
252
253        out << "{ rank=same;";
254        for (auto u : make_iterator_range(vertices(G))) {
255            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
256                out << " v" << u;
257            }
258        }
259        out << "}\n";
260
261    }
262
263    out << "}\n\n";
264    out.flush();
265}
266
267/** ------------------------------------------------------------------------------------------------------------- *
268 * @brief optimize
269 ** ------------------------------------------------------------------------------------------------------------- */
270bool BooleanReassociationPass::optimize(PabloFunction & function) {
271
272    Z3_config cfg = Z3_mk_config();
273    Z3_context ctx = Z3_mk_context(cfg);
274    Z3_del_config(cfg);
275    Z3_solver solver = Z3_mk_solver(ctx);
276    Z3_solver_inc_ref(ctx, solver);
277
278    BooleanReassociationPass brp(ctx, solver, function);
279    brp.processScopes(function);
280
281    Z3_solver_dec_ref(ctx, solver);
282    Z3_del_context(ctx);
283
284    Simplifier::optimize(function);
285
286    return true;
287}
288
289/** ------------------------------------------------------------------------------------------------------------- *
290 * @brief processScopes
291 ** ------------------------------------------------------------------------------------------------------------- */
292inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
293    PabloBlock * const entry = function.getEntryBlock();
294    CharacterizationMap map;
295    // Map the constants and input variables
296    map.add(entry->createZeroes(), Z3_mk_false(mContext));
297    map.add(entry->createOnes(), Z3_mk_true(mContext));
298    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
299        map.add(mFunction.getParameter(i), makeVar());
300    }
301    mInFile = makeVar();
302    processScopes(entry, map);
303    return mModified;
304}
305
306/** ------------------------------------------------------------------------------------------------------------- *
307 * @brief processScopes
308 ** ------------------------------------------------------------------------------------------------------------- */
309void BooleanReassociationPass::processScopes(PabloBlock * const block, CharacterizationMap & map) {
310    Z3_solver_push(mContext, mSolver);
311    for (Statement * stmt = block->front(); stmt; ) {
312        if (LLVM_UNLIKELY(isa<If>(stmt))) {
313            if (LLVM_UNLIKELY(isa<Zeroes>(cast<If>(stmt)->getCondition()))) {
314                stmt = stmt->eraseFromParent(true);
315            } else {
316                CharacterizationMap nested(map);
317                processScopes(cast<If>(stmt)->getBody(), nested);
318                for (Assign * def : cast<If>(stmt)->getDefined()) {
319                    map.add(def, makeVar());
320                }
321                stmt = stmt->getNextNode();
322            }
323        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
324            if (LLVM_UNLIKELY(isa<Zeroes>(cast<While>(stmt)->getCondition()))) {
325                stmt = stmt->eraseFromParent(true);
326            } else {
327                CharacterizationMap nested(map);
328                processScopes(cast<While>(stmt)->getBody(), nested);
329                for (Next * var : cast<While>(stmt)->getVariants()) {
330                    map.add(var, makeVar());
331                }
332                stmt = stmt->getNextNode();
333            }
334        } else { // characterize this statement then check whether it is equivalent to any existing one.
335            stmt = characterize(stmt, map);
336        }
337    }
338    distributeScope(block, map);
339    Z3_solver_pop(mContext, mSolver, 1);
340}
341
342/** ------------------------------------------------------------------------------------------------------------- *
343 * @brief characterize
344 ** ------------------------------------------------------------------------------------------------------------- */
345inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & map) {
346    Z3_ast node = nullptr;
347    const size_t n = stmt->getNumOperands(); assert (n > 0);
348    if (isa<Variadic>(stmt)) {
349        Z3_ast operands[n];
350        for (size_t i = 0; i < n; ++i) {
351            operands[i] = map.get(stmt->getOperand(i)); assert (operands[i]);
352        }
353        if (isa<And>(stmt)) {
354            node = Z3_mk_and(mContext, n, operands);
355        } else if (isa<Or>(stmt)) {
356            node = Z3_mk_or(mContext, n, operands);
357        } else if (isa<Xor>(stmt)) {
358            node = Z3_mk_xor(mContext, operands[0], operands[1]);
359            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
360                node = Z3_mk_xor(mContext, node, operands[i]);
361            }
362        }
363    } else if (isa<Not>(stmt)) {
364        Z3_ast op = map.get(stmt->getOperand(0)); assert (op);
365        node = Z3_mk_not(mContext, op);
366    } else if (isa<Sel>(stmt)) {
367        Z3_ast operands[3];
368        for (size_t i = 0; i < 3; ++i) {
369            operands[i] = map.get(stmt->getOperand(i)); assert (operands[i]);
370        }
371        node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
372    } else if (LLVM_UNLIKELY(isa<InFile>(stmt) || isa<AtEOF>(stmt))) {
373        assert (stmt->getNumOperands() == 1);
374        Z3_ast check[2];
375        check[0] = map.get(stmt->getOperand(0)); assert (check[0]);
376        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
377        node = Z3_mk_and(mContext, 2, check);
378    } else {
379        if (LLVM_UNLIKELY(isa<Assign>(stmt) || isa<Next>(stmt))) {
380            Z3_ast op = map.get(stmt->getOperand(0)); assert (op);
381            map.add(stmt, op, true);
382        } else {
383            map.add(stmt, makeVar());
384        }
385        return stmt->getNextNode();
386    }
387    node = Z3_simplify(mContext, node); assert (node);
388    PabloAST * const replacement = map.findKey(node);
389    if (LLVM_LIKELY(replacement == nullptr)) {
390        map.add(stmt, node);
391        return stmt->getNextNode();
392    } else {
393        return stmt->replaceWith(replacement);
394    }
395
396}
397
398/** ------------------------------------------------------------------------------------------------------------- *
399 * @brief processScope
400 ** ------------------------------------------------------------------------------------------------------------- */
401inline void BooleanReassociationPass::distributeScope(PabloBlock * const block, CharacterizationMap & map) {
402    Graph G;
403    try {
404        transformAST(block, map, G);
405    } catch (std::runtime_error err) {
406        printGraph(G, "G");
407        throw err;
408    }
409}
410
411/** ------------------------------------------------------------------------------------------------------------- *
412 * @brief summarizeAST
413 *
414 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
415 * are "flattened" (i.e., allowed to have any number of inputs.)
416 ** ------------------------------------------------------------------------------------------------------------- */
417void BooleanReassociationPass::transformAST(PabloBlock * const block, CharacterizationMap & C, Graph & G) {
418    StatementMap S;
419    // Compute the base def-use graph ...
420    for (Statement * stmt : *block) {
421
422        const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, S, G, C.get(stmt));
423
424        for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
425            PabloAST * const op = stmt->getOperand(i);
426            if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
427                add_edge(op, makeVertex(TypeId::Var, op, C, S, G), u, G);
428            }
429        }
430
431        if (LLVM_UNLIKELY(isa<If>(stmt))) {
432            for (Assign * def : cast<const If>(stmt)->getDefined()) {
433                const Vertex v = makeVertex(TypeId::Var, def, C, S, G);
434                add_edge(def, u, v, G);
435                resolveNestedUsages(block, def, v, C, S, G, stmt);
436            }
437        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
438            // To keep G a DAG, we need to do a bit of surgery on loop variants because
439            // the next variables it produces can be used within the condition. Instead,
440            // we make the loop dependent on the original value of each Next node and
441            // the Next node dependent on the loop.
442            for (Next * var : cast<const While>(stmt)->getVariants()) {
443                const Vertex v = makeVertex(TypeId::Var, var, C, S, G);
444                assert (in_degree(v, G) == 1);
445                auto e = first(in_edges(v, G));
446                add_edge(G[e], source(e, G), u, G);
447                remove_edge(v, u, G);
448                add_edge(var, u, v, G);
449                resolveNestedUsages(block, var, v, C, S, G, stmt);
450            }
451        } else {           
452            resolveNestedUsages(block, stmt, u, C, S, G, stmt);
453        }
454    }
455
456    if (redistributeGraph(C, S, G)) {
457        factorGraph(G);
458        rewriteAST(block, G);
459        mModified = true;
460    }
461
462}
463
464/** ------------------------------------------------------------------------------------------------------------- *
465 * @brief resolveNestedUsages
466 ** ------------------------------------------------------------------------------------------------------------- */
467void BooleanReassociationPass::resolveNestedUsages(PabloBlock * const block, PabloAST * const expr, const Vertex u,
468                                                   CharacterizationMap & C, StatementMap & S, Graph & G,
469                                                   const Statement * const ignoreIfThis) const {
470    assert ("Cannot resolve nested usages of a null expression!" && expr);
471    for (PabloAST * user : expr->users()) { assert (user);
472        if (LLVM_LIKELY(user != expr && isa<Statement>(user))) {
473            PabloBlock * parent = cast<Statement>(user)->getParent(); assert (parent);
474            if (LLVM_UNLIKELY(parent != block)) {
475                for (;;) {
476                    if (parent->getParent() == block) {
477                        Statement * const branch = parent->getBranch();
478                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
479                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
480                            const Vertex v = makeVertex(TypeId::Var, user, C, S, G);
481                            add_edge(expr, u, v, G);
482                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
483                            add_edge(user, v, w, G);
484                        }
485                        break;
486                    }
487                    parent = parent->getParent();
488                    if (LLVM_UNLIKELY(parent == nullptr)) {
489                        assert (isa<Assign>(expr) || isa<Next>(expr));
490                        break;
491                    }
492                }
493            }
494        }
495    }
496}
497
498
499
500/** ------------------------------------------------------------------------------------------------------------- *
501 * @brief enumerateBicliques
502 *
503 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
504 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
505 * bipartition A and their adjacencies to be in B.
506  ** ------------------------------------------------------------------------------------------------------------- */
507template <class Graph>
508static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
509    using IntersectionSets = std::set<VertexSet>;
510
511    IntersectionSets B1;
512    for (auto u : A) {
513        if (in_degree(u, G) > 0) {
514            VertexSet adjacencies;
515            adjacencies.reserve(in_degree(u, G));
516            for (auto e : make_iterator_range(in_edges(u, G))) {
517                adjacencies.push_back(source(e, G));
518            }
519            std::sort(adjacencies.begin(), adjacencies.end());
520            assert(std::unique(adjacencies.begin(), adjacencies.end()) == adjacencies.end());
521            B1.insert(std::move(adjacencies));
522        }
523    }
524
525    IntersectionSets B(B1);
526
527    IntersectionSets Bi;
528
529    VertexSet clique;
530    for (auto i = B1.begin(); i != B1.end(); ++i) {
531        for (auto j = i; ++j != B1.end(); ) {
532            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
533            if (clique.size() > 0) {
534                if (B.count(clique) == 0) {
535                    Bi.insert(clique);
536                }
537                clique.clear();
538            }
539        }
540    }
541
542    for (;;) {
543        if (Bi.empty()) {
544            break;
545        }
546        B.insert(Bi.begin(), Bi.end());
547        IntersectionSets Bk;
548        for (auto i = B1.begin(); i != B1.end(); ++i) {
549            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
550                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
551                if (clique.size() > 0) {
552                    if (B.count(clique) == 0) {
553                        Bk.insert(clique);
554                    }
555                    clique.clear();
556                }
557            }
558        }
559        Bi.swap(Bk);
560    }
561
562    BicliqueSet cliques;
563    cliques.reserve(B.size());
564    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
565        VertexSet Ai(A);
566        for (const Vertex u : *Bi) {
567            VertexSet Aj;
568            Aj.reserve(out_degree(u, G));
569            for (auto e : make_iterator_range(out_edges(u, G))) {
570                Aj.push_back(target(e, G));
571            }
572            std::sort(Aj.begin(), Aj.end());
573            assert(std::unique(Aj.begin(), Aj.end()) == Aj.end());
574            VertexSet Ak;
575            Ak.reserve(std::min(Ai.size(), Aj.size()));
576            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
577            Ai.swap(Ak);
578        }
579        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
580        cliques.emplace_back(std::move(Ai), std::move(*Bi));
581    }
582    return std::move(cliques);
583}
584
585/** ------------------------------------------------------------------------------------------------------------- *
586 * @brief independentCliqueSets
587 ** ------------------------------------------------------------------------------------------------------------- */
588template <unsigned side>
589inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
590    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
591
592    const auto l = cliques.size();
593    IndependentSetGraph I(l);
594
595    // Initialize our weights
596    for (unsigned i = 0; i != l; ++i) {
597        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
598    }
599
600    // Determine our constraints
601    for (unsigned i = 0; i != l; ++i) {
602        for (unsigned j = i + 1; j != l; ++j) {
603            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
604                add_edge(i, j, I);
605            }
606        }
607    }
608
609    // Use the greedy algorithm to choose our independent set
610    VertexSet selected;
611    for (;;) {
612        unsigned w = 0;
613        Vertex u = 0;
614        for (unsigned i = 0; i != l; ++i) {
615            if (I[i] > w) {
616                w = I[i];
617                u = i;
618            }
619        }
620        if (w < minimum) break;
621        selected.push_back(u);
622        I[u] = 0;
623        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
624            I[v] = 0;
625        }
626    }
627
628    // Sort the selected list and then remove the unselected cliques
629    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
630    auto end = cliques.end();
631    for (const unsigned offset : selected) {
632        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
633    }
634    cliques.erase(cliques.begin(), end);
635
636    return std::move(cliques);
637}
638
639/** ------------------------------------------------------------------------------------------------------------- *
640 * @brief removeUnhelpfulBicliques
641 *
642 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
643 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
644 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
645 ** ------------------------------------------------------------------------------------------------------------- */
646static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G, DistributionGraph & H) {
647    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
648        const auto cardinalityA = std::get<0>(*ci).size();
649        VertexSet & B = std::get<1>(*ci);
650        for (auto bi = B.begin(); bi != B.end(); ) {
651            if (out_degree(H[*bi], G) == cardinalityA) {
652                ++bi;
653            } else {
654                bi = B.erase(bi);
655            }
656        }
657        if (B.size() > 1) {
658            ++ci;
659        } else {
660            ci = cliques.erase(ci);
661        }
662    }
663    return std::move(cliques);
664}
665
666/** ------------------------------------------------------------------------------------------------------------- *
667 * @brief safeDistributionSets
668 ** ------------------------------------------------------------------------------------------------------------- */
669static DistributionSets safeDistributionSets(const Graph & G, DistributionGraph & H) {
670
671    VertexSet sinks;
672    for (const Vertex u : make_iterator_range(vertices(H))) {
673        if (out_degree(u, H) == 0 && in_degree(u, H) != 0) {
674            sinks.push_back(u);
675        }
676    }
677    std::sort(sinks.begin(), sinks.end());
678
679    DistributionSets T;
680    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(H, sinks), G, H), 1);
681    for (Biclique & lower : lowerSet) {
682        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(H, std::get<1>(lower)), 2);
683        for (Biclique & upper : upperSet) {
684            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
685        }
686    }
687    return std::move(T);
688}
689
690/** ------------------------------------------------------------------------------------------------------------- *
691 * @brief getVertex
692 ** ------------------------------------------------------------------------------------------------------------- */
693template<typename ValueType, typename GraphType, typename MapType>
694static inline Vertex getVertex(const ValueType value, GraphType & G, MapType & M) {
695    const auto f = M.find(value);
696    if (f != M.end()) {
697        return f->second;
698    }
699    const auto u = add_vertex(value, G);
700    M.insert(std::make_pair(value, u));
701    return u;
702}
703
704/** ------------------------------------------------------------------------------------------------------------- *
705 * @brief generateDistributionGraph
706 ** ------------------------------------------------------------------------------------------------------------- */
707void generateDistributionGraph(const Graph & G, DistributionGraph & H) {
708    DistributionMap M;
709    for (const Vertex u : make_iterator_range(vertices(G))) {
710        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
711            continue;
712        } else if (isDistributive(G[u])) {
713            const TypeId outerTypeId = getType(G[u]);
714            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
715            flat_set<Vertex> distributable;
716            for (auto e : make_iterator_range(in_edges(u, G))) {
717                const Vertex v = source(e, G);
718                if (LLVM_UNLIKELY(getType(G[v]) == innerTypeId)) {
719                    bool safe = true;
720                    for (const auto e : make_iterator_range(out_edges(v, G))) {
721                        if (getType(G[target(e, G)]) != outerTypeId) {
722                            safe = false;
723                            break;
724                        }
725                    }
726                    if (safe) {
727                        distributable.insert(v);
728                    }
729                }
730            }
731            if (LLVM_LIKELY(distributable.size() > 1)) {
732                flat_set<Vertex> observed;
733                for (const Vertex v : distributable) {
734                    for (auto e : make_iterator_range(in_edges(v, G))) {
735                        const auto v = source(e, G);
736                        observed.insert(v);
737                    }
738                }
739                for (const Vertex w : observed) {
740                    for (auto e : make_iterator_range(out_edges(w, G))) {
741                        const Vertex v = target(e, G);
742                        if (distributable.count(v)) {
743                            const Vertex y = getVertex(v, H, M);
744                            boost::add_edge(y, getVertex(u, H, M), H);
745                            boost::add_edge(getVertex(w, H, M), y, H);
746                        }
747                    }
748                }
749            }
750        }
751    }
752}
753
754/** ------------------------------------------------------------------------------------------------------------- *
755 * @brief redistributeAST
756 *
757 * Apply the distribution law to reduce computations whenever possible.
758 ** ------------------------------------------------------------------------------------------------------------- */
759bool BooleanReassociationPass::redistributeGraph(CharacterizationMap & C, StatementMap & M, Graph & G) const {
760
761    bool modified = false;
762
763    DistributionGraph H;
764
765    contractGraph(M, G);
766
767    for (;;) {
768
769        for (;;) {
770
771            generateDistributionGraph(G, H);
772
773            // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
774            if (num_vertices(H) == 0) {
775                break;
776            }
777
778            const DistributionSets distributionSets = safeDistributionSets(G, H);
779
780            if (LLVM_UNLIKELY(distributionSets.empty())) {
781                break;
782            }
783
784            modified = true;
785
786            for (const DistributionSet & set : distributionSets) {
787
788                // Each distribution tuple consists of the sources, intermediary, and sink nodes.
789                const VertexSet & sources = std::get<0>(set);
790                const VertexSet & intermediary = std::get<1>(set);
791                const VertexSet & sinks = std::get<2>(set);
792
793                const TypeId outerTypeId = getType(G[H[sinks.front()]]);
794                assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
795                const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
796
797                const Vertex x = makeVertex(outerTypeId, nullptr, G);
798                const Vertex y = makeVertex(innerTypeId, nullptr, G);
799
800                // Update G to reflect the distributed operations (including removing the subgraph of
801                // the to-be distributed edges.)
802
803                add_edge(nullptr, x, y, G);
804
805                for (const Vertex i : sources) {
806                    const auto u = H[i];
807                    for (const Vertex j : intermediary) {
808                        const auto v = H[j];
809                        const auto e = edge(u, v, G); assert (e.second);
810                        remove_edge(e.first, G);
811                    }
812                    add_edge(nullptr, u, y, G);
813                }
814
815                for (const Vertex i : intermediary) {
816                    const auto u = H[i];
817                    for (const Vertex j : sinks) {
818                        const auto v = H[j];
819                        const auto e = edge(u, v, G); assert (e.second);
820                        add_edge(G[e.first], y, v, G);
821                        remove_edge(e.first, G);
822                    }
823                    add_edge(nullptr, u, x, G);
824                    getDefinition(G[u]) = nullptr;
825                }
826
827            }
828
829            H.clear();
830
831            contractGraph(M, G);
832        }
833
834        // Although exceptionally unlikely, it's possible that if we can reduce the graph, we could
835        // further simplify it. Restart the process if and only if we succeed.
836        if (LLVM_UNLIKELY(reduceGraph(C, M, G))) {
837            if (LLVM_UNLIKELY(contractGraph(M, G))) {
838                H.clear();
839                continue;
840            }
841        }
842
843        break;
844    }
845
846    return modified;
847}
848
849/** ------------------------------------------------------------------------------------------------------------- *
850 * @brief isNonEscaping
851 ** ------------------------------------------------------------------------------------------------------------- */
852inline bool isNonEscaping(const VertexData & data) {
853    // If these are redundant, the Simplifier pass will eliminate them. Trust that they're necessary.
854    switch (getType(data)) {
855        case TypeId::Assign:
856        case TypeId::Next:
857        case TypeId::If:
858        case TypeId::While:
859        case TypeId::Count:
860            return false;
861        default:
862            return true;
863    }
864}
865
866/** ------------------------------------------------------------------------------------------------------------- *
867 * @brief contractGraph
868 ** ------------------------------------------------------------------------------------------------------------- */
869bool BooleanReassociationPass::contractGraph(StatementMap & M, Graph & G) const {
870
871    bool contracted = false;
872
873    circular_buffer<Vertex> ordering(num_vertices(G));
874
875    topological_sort(G, std::back_inserter(ordering)); // reverse topological ordering
876
877    // first contract the graph
878    for (const Vertex u : ordering) {
879        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
880            continue;
881        } else if (LLVM_LIKELY(out_degree(u, G) > 0)) {
882            if (isAssociative(G[u])) {
883                if (LLVM_UNLIKELY(in_degree(u, G) == 1)) {
884                    // We have a redundant node here that'll simply end up being a duplicate
885                    // of the input value. Remove it and add any of its outgoing edges to its
886                    // input node.
887                    const auto ei = first(in_edges(u, G));
888                    const Vertex v = source(ei, G);
889                    for (auto ej : make_iterator_range(out_edges(u, G))) {
890                        const Vertex w = target(ej, G);
891                        add_edge(G[ei], v, w, G);
892                    }
893                    removeVertex(u, M, G);
894                    contracted = true;
895                } else if (LLVM_UNLIKELY(out_degree(u, G) == 1)) {
896                    // Otherwise if we have a single user, we have a similar case as above but
897                    // we can only merge this vertex into the outgoing instruction if they are
898                    // of the same type.
899                    const auto ei = first(out_edges(u, G));
900                    const Vertex v = target(ei, G);
901                    if (LLVM_UNLIKELY(getType(G[v]) == getType(G[u]))) {
902                        for (auto ej : make_iterator_range(in_edges(u, G))) {
903                            add_edge(G[ei], source(ej, G), v, G);
904                        }
905                        removeVertex(u, M, G);
906                        contracted = true;
907                    }                   
908                }
909            }
910        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
911            removeVertex(u, M, G);
912            contracted = true;
913        }
914    }
915    return contracted;
916}
917
918/** ------------------------------------------------------------------------------------------------------------- *
919 * @brief isReducible
920 ** ------------------------------------------------------------------------------------------------------------- */
921inline bool isReducible(const VertexData & data) {
922    switch (getType(data)) {
923        case TypeId::Var:
924        case TypeId::If:
925        case TypeId::While:
926            return false;
927        default:
928            return true;
929    }
930}
931
932/** ------------------------------------------------------------------------------------------------------------- *
933 * @brief reduceGraph
934 ** ------------------------------------------------------------------------------------------------------------- */
935bool BooleanReassociationPass::reduceGraph(CharacterizationMap & C, StatementMap & S, Graph & G) const {
936
937    bool reduced = false;
938
939    VertexMap M;
940    circular_buffer<Vertex> ordering(num_vertices(G));
941
942    topological_sort(G, std::front_inserter(ordering)); // topological ordering
943
944    // first contract the graph
945    for (const Vertex u : ordering) {
946        if (isReducible(G[u])) {
947            Z3_ast & node = getDefinition(G[u]);
948            if (isAssociative(G[u])) {
949                const TypeId typeId = getType(G[u]);
950                if (node == nullptr) {
951                    const auto n = in_degree(u, G); assert (n > 1);
952                    Z3_ast operands[n];
953                    unsigned i = 0;
954                    for (auto e : make_iterator_range(in_edges(u, G))) {
955                        const Vertex v = source(e, G);
956                        assert (getDefinition(G[v]));
957                        operands[i++] = getDefinition(G[v]);
958                    }
959                    switch (typeId) {
960                        case TypeId::And:
961                            node = Z3_mk_and(mContext, n, operands);
962                            break;
963                        case TypeId::Or:
964                            node = Z3_mk_or(mContext, n, operands);
965                            break;
966                        case TypeId::Xor:
967                            node = Z3_mk_xor(mContext, operands[0], operands[1]);
968                            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
969                                node = Z3_mk_xor(mContext, node, operands[i]);
970                            }
971                            break;
972                        default: llvm_unreachable("unexpected type id");
973                    }
974
975                    assert (node);
976                    node = Z3_simplify(mContext, node);
977                }
978                graph_traits<Graph>::in_edge_iterator begin, end;
979restart:        if (in_degree(u, G) > 1) {
980                    std::tie(begin, end) = in_edges(u, G);
981                    for (auto i = begin; ++i != end; ) {
982                        const auto v = source(*i, G);
983                        for (auto j = begin; j != i; ++j) {
984                            const auto w = source(*j, G);
985                            Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
986                            assert (operands[0]);
987                            assert (operands[1]);
988                            Z3_ast test = nullptr;
989                            switch (typeId) {
990                                case TypeId::And:
991                                    test = Z3_mk_and(mContext, 2, operands); break;
992                                case TypeId::Or:
993                                    test = Z3_mk_or(mContext, 2, operands); break;
994                                case TypeId::Xor:
995                                    test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
996                                default:
997                                    llvm_unreachable("impossible type id");
998                            }
999                            assert (test);
1000                            test = Z3_simplify(mContext, test);
1001                            PabloAST * const factor = C.findKey(test);
1002                            if (LLVM_UNLIKELY(factor != nullptr)) {
1003                                const Vertex a = makeVertex(TypeId::Var, factor, S, G, test);
1004                                // note: unless both edges carry an Pablo AST replacement value, they will converge into a single edge.
1005                                PabloAST * const r1 = G[*i];
1006                                PabloAST * const r2 = G[*j];
1007
1008                                remove_edge(*i, G);
1009                                remove_edge(*j, G);
1010
1011                                if (LLVM_UNLIKELY(r1 && r2)) {
1012                                    add_edge(r1, a, u, G);
1013                                    add_edge(r2, a, u, G);
1014                                } else {
1015                                    add_edge(r1 ? r1 : r2, a, u, G);
1016                                }
1017
1018//                                errs() << " -- subsituting (" << a << ',' << u << ")=" << Z3_ast_to_string(mContext, test)
1019//                                       << " for (" << v << ',' << u << ")=" << Z3_ast_to_string(mContext, getDefinition(G[v]));
1020
1021//                                switch (typeId) {
1022//                                    case TypeId::And:
1023//                                        errs() << " ∧ "; break;
1024//                                    case TypeId::Or:
1025//                                        errs() << " √ ";  break;
1026//                                    case TypeId::Xor:
1027//                                        errs() << " ⊕ ";  break;
1028//                                    default:
1029//                                        llvm_unreachable("impossible type id");
1030//                                }
1031
1032//                                errs() << "(" << w << ',' << u << ")=" << Z3_ast_to_string(mContext, getDefinition(G[w]))
1033//                                       << "\n";
1034
1035                                reduced = true;
1036                                goto restart;
1037                            }
1038                        }
1039                    }
1040                }
1041            }
1042
1043            if (LLVM_UNLIKELY(node == nullptr)) {
1044                throw std::runtime_error("No Z3 characterization for vertex " + std::to_string(u));
1045            }
1046
1047            auto f = M.find(node);
1048            if (LLVM_LIKELY(f == M.end())) {
1049                M.emplace(node, u);
1050            } else if (isAssociative(G[u])) {
1051                const Vertex v = f->second;
1052
1053//                errs() << " -- subsituting " << u << ":=\n" << Z3_ast_to_string(mContext, getDefinition(G[u]))
1054//                       << "\n for " << v << ":=\n" << Z3_ast_to_string(mContext, getDefinition(G[v])) << "\n";
1055
1056                for (auto e : make_iterator_range(out_edges(u, G))) {
1057                    add_edge(G[e], v, target(e, G), G);
1058                }
1059                removeVertex(u, S, G);
1060                reduced = true;
1061            }
1062        }
1063    }
1064    return reduced;
1065}
1066
1067/** ------------------------------------------------------------------------------------------------------------- *
1068 * @brief factorGraph
1069 ** ------------------------------------------------------------------------------------------------------------- */
1070bool BooleanReassociationPass::factorGraph(const TypeId typeId, Graph & G, std::vector<Vertex> & factors) const {
1071
1072    if (LLVM_UNLIKELY(factors.empty())) {
1073        return false;
1074    }
1075
1076    std::vector<Vertex> I, J, K;
1077
1078    bool modified = false;
1079
1080    for (auto i = factors.begin(); ++i != factors.end(); ) {
1081        assert (getType(G[*i]) == typeId);
1082        for (auto ei : make_iterator_range(in_edges(*i, G))) {
1083            I.push_back(source(ei, G));
1084        }
1085        std::sort(I.begin(), I.end());
1086        for (auto j = factors.begin(); j != i; ++j) {
1087            for (auto ej : make_iterator_range(in_edges(*j, G))) {
1088                J.push_back(source(ej, G));
1089            }
1090            std::sort(J.begin(), J.end());
1091            // get the pairwise intersection of each set of inputs (i.e., their common subexpression)
1092            std::set_intersection(I.begin(), I.end(), J.begin(), J.end(), std::back_inserter(K));
1093            assert (std::is_sorted(K.begin(), K.end()));
1094            // if the intersection contains at least two elements
1095            const auto n = K.size();
1096            if (n > 1) {
1097                Vertex a = *i;
1098                Vertex b = *j;
1099                if (LLVM_UNLIKELY(in_degree(a, G) == n || in_degree(b, G) == n)) {
1100                    if (in_degree(a, G) != n) {
1101                        assert (in_degree(b, G) == n);
1102                        std::swap(a, b);
1103                    }
1104                    assert (in_degree(a, G) == n);
1105                    if (in_degree(b, G) == n) {
1106                        for (auto e : make_iterator_range(out_edges(b, G))) {
1107                            add_edge(G[e], a, target(e, G), G);
1108                        }
1109                        removeVertex(b, G);
1110                    } else {
1111                        for (auto u : K) {
1112                            remove_edge(u, b, G);
1113                        }
1114                        add_edge(nullptr, a, b, G);
1115                    }
1116                } else {
1117                    Vertex v = makeVertex(typeId, nullptr, G);
1118                    for (auto u : K) {
1119                        remove_edge(u, a, G);
1120                        remove_edge(u, b, G);
1121                        add_edge(nullptr, u, v, G);
1122                    }
1123                    add_edge(nullptr, v, a, G);
1124                    add_edge(nullptr, v, b, G);
1125                    factors.push_back(v);
1126                }
1127                modified = true;
1128            }
1129            K.clear();
1130            J.clear();
1131        }
1132        I.clear();
1133    }
1134    return modified;
1135}
1136
1137/** ------------------------------------------------------------------------------------------------------------- *
1138 * @brief factorGraph
1139 ** ------------------------------------------------------------------------------------------------------------- */
1140bool BooleanReassociationPass::factorGraph(Graph & G) const {
1141    // factor the associative vertices.
1142    std::vector<Vertex> factors;
1143    bool factored = false;
1144    for (unsigned i = 0; i < 3; ++i) {
1145        TypeId typeId[3] = { TypeId::And, TypeId::Or, TypeId::Xor};
1146        for (auto j : make_iterator_range(vertices(G))) {
1147            if (getType(G[j]) == typeId[i]) {
1148                factors.push_back(j);
1149            }
1150        }
1151        if (factorGraph(typeId[i], G, factors)) {
1152            factored = true;
1153        }
1154        factors.clear();
1155    }
1156    return factored;
1157}
1158
1159
1160inline bool isMutable(const Vertex u, const Graph & G) {
1161    return getType(G[u]) != TypeId::Var;
1162}
1163
1164/** ------------------------------------------------------------------------------------------------------------- *
1165 * @brief rewriteAST
1166 ** ------------------------------------------------------------------------------------------------------------- */
1167void BooleanReassociationPass::rewriteAST(PabloBlock * const block, Graph & G) {
1168
1169    using line_t = long long int;
1170
1171    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
1172
1173    Z3_config cfg = Z3_mk_config();
1174    Z3_set_param_value(cfg, "MODEL", "true");
1175    Z3_context ctx = Z3_mk_context(cfg);
1176    Z3_del_config(cfg);
1177    Z3_solver solver = Z3_mk_solver(ctx);
1178    Z3_solver_inc_ref(ctx, solver);
1179
1180    std::vector<Z3_ast> mapping(num_vertices(G), nullptr);
1181
1182    flat_map<PabloAST *, Z3_ast> M;
1183
1184    // Generate the variables
1185    const auto ty = Z3_mk_int_sort(ctx);
1186    Z3_ast ZERO = Z3_mk_int(ctx, 0, ty);
1187
1188    for (const Vertex u : make_iterator_range(vertices(G))) {
1189        const auto var = Z3_mk_fresh_const(ctx, nullptr, ty);
1190        Z3_ast constraint = nullptr;
1191        if (in_degree(u, G) == 0) {
1192            constraint = Z3_mk_eq(ctx, var, ZERO);
1193        } else {
1194            constraint = Z3_mk_gt(ctx, var, ZERO);
1195        }
1196        Z3_solver_assert(ctx, solver, constraint);
1197
1198        PabloAST * const expr = getValue(G[u]);
1199        if (expr) {
1200            const bool added = M.emplace(expr, var).second;
1201            assert ("G contains duplicate vertices for the same statement!" && added);
1202        }
1203        mapping[u] = var;
1204    }
1205
1206    // Add in the dependency constraints
1207    for (const Vertex u : make_iterator_range(vertices(G))) {
1208        Z3_ast const t = mapping[u];
1209        for (auto e : make_iterator_range(in_edges(u, G))) {
1210            Z3_ast const s = mapping[source(e, G)];
1211            Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, s, t));
1212        }
1213    }
1214
1215    // Compute the soft ordering constraints
1216    std::vector<Z3_ast> ordering(0);
1217    ordering.reserve(2 * M.size() - 1);
1218
1219    Z3_ast prior = nullptr;
1220    unsigned gap = 1;
1221    for (Statement * stmt : *block) {
1222        auto f = M.find(stmt);
1223        if (f != M.end()) {
1224            Z3_ast const node = f->second;
1225            if (prior) {
1226                ordering.push_back(Z3_mk_lt(ctx, prior, node));
1227                Z3_ast ops[2] = { node, prior };
1228                ordering.push_back(Z3_mk_le(ctx, Z3_mk_sub(ctx, 2, ops), Z3_mk_int(ctx, gap, ty)));
1229            } else {
1230                ordering.push_back(Z3_mk_eq(ctx, node, Z3_mk_int(ctx, gap, ty)));
1231
1232            }
1233            prior = node;
1234            gap = 0;
1235        }
1236        ++gap;
1237    }
1238
1239    if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) < 0)) {
1240        throw std::runtime_error("Unable to construct a topological ordering during reassociation!");
1241    }
1242
1243    Z3_model model = Z3_solver_get_model(ctx, solver);
1244    Z3_model_inc_ref(ctx, model);
1245
1246    std::vector<Vertex> S(0);
1247    S.reserve(num_vertices(G));
1248
1249    std::vector<line_t> L(num_vertices(G));
1250
1251
1252
1253    for (const Vertex u : make_iterator_range(vertices(G))) {
1254        line_t line = LoadEarly ? 0 : MAX_INT;
1255        if (isMutable(u, G)) {
1256            Z3_ast value;
1257            if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, mapping[u], Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1258                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1259            }
1260            if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1261                throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1262            }
1263            S.push_back(u);
1264        }
1265        L[u] = line;
1266    }
1267
1268    Z3_model_dec_ref(ctx, model);
1269
1270    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1271
1272    block->setInsertPoint(nullptr);
1273
1274    std::vector<Vertex> T;
1275
1276    line_t count = 1;
1277
1278    for (auto u : S) {
1279        PabloAST *& stmt = getValue(G[u]);
1280
1281        assert (isMutable(u, G));
1282        assert (L[u] > 0 && L[u] < MAX_INT);
1283
1284        if (isAssociative(G[u])) {
1285
1286            if (in_degree(u, G) == 0 || out_degree(u, G) == 0) {
1287                throw std::runtime_error("Vertex " + std::to_string(u) + " is either a source or sink node but marked as associative!");
1288            }
1289
1290            Statement * ip = block->getInsertPoint();
1291            ip = ip ? ip->getNextNode() : block->front();
1292
1293            const auto typeId = getType(G[u]);
1294
1295            T.reserve(in_degree(u, G));
1296            for (const auto e : make_iterator_range(in_edges(u, G))) {
1297                T.push_back(source(e, G));
1298            }
1299
1300            // Then sort them by their line position (noting any incoming value will either be 0 or MAX_INT)
1301            std::sort(T.begin(), T.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1302            if (LoadEarly) {
1303                block->setInsertPoint(nullptr);
1304            }
1305
1306            circular_buffer<PabloAST *> Q(2); // in_degree(u, G));
1307            for (auto u : T) {
1308                PabloAST * expr = getValue(G[u]);
1309                if (LLVM_UNLIKELY(expr == nullptr)) {
1310                    throw std::runtime_error("Vertex " + std::to_string(u) + " does not have an expression!");
1311                }
1312                Q.push_back(expr);
1313                if (Q.size() > 1) {
1314
1315                    PabloAST * e1 = Q.front(); Q.pop_front();
1316                    PabloAST * e2 = Q.front(); Q.pop_front();
1317
1318                    if (in_degree(u, G) > 0) {
1319                        if (LLVM_UNLIKELY(!dominates(e1, e2))) {
1320                            std::string tmp;
1321                            raw_string_ostream out(tmp);
1322                            out << "e1: ";
1323                            PabloPrinter::print(e1, out);
1324                            out << " does not dominate e2: ";
1325                            PabloPrinter::print(e2, out);
1326                            throw std::runtime_error(out.str());
1327                        }
1328                        assert (dominates(e1, e2));
1329                        assert (dominates(e2, ip));
1330                        Statement * dom = cast<Statement>(expr);
1331                        for (;;) {
1332                            PabloBlock * const parent = dom->getParent();
1333                            if (parent == block) {
1334                                block->setInsertPoint(dom);
1335                                break;
1336                            }
1337                            dom = parent->getBranch(); assert(dom);
1338                        }
1339                        assert (dominates(dom, ip));
1340                    }
1341
1342                    switch (typeId) {
1343                        case TypeId::And:
1344                            expr = block->createAnd(e1, e2); break;
1345                        case TypeId::Or:
1346                            expr = block->createOr(e1, e2); break;
1347                        case TypeId::Xor:
1348                            expr = block->createXor(e1, e2); break;
1349                        default: break;
1350                    }
1351                    Q.push_front(expr);
1352                }
1353            }
1354
1355            assert (Q.size() == 1);
1356
1357            T.clear();
1358
1359            block->setInsertPoint(ip->getPrevNode());
1360
1361            PabloAST * const replacement = Q.front(); assert (replacement);
1362            for (auto e : make_iterator_range(out_edges(u, G))) {
1363                if (G[e]) {
1364                    if (PabloAST * user = getValue(G[target(e, G)])) {
1365                        cast<Statement>(user)->replaceUsesOfWith(G[e], replacement);
1366                    }
1367                }
1368            }
1369
1370            stmt = replacement;
1371        }
1372
1373        assert (stmt);
1374        assert (inCurrentBlock(stmt, block));
1375
1376        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
1377            for (auto e : make_iterator_range(out_edges(u, G))) {
1378                const auto v = target(e, G);
1379                assert (L[v] == std::numeric_limits<line_t>::max());
1380                L[v] = count;
1381            }
1382        }
1383
1384        block->insert(cast<Statement>(stmt));
1385        L[u] = count++; // update the line count with the actual one.
1386    }
1387
1388    Z3_solver_dec_ref(ctx, solver);
1389    Z3_del_context(ctx);
1390
1391    Statement * const end = block->getInsertPoint(); assert (end);
1392    for (;;) {
1393        Statement * const next = end->getNextNode();
1394        if (next == nullptr) {
1395            break;
1396        }
1397        next->eraseFromParent(true);
1398    }
1399
1400    #ifndef NDEBUG
1401    PabloVerifier::verify(mFunction, "post-reassociation");
1402    #endif
1403
1404}
1405
1406/** ------------------------------------------------------------------------------------------------------------- *
1407 * @brief addSummaryVertex
1408 ** ------------------------------------------------------------------------------------------------------------- */
1409Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
1410//    for (Vertex u : make_iterator_range(vertices(G))) {
1411//        if (LLVM_UNLIKELY(in_degree(u, G) == 0 && out_degree(u, G) == 0)) {
1412//            std::get<0>(G[u]) = typeId;
1413//            std::get<1>(G[u]) = expr;
1414//            return u;
1415//        }
1416//    }
1417    return add_vertex(std::make_tuple(typeId, expr, node), G);
1418}
1419
1420/** ------------------------------------------------------------------------------------------------------------- *
1421 * @brief addSummaryVertex
1422 ** ------------------------------------------------------------------------------------------------------------- */
1423Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
1424    assert (expr);
1425    const auto f = M.find(expr);
1426    if (f != M.end()) {
1427        assert (getValue(G[f->second]) == expr);
1428        return f->second;
1429    }
1430    const Vertex u = makeVertex(typeId, expr, G, node);
1431    M.emplace(expr, u);
1432    return u;
1433}
1434
1435/** ------------------------------------------------------------------------------------------------------------- *
1436 * @brief addSummaryVertex
1437 ** ------------------------------------------------------------------------------------------------------------- */
1438Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & M, Graph & G) {
1439    assert (expr);
1440    const auto f = M.find(expr);
1441    if (f != M.end()) {
1442        assert (getValue(G[f->second]) == expr);
1443        return f->second;
1444    }
1445    const Vertex u = makeVertex(typeId, expr, G, C.get(expr));
1446    M.emplace(expr, u);
1447    return u;
1448}
1449
1450/** ------------------------------------------------------------------------------------------------------------- *
1451 * @brief removeSummaryVertex
1452 ** ------------------------------------------------------------------------------------------------------------- */
1453inline void BooleanReassociationPass::removeVertex(const Vertex u, StatementMap & M, Graph & G) const {
1454    VertexData & ref = G[u];
1455    if (std::get<1>(ref)) {
1456        auto f = M.find(std::get<1>(ref));
1457        assert (f != M.end());
1458        M.erase(f);
1459    }
1460    removeVertex(u, G);
1461}
1462
1463/** ------------------------------------------------------------------------------------------------------------- *
1464 * @brief removeSummaryVertex
1465 ** ------------------------------------------------------------------------------------------------------------- */
1466inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
1467    VertexData & ref = G[u];
1468    clear_vertex(u, G);
1469    std::get<0>(ref) = TypeId::Var;
1470    std::get<1>(ref) = nullptr;
1471    std::get<2>(ref) = nullptr;
1472}
1473
1474/** ------------------------------------------------------------------------------------------------------------- *
1475 * @brief make
1476 ** ------------------------------------------------------------------------------------------------------------- */
1477inline Z3_ast BooleanReassociationPass::makeVar() const {
1478    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1479//    Z3_inc_ref(mContext, node);
1480    return node;
1481}
1482
1483/** ------------------------------------------------------------------------------------------------------------- *
1484 * @brief add
1485 ** ------------------------------------------------------------------------------------------------------------- */
1486inline Z3_ast BooleanReassociationPass::CharacterizationMap::add(PabloAST * const expr, Z3_ast node, const bool forwardOnly) {
1487    assert (expr && node);
1488    mForward.emplace(expr, node);
1489    if (!forwardOnly) {
1490        mBackward.emplace(node, expr);
1491    }
1492    return node;
1493}
1494
1495/** ------------------------------------------------------------------------------------------------------------- *
1496 * @brief get
1497 ** ------------------------------------------------------------------------------------------------------------- */
1498inline Z3_ast BooleanReassociationPass::CharacterizationMap::get(PabloAST * const expr) const {
1499    assert (expr);
1500    auto f = mForward.find(expr);
1501    if (LLVM_UNLIKELY(f == mForward.end())) {
1502        if (mPredecessor == nullptr) {
1503            return nullptr;
1504        }
1505        return mPredecessor->get(expr);
1506    }
1507    return f->second;
1508}
1509
1510/** ------------------------------------------------------------------------------------------------------------- *
1511 * @brief get
1512 ** ------------------------------------------------------------------------------------------------------------- */
1513inline PabloAST * BooleanReassociationPass::CharacterizationMap::findKey(Z3_ast const node) const {
1514    assert (node);
1515    auto f = mBackward.find(node);
1516    if (LLVM_UNLIKELY(f == mBackward.end())) {
1517        if (mPredecessor == nullptr) {
1518            return nullptr;
1519        }
1520        return mPredecessor->findKey(node);
1521    }
1522    return f->second;
1523}
1524
1525/** ------------------------------------------------------------------------------------------------------------- *
1526 * @brief constructor
1527 ** ------------------------------------------------------------------------------------------------------------- */
1528inline BooleanReassociationPass::BooleanReassociationPass(Z3_context ctx, Z3_solver solver, PabloFunction & f)
1529: mContext(ctx)
1530, mSolver(solver)
1531, mFunction(f)
1532, mModified(false) {
1533
1534}
1535
1536}
Note: See TracBrowser for help on using the repository browser.