source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 5157

Last change on this file since 5157 was 5157, checked in by nmedfort, 3 years ago

Bug fix for reassociation pass.

File size: 64.5 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <boost/graph/strong_components.hpp>
8#include <pablo/optimizers/pablo_simplifier.hpp>
9#include <pablo/analysis/pabloverifier.hpp>
10#include <algorithm>
11#include <numeric> // std::iota
12#include <pablo/printer_pablos.h>
13#include <iostream>
14#include <llvm/Support/CommandLine.h>
15#include "maxsat.hpp"
16
17//  noif-dist-mult-dist-50 \p{Cham}(?<!\p{Mc})
18
19using namespace boost;
20using namespace boost::container;
21
22
23static cl::OptionCategory ReassociationOptions("Reassociation Optimization Options", "These options control the Pablo Reassociation optimization pass.");
24
25static cl::opt<unsigned> LoadEarly("Reassociation-Load-Early", cl::init(false),
26                                  cl::desc("When recomputing an Associative operation, load values from preceeding blocks at the beginning of the "
27                                           "Scope Block rather than at the point of first use."),
28                                  cl::cat(ReassociationOptions));
29
30namespace pablo {
31
32using TypeId = PabloAST::ClassTypeId;
33using Graph = BooleanReassociationPass::Graph;
34using Vertex = Graph::vertex_descriptor;
35using VertexData = BooleanReassociationPass::VertexData;
36using DistributionGraph = BooleanReassociationPass::DistributionGraph;
37using DistributionMap = flat_map<Graph::vertex_descriptor, DistributionGraph::vertex_descriptor>;
38using VertexSet = std::vector<Vertex>;
39using VertexSets = std::vector<VertexSet>;
40using Biclique = std::pair<VertexSet, VertexSet>;
41using BicliqueSet = std::vector<Biclique>;
42using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
43using DistributionSets = std::vector<DistributionSet>;
44
45/** ------------------------------------------------------------------------------------------------------------- *
46 * @brief helper functions
47 ** ------------------------------------------------------------------------------------------------------------- */
48template<typename Iterator>
49inline Graph::edge_descriptor first(const std::pair<Iterator, Iterator> & range) {
50    assert (range.first != range.second);
51    return *range.first;
52}
53
54static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock * const block) {
55    return stmt->getParent() == block;
56}
57
58static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock * const block) {
59    return expr ? isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block) : true;
60}
61
62inline TypeId & getType(VertexData & data) {
63    return std::get<0>(data);
64}
65
66inline TypeId getType(const VertexData & data) {
67    return std::get<0>(data);
68}
69
70inline Z3_ast & getDefinition(VertexData & data) {
71    return std::get<2>(data);
72}
73
74inline PabloAST * getValue(const VertexData & data) {
75    return std::get<1>(data);
76}
77
78inline PabloAST *& getValue(VertexData & data) {
79    return std::get<1>(data);
80}
81
82inline bool isAssociative(const VertexData & data) {
83    switch (getType(data)) {
84        case TypeId::And:
85        case TypeId::Or:
86        case TypeId::Xor:
87            return true;
88        default:
89            return false;
90    }
91}
92
93inline bool isDistributive(const VertexData & data) {
94    switch (getType(data)) {
95        case TypeId::And:
96        case TypeId::Or:
97            return true;
98        default:
99            return false;
100    }
101}
102
103void add_edge(PabloAST * expr, const Vertex u, const Vertex v, Graph & G) {
104    // Make sure each edge is unique
105    assert (u < num_vertices(G) && v < num_vertices(G));
106    assert (u != v);
107
108    // Just because we've supplied an expr doesn't mean it's useful. Check it.
109    if (expr) {
110        if (isAssociative(G[v])) {
111            expr = nullptr;
112        } else {
113            bool clear = true;
114            if (const Statement * dest = dyn_cast_or_null<Statement>(getValue(G[v]))) {
115                for (unsigned i = 0; i < dest->getNumOperands(); ++i) {
116                    if (dest->getOperand(i) == expr) {
117                        clear = false;
118                        break;
119                    }
120                }
121            }
122            if (LLVM_LIKELY(clear)) {
123                expr = nullptr;
124            }
125        }
126    }
127
128    for (auto e : make_iterator_range(out_edges(u, G))) {
129        if (LLVM_UNLIKELY(target(e, G) == v)) {
130            if (expr) {
131                if (G[e] == nullptr) {
132                    G[e] = expr;
133                } else if (G[e] != expr) {
134                    continue;
135                }
136            }
137            return;
138        }
139    }
140    G[boost::add_edge(u, v, G).first] = expr;
141}
142
143/** ------------------------------------------------------------------------------------------------------------- *
144 * @brief intersects
145 ** ------------------------------------------------------------------------------------------------------------- */
146template <class Type>
147inline bool intersects(const Type & A, const Type & B) {
148    auto first1 = A.begin(), last1 = A.end();
149    auto first2 = B.begin(), last2 = B.end();
150    while (first1 != last1 && first2 != last2) {
151        if (*first1 < *first2) {
152            ++first1;
153        } else if (*first2 < *first1) {
154            ++first2;
155        } else {
156            return true;
157        }
158    }
159    return false;
160}
161
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief printGraph
165 ** ------------------------------------------------------------------------------------------------------------- */
166static void printGraph(const Graph & G, const std::string name) {
167    raw_os_ostream out(std::cerr);
168
169    std::vector<unsigned> c(num_vertices(G));
170    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
171
172    out << "digraph " << name << " {\n";
173    for (auto u : make_iterator_range(vertices(G))) {
174        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
175            continue;
176        }
177        out << "v" << u << " [label=\"" << u << ": ";
178        TypeId typeId;
179        PabloAST * expr;
180        Z3_ast node;
181        std::tie(typeId, expr, node) = G[u];
182        bool temporary = false;
183        bool error = false;
184        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
185            temporary = true;
186            switch (typeId) {
187                case TypeId::And:
188                    out << "And";
189                    break;
190                case TypeId::Or:
191                    out << "Or";
192                    break;
193                case TypeId::Xor:
194                    out << "Xor";
195                    break;
196                case TypeId::Not:
197                    out << "Not";
198                    break;
199                default:
200                    out << "???";
201                    error = true;
202                    break;
203            }
204            if (expr) {
205                out << " ("; PabloPrinter::print(expr, out); out << ")";
206            }
207        } else {
208            PabloPrinter::print(expr, out);
209        }
210        if (node == nullptr) {
211            out << " (*)";
212            error = true;
213        }
214        out << "\"";
215        if (typeId == TypeId::Var) {
216            out << " style=dashed";
217        }
218        if (error) {
219            out << " color=red";
220        } else if (temporary) {
221            out << " color=blue";
222        }
223        out << "];\n";
224    }
225    for (auto e : make_iterator_range(edges(G))) {
226        const auto s = source(e, G);
227        const auto t = target(e, G);
228        out << "v" << s << " -> v" << t;
229        bool cyclic = (c[s] == c[t]);
230        if (G[e] || cyclic) {
231            out << " [";
232             if (G[e]) {
233                out << "label=\"";
234                PabloPrinter::print(G[e], out);
235                out << "\" ";
236             }
237             if (cyclic) {
238                out << "color=red ";
239             }
240             out << "]";
241        }
242        out << ";\n";
243    }
244
245    if (num_vertices(G) > 0) {
246
247        out << "{ rank=same;";
248        for (auto u : make_iterator_range(vertices(G))) {
249            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
250                out << " v" << u;
251            }
252        }
253        out << "}\n";
254
255        out << "{ rank=same;";
256        for (auto u : make_iterator_range(vertices(G))) {
257            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
258                out << " v" << u;
259            }
260        }
261        out << "}\n";
262
263    }
264
265    out << "}\n\n";
266    out.flush();
267}
268
269/** ------------------------------------------------------------------------------------------------------------- *
270 * @brief optimize
271 ** ------------------------------------------------------------------------------------------------------------- */
272bool BooleanReassociationPass::optimize(PabloFunction & function) {
273
274    Z3_config cfg = Z3_mk_config();
275    Z3_context ctx = Z3_mk_context_rc(cfg);
276    Z3_del_config(cfg);
277
278    Z3_params params = Z3_mk_params(ctx);
279    Z3_params_inc_ref(ctx, params);
280    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "pull_cheap_ite"), true);
281    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "local_ctx"), true);
282
283    Z3_tactic ctx_solver_simplify = Z3_mk_tactic(ctx, "ctx-solver-simplify");
284    Z3_tactic_inc_ref(ctx, ctx_solver_simplify);
285
286    BooleanReassociationPass brp(ctx, params, ctx_solver_simplify, function);
287    brp.processScopes(function);
288
289    Z3_params_dec_ref(ctx, params);
290    Z3_tactic_dec_ref(ctx, ctx_solver_simplify);
291    Z3_del_context(ctx);
292
293    PabloVerifier::verify(function, "post-reassociation");
294
295    Simplifier::optimize(function);
296
297    return true;
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief processScopes
302 ** ------------------------------------------------------------------------------------------------------------- */
303inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
304    mRefs.clear();
305    CharacterizationMap C;
306    PabloBlock * const entry = function.getEntryBlock();
307    // Map the constants and input variables
308    C.add(entry->createZeroes(), Z3_mk_false(mContext));
309    C.add(entry->createOnes(), Z3_mk_true(mContext));
310    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
311        C.add(mFunction.getParameter(i), makeVar());
312    }
313    mInFile = makeVar();
314    processScopes(entry, C);
315    return mModified;
316}
317
318/** ------------------------------------------------------------------------------------------------------------- *
319 * @brief processScopes
320 ** ------------------------------------------------------------------------------------------------------------- */
321void BooleanReassociationPass::processScopes(PabloBlock * const block, CharacterizationMap & C) {
322    const auto offset = mRefs.size();
323    for (Statement * stmt = block->front(); stmt; ) {
324        if (LLVM_UNLIKELY(isa<If>(stmt))) {
325            if (LLVM_UNLIKELY(isa<Zeroes>(cast<If>(stmt)->getCondition()))) {
326                stmt = stmt->eraseFromParent(true);
327            } else {
328                CharacterizationMap nested(C);
329                processScopes(cast<If>(stmt)->getBody(), nested);
330                for (Assign * def : cast<If>(stmt)->getDefined()) {
331                    C.add(def, makeVar());
332                }
333                stmt = stmt->getNextNode();
334            }
335        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
336            if (LLVM_UNLIKELY(isa<Zeroes>(cast<While>(stmt)->getCondition()))) {
337                stmt = stmt->eraseFromParent(true);
338            } else {
339                CharacterizationMap nested(C);
340                processScopes(cast<While>(stmt)->getBody(), nested);
341                for (Next * var : cast<While>(stmt)->getVariants()) {
342                    C.add(var, makeVar());
343                }
344                stmt = stmt->getNextNode();
345            }
346        } else { // characterize this statement then check whether it is equivalent to any existing one.
347            stmt = characterize(stmt, C);
348        }
349    }   
350    distributeScope(block, C);
351    for (auto i = mRefs.begin() + offset; i != mRefs.end(); ++i) {
352        Z3_dec_ref(mContext, *i);
353    }
354    mRefs.erase(mRefs.begin() + offset, mRefs.end());
355}
356
357/** ------------------------------------------------------------------------------------------------------------- *
358 * @brief characterize
359 ** ------------------------------------------------------------------------------------------------------------- */
360inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & C) {
361    Z3_ast node = nullptr;
362    const size_t n = stmt->getNumOperands(); assert (n > 0);
363    if (isa<Variadic>(stmt)) {
364        Z3_ast operands[n];
365        for (size_t i = 0; i < n; ++i) {
366            operands[i] = C.get(stmt->getOperand(i)); assert (operands[i]);
367        }
368        if (isa<And>(stmt)) {
369            node = Z3_mk_and(mContext, n, operands);
370        } else if (isa<Or>(stmt)) {
371            node = Z3_mk_or(mContext, n, operands);
372        } else if (isa<Xor>(stmt)) {
373            node = Z3_mk_xor(mContext, operands[0], operands[1]);
374            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
375                Z3_inc_ref(mContext, node);
376                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
377                Z3_inc_ref(mContext, temp);
378                Z3_dec_ref(mContext, node);
379                node = temp;
380            }
381        }
382    } else if (isa<Not>(stmt)) {
383        Z3_ast op = C.get(stmt->getOperand(0)); assert (op);
384        node = Z3_mk_not(mContext, op);
385    } else if (isa<Sel>(stmt)) {
386        Z3_ast operands[3];
387        for (size_t i = 0; i < 3; ++i) {
388            operands[i] = C.get(stmt->getOperand(i)); assert (operands[i]);
389        }
390        node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
391    } else if (LLVM_UNLIKELY(isa<InFile>(stmt) || isa<AtEOF>(stmt))) {
392        assert (stmt->getNumOperands() == 1);
393        Z3_ast check[2];
394        check[0] = C.get(stmt->getOperand(0)); assert (check[0]);
395        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
396        node = Z3_mk_and(mContext, 2, check);
397    } else {
398        if (LLVM_UNLIKELY(isa<Assign>(stmt) || isa<Next>(stmt))) {
399            Z3_ast op = C.get(stmt->getOperand(0)); assert (op);
400            C.add(stmt, op, true);
401        } else {
402            C.add(stmt, makeVar());
403        }
404        return stmt->getNextNode();
405    }
406    Z3_inc_ref(mContext, node);
407    node = simplify(node);
408    PabloAST * const replacement = C.findKey(node);
409    if (LLVM_LIKELY(replacement == nullptr)) {
410        C.add(stmt, node);
411        mRefs.push_back(node);
412        return stmt->getNextNode();
413    } else {
414        Z3_dec_ref(mContext, node);
415        return stmt->replaceWith(replacement);
416    }
417}
418
419/** ------------------------------------------------------------------------------------------------------------- *
420 * @brief processScope
421 ** ------------------------------------------------------------------------------------------------------------- */
422inline void BooleanReassociationPass::distributeScope(PabloBlock * const block, CharacterizationMap & C) {
423    Graph G;
424    try {
425        mBlock = block;
426        transformAST(C, G);
427    } catch (std::runtime_error err) {
428        printGraph(G, "E");
429        throw err;
430    } catch (std::exception err) {
431        printGraph(G, "E");
432        throw err;
433    }
434}
435
436/** ------------------------------------------------------------------------------------------------------------- *
437 * @brief summarizeAST
438 *
439 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
440 * are "flattened" (i.e., allowed to have any number of inputs.)
441 ** ------------------------------------------------------------------------------------------------------------- */
442void BooleanReassociationPass::transformAST(CharacterizationMap & C, Graph & G) {
443
444    StatementMap S;
445
446    // Compute the base def-use graph ...
447    for (Statement * stmt : *mBlock) {
448        const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, S, G, C.get(stmt));
449        for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
450            PabloAST * const op = stmt->getOperand(i);
451            if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
452                add_edge(op, makeVertex(TypeId::Var, op, C, S, G), u, G);
453            }
454        }
455        if (LLVM_UNLIKELY(isa<If>(stmt))) {
456            for (Assign * def : cast<const If>(stmt)->getDefined()) {
457                const Vertex v = makeVertex(TypeId::Var, def, C, S, G);
458                add_edge(def, u, v, G);
459                resolveNestedUsages(def, v, C, S, G, stmt);
460            }
461        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
462            // To keep G a DAG, we need to do a bit of surgery on loop variants because
463            // the next variables it produces can be used within the condition. Instead,
464            // we make the loop dependent on the original value of each Next node and
465            // the Next node dependent on the loop.
466            for (Next * var : cast<const While>(stmt)->getVariants()) {
467                const Vertex v = makeVertex(TypeId::Var, var, C, S, G);
468                assert (in_degree(v, G) == 1);
469                auto e = first(in_edges(v, G));
470                add_edge(G[e], source(e, G), u, G);
471                remove_edge(v, u, G);
472                add_edge(var, u, v, G);
473                resolveNestedUsages(var, v, C, S, G, stmt);
474            }
475        } else {
476            resolveNestedUsages(stmt, u, C, S, G, stmt);
477        }
478    }
479
480//    printGraph(G, "G");
481
482    VertexMap M;
483    if (redistributeGraph(C, M, G)) {
484        factorGraph(G);
485
486//        printGraph(G, "H");
487
488        rewriteAST(C, M, G);
489        mModified = true;
490    }
491
492}
493
494/** ------------------------------------------------------------------------------------------------------------- *
495 * @brief resolveNestedUsages
496 ** ------------------------------------------------------------------------------------------------------------- */
497void BooleanReassociationPass::resolveNestedUsages(PabloAST * const expr, const Vertex u,
498                                                   CharacterizationMap & C, StatementMap & S, Graph & G,
499                                                   const Statement * const ignoreIfThis) const {
500    assert ("Cannot resolve nested usages of a null expression!" && expr);
501    for (PabloAST * user : expr->users()) { assert (user);
502        if (LLVM_LIKELY(user != expr && isa<Statement>(user))) {
503            PabloBlock * parent = cast<Statement>(user)->getParent(); assert (parent);
504            if (LLVM_UNLIKELY(parent != mBlock)) {
505                for (;;) {
506                    if (parent->getParent() == mBlock) {
507                        Statement * const branch = parent->getBranch();
508                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
509                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
510                            const Vertex v = makeVertex(TypeId::Var, user, C, S, G);
511                            add_edge(expr, u, v, G);
512                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
513                            add_edge(user, v, w, G);
514                        }
515                        break;
516                    }
517                    parent = parent->getParent();
518                    if (LLVM_UNLIKELY(parent == nullptr)) {
519                        assert (isa<Assign>(expr) || isa<Next>(expr));
520                        break;
521                    }
522                }
523            }
524        }
525    }
526}
527
528
529
530/** ------------------------------------------------------------------------------------------------------------- *
531 * @brief enumerateBicliques
532 *
533 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
534 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
535 * bipartition A and their adjacencies to be in B.
536  ** ------------------------------------------------------------------------------------------------------------- */
537template <class Graph>
538static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
539    using IntersectionSets = std::set<VertexSet>;
540
541    IntersectionSets B1;
542    for (auto u : A) {
543        if (in_degree(u, G) > 0) {
544            VertexSet adjacencies;
545            adjacencies.reserve(in_degree(u, G));
546            for (auto e : make_iterator_range(in_edges(u, G))) {
547                adjacencies.push_back(source(e, G));
548            }
549            std::sort(adjacencies.begin(), adjacencies.end());
550            assert(std::unique(adjacencies.begin(), adjacencies.end()) == adjacencies.end());
551            B1.insert(std::move(adjacencies));
552        }
553    }
554
555    IntersectionSets B(B1);
556
557    IntersectionSets Bi;
558
559    VertexSet clique;
560    for (auto i = B1.begin(); i != B1.end(); ++i) {
561        for (auto j = i; ++j != B1.end(); ) {
562            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
563            if (clique.size() > 0) {
564                if (B.count(clique) == 0) {
565                    Bi.insert(clique);
566                }
567                clique.clear();
568            }
569        }
570    }
571
572    for (;;) {
573        if (Bi.empty()) {
574            break;
575        }
576        B.insert(Bi.begin(), Bi.end());
577        IntersectionSets Bk;
578        for (auto i = B1.begin(); i != B1.end(); ++i) {
579            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
580                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
581                if (clique.size() > 0) {
582                    if (B.count(clique) == 0) {
583                        Bk.insert(clique);
584                    }
585                    clique.clear();
586                }
587            }
588        }
589        Bi.swap(Bk);
590    }
591
592    BicliqueSet cliques;
593    cliques.reserve(B.size());
594    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
595        VertexSet Ai(A);
596        for (const Vertex u : *Bi) {
597            VertexSet Aj;
598            Aj.reserve(out_degree(u, G));
599            for (auto e : make_iterator_range(out_edges(u, G))) {
600                Aj.push_back(target(e, G));
601            }
602            std::sort(Aj.begin(), Aj.end());
603            assert(std::unique(Aj.begin(), Aj.end()) == Aj.end());
604            VertexSet Ak;
605            Ak.reserve(std::min(Ai.size(), Aj.size()));
606            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
607            Ai.swap(Ak);
608        }
609        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
610        cliques.emplace_back(std::move(Ai), std::move(*Bi));
611    }
612    return std::move(cliques);
613}
614
615/** ------------------------------------------------------------------------------------------------------------- *
616 * @brief independentCliqueSets
617 ** ------------------------------------------------------------------------------------------------------------- */
618template <unsigned side>
619inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
620    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
621
622    const auto l = cliques.size();
623    IndependentSetGraph I(l);
624
625    // Initialize our weights
626    for (unsigned i = 0; i != l; ++i) {
627        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
628    }
629
630    // Determine our constraints
631    for (unsigned i = 0; i != l; ++i) {
632        for (unsigned j = i + 1; j != l; ++j) {
633            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
634                add_edge(i, j, I);
635            }
636        }
637    }
638
639    // Use the greedy algorithm to choose our independent set
640    VertexSet selected;
641    for (;;) {
642        unsigned w = 0;
643        Vertex u = 0;
644        for (unsigned i = 0; i != l; ++i) {
645            if (I[i] > w) {
646                w = I[i];
647                u = i;
648            }
649        }
650        if (w < minimum) break;
651        selected.push_back(u);
652        I[u] = 0;
653        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
654            I[v] = 0;
655        }
656    }
657
658    // Sort the selected list and then remove the unselected cliques
659    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
660    auto end = cliques.end();
661    for (const unsigned offset : selected) {
662        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
663    }
664    cliques.erase(cliques.begin(), end);
665
666    return std::move(cliques);
667}
668
669/** ------------------------------------------------------------------------------------------------------------- *
670 * @brief removeUnhelpfulBicliques
671 *
672 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
673 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
674 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
675 ** ------------------------------------------------------------------------------------------------------------- */
676static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G, DistributionGraph & H) {
677    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
678        const auto cardinalityA = std::get<0>(*ci).size();
679        VertexSet & B = std::get<1>(*ci);
680        for (auto bi = B.begin(); bi != B.end(); ) {
681            if (out_degree(H[*bi], G) == cardinalityA) {
682                ++bi;
683            } else {
684                bi = B.erase(bi);
685            }
686        }
687        if (B.size() > 1) {
688            ++ci;
689        } else {
690            ci = cliques.erase(ci);
691        }
692    }
693    return std::move(cliques);
694}
695
696/** ------------------------------------------------------------------------------------------------------------- *
697 * @brief safeDistributionSets
698 ** ------------------------------------------------------------------------------------------------------------- */
699static DistributionSets safeDistributionSets(const Graph & G, DistributionGraph & H) {
700
701    VertexSet sinks;
702    for (const Vertex u : make_iterator_range(vertices(H))) {
703        if (out_degree(u, H) == 0 && in_degree(u, H) != 0) {
704            sinks.push_back(u);
705        }
706    }
707    std::sort(sinks.begin(), sinks.end());
708
709    DistributionSets T;
710    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(H, sinks), G, H), 1);
711    for (Biclique & lower : lowerSet) {
712        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(H, std::get<1>(lower)), 2);
713        for (Biclique & upper : upperSet) {
714            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
715        }
716    }
717    return std::move(T);
718}
719
720/** ------------------------------------------------------------------------------------------------------------- *
721 * @brief getVertex
722 ** ------------------------------------------------------------------------------------------------------------- */
723template<typename ValueType, typename GraphType, typename MapType>
724static inline Vertex getVertex(const ValueType value, GraphType & G, MapType & M) {
725    const auto f = M.find(value);
726    if (f != M.end()) {
727        return f->second;
728    }
729    const auto u = add_vertex(value, G);
730    M.insert(std::make_pair(value, u));
731    return u;
732}
733
734/** ------------------------------------------------------------------------------------------------------------- *
735 * @brief generateDistributionGraph
736 ** ------------------------------------------------------------------------------------------------------------- */
737void generateDistributionGraph(const Graph & G, DistributionGraph & H) {
738    DistributionMap M;
739    for (const Vertex u : make_iterator_range(vertices(G))) {
740        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
741            continue;
742        } else if (isDistributive(G[u])) {
743            const TypeId outerTypeId = getType(G[u]);
744            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
745            flat_set<Vertex> distributable;
746            for (auto e : make_iterator_range(in_edges(u, G))) {
747                const Vertex v = source(e, G);
748                if (LLVM_UNLIKELY(getType(G[v]) == innerTypeId)) {
749                    bool safe = true;
750                    for (const auto e : make_iterator_range(out_edges(v, G))) {
751                        if (getType(G[target(e, G)]) != outerTypeId) {
752                            safe = false;
753                            break;
754                        }
755                    }
756                    if (safe) {
757                        distributable.insert(v);
758                    }
759                }
760            }
761            if (LLVM_LIKELY(distributable.size() > 1)) {
762                flat_set<Vertex> observed;
763                for (const Vertex v : distributable) {
764                    for (auto e : make_iterator_range(in_edges(v, G))) {
765                        const auto v = source(e, G);
766                        observed.insert(v);
767                    }
768                }
769                for (const Vertex w : observed) {
770                    for (auto e : make_iterator_range(out_edges(w, G))) {
771                        const Vertex v = target(e, G);
772                        if (distributable.count(v)) {
773                            const Vertex y = getVertex(v, H, M);
774                            boost::add_edge(y, getVertex(u, H, M), H);
775                            boost::add_edge(getVertex(w, H, M), y, H);
776                        }
777                    }
778                }
779            }
780        }
781    }
782}
783
784/** ------------------------------------------------------------------------------------------------------------- *
785 * @brief redistributeAST
786 *
787 * Apply the distribution law to reduce computations whenever possible.
788 ** ------------------------------------------------------------------------------------------------------------- */
789bool BooleanReassociationPass::redistributeGraph(CharacterizationMap & C, VertexMap & M, Graph & G) {
790
791    bool modified = false;
792
793    DistributionGraph H;
794
795    contractGraph(G);
796
797    for (;;) {
798
799        for (;;) {
800
801            generateDistributionGraph(G, H);
802
803            // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
804            if (num_vertices(H) == 0) {
805                break;
806            }
807
808            const DistributionSets distributionSets = safeDistributionSets(G, H);
809
810            if (LLVM_UNLIKELY(distributionSets.empty())) {
811                break;
812            }
813
814            modified = true;
815
816            for (const DistributionSet & set : distributionSets) {
817
818                // Each distribution tuple consists of the sources, intermediary, and sink nodes.
819                const VertexSet & sources = std::get<0>(set);
820                const VertexSet & intermediary = std::get<1>(set);
821                const VertexSet & sinks = std::get<2>(set);
822
823                const TypeId outerTypeId = getType(G[H[sinks.front()]]);
824                assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
825                const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
826
827                const Vertex x = makeVertex(outerTypeId, nullptr, G);
828                const Vertex y = makeVertex(innerTypeId, nullptr, G);
829
830                // Update G to reflect the distributed operations (including removing the subgraph of
831                // the to-be distributed edges.)
832
833                add_edge(nullptr, x, y, G);
834
835                for (const Vertex i : sources) {
836                    const auto u = H[i];
837                    for (const Vertex j : intermediary) {
838                        const auto v = H[j];
839                        const auto e = edge(u, v, G); assert (e.second);
840                        remove_edge(e.first, G);
841                    }
842                    add_edge(nullptr, u, y, G);
843                }
844
845                for (const Vertex i : intermediary) {
846                    const auto u = H[i];
847                    for (const Vertex j : sinks) {
848                        const auto v = H[j];
849                        const auto e = edge(u, v, G); assert (e.second);
850                        add_edge(G[e.first], y, v, G);
851                        remove_edge(e.first, G);
852                    }
853                    add_edge(nullptr, u, x, G);
854                    getDefinition(G[u]) = nullptr;
855                }
856
857            }
858
859            H.clear();
860
861            contractGraph(G);
862        }
863
864        // Although exceptionally unlikely, it's possible that if we can reduce the graph, we could
865        // further simplify it. Restart the process if and only if we succeed.
866        if (reduceGraph(C, M, G)) {
867            if (LLVM_UNLIKELY(contractGraph(G))) {
868                H.clear();
869                continue;
870            }
871        }
872
873        break;
874    }
875
876    return modified;
877}
878
879/** ------------------------------------------------------------------------------------------------------------- *
880 * @brief isNonEscaping
881 ** ------------------------------------------------------------------------------------------------------------- */
882inline bool isNonEscaping(const VertexData & data) {
883    // If these are redundant, the Simplifier pass will eliminate them. Trust that they're necessary.
884    switch (getType(data)) {
885        case TypeId::Assign:
886        case TypeId::Next:
887        case TypeId::If:
888        case TypeId::While:
889        case TypeId::Count:
890            return false;
891        default:
892            return true;
893    }
894}
895
896/** ------------------------------------------------------------------------------------------------------------- *
897 * @brief unique_source
898 ** ------------------------------------------------------------------------------------------------------------- */
899inline bool has_unique_source(const Vertex u, const Graph & G) {
900    if (in_degree(u, G) > 0) {
901        graph_traits<Graph>::in_edge_iterator i, end;
902        std::tie(i, end) = in_edges(u, G);
903        const Vertex v = source(*i, G);
904        while (++i != end) {
905            if (source(*i, G) != v) {
906                return false;
907            }
908        }
909        return true;
910    }
911    return false;
912}
913
914/** ------------------------------------------------------------------------------------------------------------- *
915 * @brief unique_target
916 ** ------------------------------------------------------------------------------------------------------------- */
917inline bool has_unique_target(const Vertex u, const Graph & G) {
918    if (out_degree(u, G) > 0) {
919        graph_traits<Graph>::out_edge_iterator i, end;
920        std::tie(i, end) = out_edges(u, G);
921        const Vertex v = target(*i, G);
922        while (++i != end) {
923            if (target(*i, G) != v) {
924                return false;
925            }
926        }
927        return true;
928    }
929    return false;
930}
931
932
933/** ------------------------------------------------------------------------------------------------------------- *
934 * @brief contractGraph
935 ** ------------------------------------------------------------------------------------------------------------- */
936bool BooleanReassociationPass::contractGraph(Graph & G) const {
937
938    bool contracted = false;
939
940    circular_buffer<Vertex> ordering(num_vertices(G));
941
942    topological_sort(G, std::back_inserter(ordering)); // reverse topological ordering
943
944    // first contract the graph
945    for (const Vertex u : ordering) {
946        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
947            continue;
948        } else if (LLVM_LIKELY(out_degree(u, G) > 0)) {
949            if (isAssociative(G[u])) {
950                if (LLVM_UNLIKELY(has_unique_source(u, G))) {
951                    // We have a redundant node here that'll simply end up being a duplicate
952                    // of the input value. Remove it and add any of its outgoing edges to its
953                    // input node.
954                    const auto ei = first(in_edges(u, G));
955                    const Vertex v = source(ei, G);
956                    for (auto ej : make_iterator_range(out_edges(u, G))) {
957                        add_edge(G[ej], v, target(ej, G), G);
958                    }
959                    removeVertex(u, G);
960                    contracted = true;
961                } else if (LLVM_UNLIKELY(has_unique_target(u, G))) {
962                    // Otherwise if we have a single user, we have a similar case as above but
963                    // we can only merge this vertex into the outgoing instruction if they are
964                    // of the same type.
965                    const auto ei = first(out_edges(u, G));
966                    const Vertex v = target(ei, G);
967                    if (LLVM_UNLIKELY(getType(G[v]) == getType(G[u]))) {
968                        for (auto ej : make_iterator_range(in_edges(u, G))) {
969                            add_edge(G[ej], source(ej, G), v, G);
970                        }
971                        removeVertex(u, G);
972                        contracted = true;
973                    }
974                }
975            }
976        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
977            removeVertex(u, G);
978            contracted = true;
979        }
980    }
981    return contracted;
982}
983
984/** ------------------------------------------------------------------------------------------------------------- *
985 * @brief isReducible
986 ** ------------------------------------------------------------------------------------------------------------- */
987inline bool isReducible(const VertexData & data) {
988    switch (getType(data)) {
989        case TypeId::Var:
990        case TypeId::If:
991        case TypeId::While:
992            return false;
993        default:
994            return true;
995    }
996}
997
998/** ------------------------------------------------------------------------------------------------------------- *
999 * @brief reduceGraph
1000 ** ------------------------------------------------------------------------------------------------------------- */
1001BooleanReassociationPass::Reduction BooleanReassociationPass::reduceVertex(const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G, const bool use_expensive_simplification) {
1002
1003    Reduction reduction = Reduction::NoChange;
1004
1005    assert (isReducible(G[u]));
1006
1007    Z3_ast node = getDefinition(G[u]);
1008    if (isAssociative(G[u])) {
1009        const TypeId typeId = getType(G[u]);
1010        if (node == nullptr) {
1011            const auto n = in_degree(u, G); assert (n > 1);
1012            Z3_ast operands[n];
1013            unsigned i = 0;
1014            for (auto e : make_iterator_range(in_edges(u, G))) {
1015                const Vertex v = source(e, G);
1016                assert (getDefinition(G[v]));
1017                operands[i++] = getDefinition(G[v]);
1018            }
1019            switch (typeId) {
1020                case TypeId::And:
1021                    node = Z3_mk_and(mContext, n, operands);
1022                    break;
1023                case TypeId::Or:
1024                    node = Z3_mk_or(mContext, n, operands);
1025                    break;
1026                case TypeId::Xor:
1027                    node = Z3_mk_xor(mContext, operands[0], operands[1]);
1028                    for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
1029                        node = Z3_mk_xor(mContext, node, operands[i]);
1030                    }
1031                    break;
1032                default: llvm_unreachable("unexpected type id");
1033            }
1034            assert (node);
1035            Z3_inc_ref(mContext, node);
1036            mRefs.push_back(node);
1037            getDefinition(G[u]) = node;
1038        }
1039
1040        graph_traits<Graph>::in_edge_iterator begin, end;
1041restart:if (in_degree(u, G) > 1) {
1042            std::tie(begin, end) = in_edges(u, G);
1043            for (auto i = begin; ++i != end; ) {
1044                const auto v = source(*i, G);
1045                for (auto j = begin; j != i; ++j) {
1046                    const auto w = source(*j, G);
1047                    Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
1048                    Z3_ast test = nullptr;
1049                    switch (typeId) {
1050                        case TypeId::And:
1051                            test = Z3_mk_and(mContext, 2, operands); break;
1052                        case TypeId::Or:
1053                            test = Z3_mk_or(mContext, 2, operands); break;
1054                        case TypeId::Xor:
1055                            test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
1056                        default:
1057                            llvm_unreachable("impossible type id");
1058                    }
1059                    assert (test);
1060                    Z3_inc_ref(mContext, test);
1061                    test = simplify(test, use_expensive_simplification);
1062                    bool replacement = false;
1063                    Vertex x = 0;
1064                    const auto f = M.find(test);
1065                    if (LLVM_UNLIKELY(f != M.end())) {
1066                        x = f->second;
1067                        assert (getDefinition(G[x]) == test);
1068                        Z3_dec_ref(mContext, test);
1069                        replacement = true;
1070                    } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
1071                        PabloAST * const factor = C.predecessor()->findKey(test);
1072                        if (LLVM_UNLIKELY(factor != nullptr)) {
1073                            x = makeVertex(TypeId::Var, factor, G, test);
1074                            M.emplace(test, x);
1075                            replacement = true;
1076                            mRefs.push_back(test);
1077                        }
1078                    }
1079
1080                    if (LLVM_UNLIKELY(replacement)) {
1081
1082                        // note: unless both edges carry an Pablo AST replacement value, they will converge into a single edge.
1083                        PabloAST * const r1 = G[*i];
1084                        PabloAST * const r2 = G[*j];
1085
1086                        remove_edge(*i, G);
1087                        remove_edge(*j, G);
1088
1089                        if (LLVM_UNLIKELY(r1 && r2)) {
1090                            add_edge(r1, x, u, G);
1091                            add_edge(r2, x, u, G);
1092                        } else {
1093                            add_edge(r1 ? r1 : r2, x, u, G);
1094                        }
1095
1096                        reduction = Reduction::Simplified;
1097
1098                        goto restart;
1099                    }
1100
1101                    Z3_dec_ref(mContext, test);
1102                }
1103            }
1104        }
1105    }
1106
1107    if (LLVM_UNLIKELY(node == nullptr)) {
1108        throw std::runtime_error("No Z3 characterization for vertex " + std::to_string(u));
1109    }
1110
1111    auto f = M.find(node);
1112    if (LLVM_LIKELY(f == M.end())) {
1113        M.emplace(node, u);
1114    } else if (isAssociative(G[u])) {
1115        const Vertex v = f->second;
1116        for (auto e : make_iterator_range(out_edges(u, G))) {
1117            add_edge(G[e], v, target(e, G), G);
1118        }
1119        removeVertex(u, G);
1120        reduction = Reduction::Removed;
1121    }
1122
1123    return reduction;
1124}
1125
1126/** ------------------------------------------------------------------------------------------------------------- *
1127 * @brief reduceGraph
1128 ** ------------------------------------------------------------------------------------------------------------- */
1129bool BooleanReassociationPass::reduceGraph(CharacterizationMap & C, VertexMap & M, Graph & G) {
1130
1131    bool reduced = false;
1132
1133    circular_buffer<Vertex> ordering(num_vertices(G));
1134
1135    topological_sort(G, std::front_inserter(ordering)); // topological ordering
1136
1137    M.clear();
1138
1139    // first contract the graph
1140    for (const Vertex u : ordering) {
1141        if (isReducible(G[u])) {
1142            if (reduceVertex(u, C, M, G, false) != Reduction::NoChange) {
1143                reduced = true;
1144            }
1145        }
1146    }
1147    return reduced;
1148}
1149
1150/** ------------------------------------------------------------------------------------------------------------- *
1151 * @brief factorGraph
1152 ** ------------------------------------------------------------------------------------------------------------- */
1153bool BooleanReassociationPass::factorGraph(const TypeId typeId, Graph & G, std::vector<Vertex> & factors) const {
1154
1155    if (LLVM_UNLIKELY(factors.empty())) {
1156        return false;
1157    }
1158
1159    std::vector<Vertex> I, J, K;
1160
1161    bool modified = false;
1162
1163    for (unsigned i = 1; i < factors.size(); ++i) {
1164        assert (getType(G[factors[i]]) == typeId);
1165        for (auto ei : make_iterator_range(in_edges(factors[i], G))) {
1166            I.push_back(source(ei, G));
1167        }
1168        std::sort(I.begin(), I.end());
1169        for (unsigned j = 0; j < i; ++j) {
1170            for (auto ej : make_iterator_range(in_edges(factors[j], G))) {
1171                J.push_back(source(ej, G));
1172            }
1173            std::sort(J.begin(), J.end());
1174            // get the pairwise intersection of each set of inputs (i.e., their common subexpression)
1175            std::set_intersection(I.begin(), I.end(), J.begin(), J.end(), std::back_inserter(K));
1176            assert (std::is_sorted(K.begin(), K.end()));
1177            // if the intersection contains at least two elements
1178            const auto n = K.size();
1179            if (n > 1) {
1180                Vertex a = factors[i];
1181                Vertex b = factors[j];
1182                if (LLVM_UNLIKELY(in_degree(a, G) == n || in_degree(b, G) == n)) {
1183                    if (in_degree(a, G) != n) {
1184                        assert (in_degree(b, G) == n);
1185                        std::swap(a, b);
1186                    }
1187                    assert (in_degree(a, G) == n);
1188                    if (in_degree(b, G) == n) {
1189                        for (auto e : make_iterator_range(out_edges(b, G))) {
1190                            add_edge(G[e], a, target(e, G), G);
1191                        }
1192                        removeVertex(b, G);
1193                    } else {
1194                        for (auto u : K) {
1195                            remove_edge(u, b, G);
1196                        }
1197                        add_edge(nullptr, a, b, G);
1198                    }
1199                } else {
1200                    Vertex v = makeVertex(typeId, nullptr, G);
1201                    for (auto u : K) {
1202                        remove_edge(u, a, G);
1203                        remove_edge(u, b, G);
1204                        add_edge(nullptr, u, v, G);
1205                    }
1206                    add_edge(nullptr, v, a, G);
1207                    add_edge(nullptr, v, b, G);
1208                    factors.push_back(v);
1209                }
1210                modified = true;
1211            }
1212            K.clear();
1213            J.clear();
1214        }
1215        I.clear();
1216    }
1217    return modified;
1218}
1219
1220/** ------------------------------------------------------------------------------------------------------------- *
1221 * @brief factorGraph
1222 ** ------------------------------------------------------------------------------------------------------------- */
1223bool BooleanReassociationPass::factorGraph(Graph & G) const {
1224    // factor the associative vertices.
1225    std::vector<Vertex> factors;
1226    bool factored = false;
1227    for (unsigned i = 0; i < 3; ++i) {
1228        TypeId typeId[3] = { TypeId::And, TypeId::Or, TypeId::Xor};
1229        for (auto j : make_iterator_range(vertices(G))) {
1230            if (getType(G[j]) == typeId[i]) {
1231                factors.push_back(j);
1232            }
1233        }
1234        if (factorGraph(typeId[i], G, factors)) {
1235            factored = true;
1236        }
1237        factors.clear();
1238    }
1239    return factored;
1240}
1241
1242
1243inline bool isMutable(const Vertex u, const Graph & G) {
1244    return getType(G[u]) != TypeId::Var;
1245}
1246
1247/** ------------------------------------------------------------------------------------------------------------- *
1248 * @brief rewriteAST
1249 ** ------------------------------------------------------------------------------------------------------------- */
1250bool BooleanReassociationPass::rewriteAST(CharacterizationMap & C, VertexMap & M, Graph & G) {
1251
1252    using line_t = long long int;
1253
1254    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
1255
1256    Z3_config cfg = Z3_mk_config();
1257    Z3_set_param_value(cfg, "model", "true");
1258    Z3_context ctx = Z3_mk_context(cfg);
1259    Z3_del_config(cfg);
1260    Z3_solver solver = Z3_mk_solver(ctx);
1261    Z3_solver_inc_ref(ctx, solver);
1262
1263    std::vector<Z3_ast> mapping(num_vertices(G), nullptr);
1264
1265    flat_map<PabloAST *, Z3_ast> V;
1266
1267    // Generate the variables
1268    const auto ty = Z3_mk_int_sort(ctx);
1269    Z3_ast ZERO = Z3_mk_int(ctx, 0, ty);
1270
1271    for (const Vertex u : make_iterator_range(vertices(G))) {
1272        const auto var = Z3_mk_fresh_const(ctx, nullptr, ty);
1273        Z3_ast constraint = nullptr;
1274        if (in_degree(u, G) > 0) {
1275            constraint = Z3_mk_gt(ctx, var, ZERO);
1276            Z3_solver_assert(ctx, solver, constraint);
1277        }       
1278        PabloAST * const expr = getValue(G[u]);
1279        if (inCurrentBlock(expr, mBlock)) {
1280            V.emplace(expr, var);
1281        }
1282        mapping[u] = var;
1283    }
1284
1285    // Add in the dependency constraints
1286    for (const Vertex u : make_iterator_range(vertices(G))) {
1287        Z3_ast const t = mapping[u];
1288        for (auto e : make_iterator_range(in_edges(u, G))) {
1289            Z3_ast const s = mapping[source(e, G)];
1290            Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, s, t));
1291        }
1292    }
1293
1294    // Compute the soft ordering constraints
1295    std::vector<Z3_ast> ordering(0);
1296    ordering.reserve(V.size() - 1);
1297
1298    Z3_ast prior = nullptr;
1299    unsigned gap = 1;
1300    for (Statement * stmt : *mBlock) {
1301        auto f = V.find(stmt);
1302        if (f != V.end()) {
1303            Z3_ast const node = f->second;
1304            if (prior) {
1305//                ordering.push_back(Z3_mk_lt(ctx, prior, node)); // increases the cost by 6 - 10x
1306                Z3_ast ops[2] = { node, prior };
1307                ordering.push_back(Z3_mk_le(ctx, Z3_mk_sub(ctx, 2, ops), Z3_mk_int(ctx, gap, ty)));
1308            } else {
1309                ordering.push_back(Z3_mk_eq(ctx, node, Z3_mk_int(ctx, gap, ty)));
1310            }
1311            prior = node;
1312            gap = 0;
1313        }
1314        ++gap;
1315    }
1316
1317    if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) < 0)) {
1318        throw std::runtime_error("Unable to construct a topological ordering during reassociation!");
1319    }
1320
1321    Z3_model model = Z3_solver_get_model(ctx, solver);
1322    Z3_model_inc_ref(ctx, model);
1323
1324    std::vector<Vertex> S(0);
1325    S.reserve(num_vertices(G));
1326
1327    std::vector<line_t> L(num_vertices(G));
1328
1329
1330
1331    for (const Vertex u : make_iterator_range(vertices(G))) {
1332        line_t line = LoadEarly ? 0 : MAX_INT;
1333        if (isMutable(u, G)) {
1334            Z3_ast value;
1335            if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, mapping[u], Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1336                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1337            }
1338            if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1339                throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1340            }
1341            S.push_back(u);
1342        }
1343        L[u] = line;
1344    }
1345
1346    Z3_model_dec_ref(ctx, model);
1347
1348    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1349
1350    mBlock->setInsertPoint(nullptr);
1351
1352    std::vector<Vertex> T;
1353
1354    line_t count = 1;
1355
1356//    errs() << "--------------------------------------------------\n";
1357
1358//    printGraph(G, "G");
1359
1360    for (auto u : S) {
1361        PabloAST *& stmt = getValue(G[u]);
1362
1363        assert (isMutable(u, G));
1364        assert (L[u] > 0 && L[u] < MAX_INT);
1365
1366        if (isAssociative(G[u])) {
1367
1368            if (in_degree(u, G) == 0 || out_degree(u, G) == 0) {
1369                throw std::runtime_error("Vertex " + std::to_string(u) + " is either a source or sink node but marked as associative!");
1370            }
1371
1372            Statement * ip = mBlock->getInsertPoint();
1373            ip = ip ? ip->getNextNode() : mBlock->front();
1374
1375            const auto typeId = getType(G[u]);
1376
1377// retry:
1378
1379            T.clear();
1380            T.reserve(in_degree(u, G));
1381            for (const auto e : make_iterator_range(in_edges(u, G))) {
1382                T.push_back(source(e, G));
1383            }
1384
1385            // Then sort them by their line position (noting any incoming value will either be 0 or MAX_INT)
1386            std::sort(T.begin(), T.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1387
1388            if (LoadEarly) {
1389                mBlock->setInsertPoint(nullptr);
1390            }
1391
1392            PabloAST * expr = nullptr;
1393            PabloAST * join = nullptr;
1394
1395            for (auto v : T) {
1396                expr = getValue(G[v]);
1397                if (LLVM_UNLIKELY(expr == nullptr)) {
1398                    throw std::runtime_error("Vertex " + std::to_string(v) + " does not have an expression!");
1399                }
1400                if (join) {
1401
1402                    if (in_degree(v, G) > 0) {
1403
1404                        assert (L[v] > 0 && L[v] < MAX_INT);
1405
1406                        Statement * dom = cast<Statement>(expr);
1407                        for (;;) {
1408                            PabloBlock * const parent = dom->getParent();
1409                            if (parent == mBlock) {
1410                                break;
1411                            }
1412                            dom = parent->getBranch(); assert(dom);
1413                        }
1414                        mBlock->setInsertPoint(dom);
1415
1416                        assert (dominates(join, expr));
1417                        assert (dominates(expr, ip));
1418                        assert (dominates(dom, ip));
1419                    }
1420
1421                    switch (typeId) {
1422                        case TypeId::And:
1423                            expr = mBlock->createAnd(join, expr); break;
1424                        case TypeId::Or:
1425                            expr = mBlock->createOr(join, expr); break;
1426                        case TypeId::Xor:
1427                            expr = mBlock->createXor(join, expr); break;
1428                        default:
1429                            llvm_unreachable("Invalid TypeId!");
1430                    }
1431
1432//                    // If the insertion point isn't the statement we just attempted to create
1433//                    // we must have unexpectidly reused a prior statement (or Var.)
1434//                    if (LLVM_UNLIKELY(expr != mBlock->getInsertPoint())) {
1435//                        const auto reduction = reduceVertex(u, C, M, G, true);
1436//                        if (LLVM_UNLIKELY(reduction == Reduction::NoChange)) {
1437//                            throw std::runtime_error("Unable to reduce vertex " + std::to_string(u));
1438//                        } else if (LLVM_UNLIKELY(reduction == Reduction::Simplified)) {
1439//                            goto retry;
1440//                        } else { // if (reduction == Reduction::Removed) {
1441//                            mBlock->setInsertPoint(ip->getPrevNode());
1442//                            goto next_statement;
1443//                        }
1444//                    }
1445                }
1446                join = expr;
1447            }
1448
1449            assert (expr);
1450
1451            Statement * const currIP = mBlock->getInsertPoint();
1452
1453            mBlock->setInsertPoint(ip->getPrevNode());
1454
1455            for (auto e : make_iterator_range(out_edges(u, G))) {
1456                if (G[e]) {
1457                    if (PabloAST * user = getValue(G[target(e, G)])) {
1458                        cast<Statement>(user)->replaceUsesOfWith(G[e], expr);
1459                    }
1460                }
1461            }
1462
1463            stmt = expr;
1464
1465            // If the insertion point isn't the statement we just attempted to create,
1466            // we must have unexpectidly reused a prior statement (or var.)
1467            if (LLVM_UNLIKELY(expr != currIP)) {
1468                L[u] = 0;
1469                // If that statement happened to be in the current block, locate which
1470                // statement it "duplicated" (if any) and set the duplicate's line # to
1471                // the line of the original statement to maintain a consistent view of
1472                // the program.
1473                if (inCurrentBlock(expr, mBlock)) {
1474                    for (auto v : make_iterator_range(vertices(G))) {
1475                        if (LLVM_UNLIKELY(getValue(G[v]) == expr)) {
1476                            L[u] = L[v];
1477                            break;
1478                        }
1479                    }
1480                }
1481                continue;
1482            }
1483        }
1484
1485        assert (stmt);
1486
1487        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
1488            for (auto e : make_iterator_range(out_edges(u, G))) {
1489                const auto v = target(e, G);
1490                assert (L[v] == std::numeric_limits<line_t>::max());
1491                L[v] = count;
1492            }
1493        }
1494
1495//        PabloPrinter::print(cast<Statement>(stmt), errs()); errs() << "\n";
1496
1497        mBlock->insert(cast<Statement>(stmt));
1498        L[u] = count++; // update the line count with the actual one.
1499//        next_statement: continue;
1500    }
1501
1502    Z3_solver_dec_ref(ctx, solver);
1503    Z3_del_context(ctx);
1504
1505    Statement * const end = mBlock->getInsertPoint(); assert (end);
1506    for (;;) {
1507        Statement * const next = end->getNextNode();
1508        if (next == nullptr) {
1509            break;
1510        }
1511
1512        #ifndef NDEBUG
1513        for (PabloAST * user : next->users()) {
1514            if (isa<Statement>(user) && dominates(user, next)) {
1515                std::string tmp;
1516                raw_string_ostream out(tmp);
1517                out << "Erasing ";
1518                PabloPrinter::print(next, out);
1519                out << " erroneously modifies live statement ";
1520                PabloPrinter::print(cast<Statement>(user), out);
1521                throw std::runtime_error(out.str());
1522            }
1523        }
1524        #endif
1525        next->eraseFromParent(true);
1526    }
1527
1528    #ifndef NDEBUG
1529    PabloVerifier::verify(mFunction, "mid-reassociation");
1530    #endif
1531
1532    return true;
1533}
1534
1535/** ------------------------------------------------------------------------------------------------------------- *
1536 * @brief addSummaryVertex
1537 ** ------------------------------------------------------------------------------------------------------------- */
1538Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
1539//    for (Vertex u : make_iterator_range(vertices(G))) {
1540//        if (LLVM_UNLIKELY(in_degree(u, G) == 0 && out_degree(u, G) == 0)) {
1541//            std::get<0>(G[u]) = typeId;
1542//            std::get<1>(G[u]) = expr;
1543//            return u;
1544//        }
1545//    }
1546    return add_vertex(std::make_tuple(typeId, expr, node), G);
1547}
1548
1549/** ------------------------------------------------------------------------------------------------------------- *
1550 * @brief addSummaryVertex
1551 ** ------------------------------------------------------------------------------------------------------------- */
1552Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
1553    assert (expr);
1554    const auto f = M.find(expr);
1555    if (f != M.end()) {
1556        assert (getValue(G[f->second]) == expr);
1557        return f->second;
1558    }
1559    const Vertex u = makeVertex(typeId, expr, G, node);
1560    M.emplace(expr, u);
1561    return u;
1562}
1563
1564/** ------------------------------------------------------------------------------------------------------------- *
1565 * @brief addSummaryVertex
1566 ** ------------------------------------------------------------------------------------------------------------- */
1567Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & M, Graph & G) {
1568    assert (expr);
1569    const auto f = M.find(expr);
1570    if (f != M.end()) {
1571        assert (getValue(G[f->second]) == expr);
1572        return f->second;
1573    }
1574    const Vertex u = makeVertex(typeId, expr, G, C.get(expr));
1575    M.emplace(expr, u);
1576    return u;
1577}
1578
1579/** ------------------------------------------------------------------------------------------------------------- *
1580 * @brief removeSummaryVertex
1581 ** ------------------------------------------------------------------------------------------------------------- */
1582inline void BooleanReassociationPass::removeVertex(const Vertex u, StatementMap & M, Graph & G) const {
1583    VertexData & ref = G[u];
1584    if (std::get<1>(ref)) {
1585        auto f = M.find(std::get<1>(ref));
1586        assert (f != M.end());
1587        M.erase(f);
1588    }
1589    removeVertex(u, G);
1590}
1591
1592/** ------------------------------------------------------------------------------------------------------------- *
1593 * @brief removeSummaryVertex
1594 ** ------------------------------------------------------------------------------------------------------------- */
1595inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
1596    VertexData & ref = G[u];
1597    clear_vertex(u, G);
1598    std::get<0>(ref) = TypeId::Var;
1599    std::get<1>(ref) = nullptr;
1600    std::get<2>(ref) = nullptr;
1601}
1602
1603/** ------------------------------------------------------------------------------------------------------------- *
1604 * @brief make
1605 ** ------------------------------------------------------------------------------------------------------------- */
1606inline Z3_ast BooleanReassociationPass::makeVar() {
1607    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1608    Z3_inc_ref(mContext, node);
1609    mRefs.push_back(node);
1610    return node;
1611}
1612
1613/** ------------------------------------------------------------------------------------------------------------- *
1614 * @brief simplify
1615 ** ------------------------------------------------------------------------------------------------------------- */
1616Z3_ast BooleanReassociationPass::simplify(Z3_ast const node, bool use_expensive_minimization) const {
1617    assert (node);
1618    Z3_ast result = Z3_simplify_ex(mContext, node, mParams);
1619    Z3_inc_ref(mContext, result);
1620    if (use_expensive_minimization) {
1621        Z3_goal g = Z3_mk_goal(mContext, true, false, false);
1622        Z3_goal_inc_ref(mContext, g);
1623        Z3_goal_assert(mContext, g, result);
1624
1625        Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g);
1626        Z3_apply_result_inc_ref(mContext, r);
1627        assert (Z3_apply_result_get_num_subgoals(mContext, r) == 1);
1628
1629        Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0);
1630        Z3_goal_inc_ref(mContext, h);
1631        Z3_goal_dec_ref(mContext, g);
1632
1633        const unsigned n = Z3_goal_size(mContext, h);
1634
1635        Z3_ast optimized = nullptr;
1636        if (n == 1) {
1637            optimized = Z3_goal_formula(mContext, h, 0);
1638            Z3_inc_ref(mContext, optimized);
1639
1640        } else if (n > 1) {
1641            Z3_ast operands[n];
1642            for (unsigned i = 0; i < n; ++i) {
1643                operands[i] = Z3_goal_formula(mContext, h, i);
1644                Z3_inc_ref(mContext, operands[i]);
1645            }
1646            optimized = Z3_mk_and(mContext, n, operands);
1647            Z3_inc_ref(mContext, optimized);
1648            for (unsigned i = 0; i < n; ++i) {
1649                Z3_dec_ref(mContext, operands[i]);
1650            }
1651        }
1652        Z3_goal_dec_ref(mContext, h);
1653        Z3_apply_result_dec_ref(mContext, r);
1654        Z3_dec_ref(mContext, result);
1655        result = optimized;
1656    }
1657    Z3_dec_ref(mContext, node);
1658    return result;
1659}
1660
1661/** ------------------------------------------------------------------------------------------------------------- *
1662 * @brief add
1663 ** ------------------------------------------------------------------------------------------------------------- */
1664inline Z3_ast BooleanReassociationPass::CharacterizationMap::add(PabloAST * const expr, Z3_ast node, const bool forwardOnly) {
1665    assert (expr && node);
1666    mForward.emplace(expr, node);
1667    if (!forwardOnly) {
1668        mBackward.emplace(node, expr);
1669    }
1670    return node;
1671}
1672
1673/** ------------------------------------------------------------------------------------------------------------- *
1674 * @brief get
1675 ** ------------------------------------------------------------------------------------------------------------- */
1676inline Z3_ast BooleanReassociationPass::CharacterizationMap::get(PabloAST * const expr) const {
1677    assert (expr);
1678    auto f = mForward.find(expr);
1679    if (LLVM_UNLIKELY(f == mForward.end())) {
1680        if (mPredecessor == nullptr) {
1681            return nullptr;
1682        }
1683        return mPredecessor->get(expr);
1684    }
1685    return f->second;
1686}
1687
1688/** ------------------------------------------------------------------------------------------------------------- *
1689 * @brief get
1690 ** ------------------------------------------------------------------------------------------------------------- */
1691inline PabloAST * BooleanReassociationPass::CharacterizationMap::findKey(Z3_ast const node) const {
1692    assert (node);
1693    auto f = mBackward.find(node);
1694    if (LLVM_UNLIKELY(f == mBackward.end())) {
1695        if (mPredecessor == nullptr) {
1696            return nullptr;
1697        }
1698        return mPredecessor->findKey(node);
1699    }
1700    return f->second;
1701}
1702
1703/** ------------------------------------------------------------------------------------------------------------- *
1704 * @brief constructor
1705 ** ------------------------------------------------------------------------------------------------------------- */
1706inline BooleanReassociationPass::BooleanReassociationPass(Z3_context ctx, Z3_params params, Z3_tactic tactic, PabloFunction & f)
1707: mContext(ctx)
1708, mParams(params)
1709, mTactic(tactic)
1710, mFunction(f)
1711, mModified(false) {
1712
1713}
1714
1715}
Note: See TracBrowser for help on using the repository browser.