source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 5156

Last change on this file since 5156 was 5156, checked in by nmedfort, 3 years ago

Work on multiplexing and distribution passes + a few AST modification bug fixes.

File size: 62.4 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <boost/graph/strong_components.hpp>
8#include <pablo/optimizers/pablo_simplifier.hpp>
9#include <pablo/analysis/pabloverifier.hpp>
10#include <algorithm>
11#include <numeric> // std::iota
12#include <pablo/printer_pablos.h>
13#include <iostream>
14#include <llvm/Support/CommandLine.h>
15#include "maxsat.hpp"
16
17//  noif-dist-mult-dist-50 \p{Cham}(?<!\p{Mc})
18
19using namespace boost;
20using namespace boost::container;
21
22
23static cl::OptionCategory ReassociationOptions("Reassociation Optimization Options", "These options control the Pablo Reassociation optimization pass.");
24
25static cl::opt<unsigned> LoadEarly("Reassociation-Load-Early", cl::init(false),
26                                  cl::desc("When recomputing an Associative operation, load values from preceeding blocks at the beginning of the "
27                                           "Scope Block rather than at the point of first use."),
28                                  cl::cat(ReassociationOptions));
29
30namespace pablo {
31
32using TypeId = PabloAST::ClassTypeId;
33using Graph = BooleanReassociationPass::Graph;
34using Vertex = Graph::vertex_descriptor;
35using VertexData = BooleanReassociationPass::VertexData;
36using DistributionGraph = BooleanReassociationPass::DistributionGraph;
37using DistributionMap = flat_map<Graph::vertex_descriptor, DistributionGraph::vertex_descriptor>;
38using VertexSet = std::vector<Vertex>;
39using VertexSets = std::vector<VertexSet>;
40using Biclique = std::pair<VertexSet, VertexSet>;
41using BicliqueSet = std::vector<Biclique>;
42using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
43using DistributionSets = std::vector<DistributionSet>;
44
45/** ------------------------------------------------------------------------------------------------------------- *
46 * @brief helper functions
47 ** ------------------------------------------------------------------------------------------------------------- */
48template<typename Iterator>
49inline Graph::edge_descriptor first(const std::pair<Iterator, Iterator> & range) {
50    assert (range.first != range.second);
51    return *range.first;
52}
53
54static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock * const block) {
55    return stmt->getParent() == block;
56}
57
58static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock * const block) {
59    return expr ? isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block) : true;
60}
61
62inline TypeId & getType(VertexData & data) {
63    return std::get<0>(data);
64}
65
66inline TypeId getType(const VertexData & data) {
67    return std::get<0>(data);
68}
69
70inline Z3_ast & getDefinition(VertexData & data) {
71    return std::get<2>(data);
72}
73
74inline PabloAST * getValue(const VertexData & data) {
75    return std::get<1>(data);
76}
77
78inline PabloAST *& getValue(VertexData & data) {
79    return std::get<1>(data);
80}
81
82inline bool isAssociative(const VertexData & data) {
83    switch (getType(data)) {
84        case TypeId::And:
85        case TypeId::Or:
86        case TypeId::Xor:
87            return true;
88        default:
89            return false;
90    }
91}
92
93inline bool isDistributive(const VertexData & data) {
94    switch (getType(data)) {
95        case TypeId::And:
96        case TypeId::Or:
97            return true;
98        default:
99            return false;
100    }
101}
102
103void add_edge(PabloAST * expr, const Vertex u, const Vertex v, Graph & G) {
104    // Make sure each edge is unique
105    assert (u < num_vertices(G) && v < num_vertices(G));
106    assert (u != v);
107
108    // Just because we've supplied an expr doesn't mean it's useful. Check it.
109    if (expr) {
110        if (isAssociative(G[v])) {
111            expr = nullptr;
112        } else {
113            bool clear = true;
114            if (const Statement * dest = dyn_cast_or_null<Statement>(getValue(G[v]))) {
115                for (unsigned i = 0; i < dest->getNumOperands(); ++i) {
116                    if (dest->getOperand(i) == expr) {
117                        clear = false;
118                        break;
119                    }
120                }
121            }
122            if (LLVM_LIKELY(clear)) {
123                expr = nullptr;
124            }
125        }
126    }
127
128    for (auto e : make_iterator_range(out_edges(u, G))) {
129        if (LLVM_UNLIKELY(target(e, G) == v)) {
130            if (expr) {
131                if (G[e] == nullptr) {
132                    G[e] = expr;
133                } else if (G[e] != expr) {
134                    continue;
135                }
136            }
137            return;
138        }
139    }
140    G[boost::add_edge(u, v, G).first] = expr;
141}
142
143/** ------------------------------------------------------------------------------------------------------------- *
144 * @brief intersects
145 ** ------------------------------------------------------------------------------------------------------------- */
146template <class Type>
147inline bool intersects(const Type & A, const Type & B) {
148    auto first1 = A.begin(), last1 = A.end();
149    auto first2 = B.begin(), last2 = B.end();
150    while (first1 != last1 && first2 != last2) {
151        if (*first1 < *first2) {
152            ++first1;
153        } else if (*first2 < *first1) {
154            ++first2;
155        } else {
156            return true;
157        }
158    }
159    return false;
160}
161
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief printGraph
165 ** ------------------------------------------------------------------------------------------------------------- */
166static void printGraph(const Graph & G, const std::string name) {
167    raw_os_ostream out(std::cerr);
168
169    std::vector<unsigned> c(num_vertices(G));
170    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
171
172    out << "digraph " << name << " {\n";
173    for (auto u : make_iterator_range(vertices(G))) {
174        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
175            continue;
176        }
177        out << "v" << u << " [label=\"" << u << ": ";
178        TypeId typeId;
179        PabloAST * expr;
180        Z3_ast node;
181        std::tie(typeId, expr, node) = G[u];
182        bool temporary = false;
183        bool error = false;
184        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
185            temporary = true;
186            switch (typeId) {
187                case TypeId::And:
188                    out << "And";
189                    break;
190                case TypeId::Or:
191                    out << "Or";
192                    break;
193                case TypeId::Xor:
194                    out << "Xor";
195                    break;
196                case TypeId::Not:
197                    out << "Not";
198                    break;
199                default:
200                    out << "???";
201                    error = true;
202                    break;
203            }
204            if (expr) {
205                out << " ("; PabloPrinter::print(expr, out); out << ")";
206            }
207        } else {
208            PabloPrinter::print(expr, out);
209        }
210        if (node == nullptr) {
211            out << " (*)";
212            error = true;
213        }
214        out << "\"";
215        if (typeId == TypeId::Var) {
216            out << " style=dashed";
217        }
218        if (error) {
219            out << " color=red";
220        } else if (temporary) {
221            out << " color=blue";
222        }
223        out << "];\n";
224    }
225    for (auto e : make_iterator_range(edges(G))) {
226        const auto s = source(e, G);
227        const auto t = target(e, G);
228        out << "v" << s << " -> v" << t;
229        bool cyclic = (c[s] == c[t]);
230        if (G[e] || cyclic) {
231            out << " [";
232             if (G[e]) {
233                out << "label=\"";
234                PabloPrinter::print(G[e], out);
235                out << "\" ";
236             }
237             if (cyclic) {
238                out << "color=red ";
239             }
240             out << "]";
241        }
242        out << ";\n";
243    }
244
245    if (num_vertices(G) > 0) {
246
247        out << "{ rank=same;";
248        for (auto u : make_iterator_range(vertices(G))) {
249            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
250                out << " v" << u;
251            }
252        }
253        out << "}\n";
254
255        out << "{ rank=same;";
256        for (auto u : make_iterator_range(vertices(G))) {
257            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
258                out << " v" << u;
259            }
260        }
261        out << "}\n";
262
263    }
264
265    out << "}\n\n";
266    out.flush();
267}
268
269/** ------------------------------------------------------------------------------------------------------------- *
270 * @brief optimize
271 ** ------------------------------------------------------------------------------------------------------------- */
272bool BooleanReassociationPass::optimize(PabloFunction & function) {
273
274    Z3_config cfg = Z3_mk_config();
275    Z3_context ctx = Z3_mk_context(cfg);
276    Z3_del_config(cfg);
277
278    Z3_params params = Z3_mk_params(ctx);
279    Z3_params_inc_ref(ctx, params);
280    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "pull_cheap_ite"), true);
281    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "local_ctx"), true);
282
283    Z3_tactic ctx_solver_simplify = Z3_mk_tactic(ctx, "ctx-solver-simplify");
284    Z3_tactic_inc_ref(ctx, ctx_solver_simplify);
285
286    BooleanReassociationPass brp(ctx, params, ctx_solver_simplify, function);
287    brp.processScopes(function);
288
289    Z3_params_dec_ref(ctx, params);
290    Z3_tactic_dec_ref(ctx, ctx_solver_simplify);
291    Z3_del_context(ctx);
292
293    PabloVerifier::verify(function, "post-reassociation");
294
295    Simplifier::optimize(function);
296
297    return true;
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief processScopes
302 ** ------------------------------------------------------------------------------------------------------------- */
303inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
304    PabloBlock * const entry = function.getEntryBlock();
305    CharacterizationMap map;
306    // Map the constants and input variables
307    map.add(entry->createZeroes(), Z3_mk_false(mContext));
308    map.add(entry->createOnes(), Z3_mk_true(mContext));
309    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
310        map.add(mFunction.getParameter(i), makeVar());
311    }
312    mInFile = makeVar();
313    processScopes(entry, map);
314    return mModified;
315}
316
317/** ------------------------------------------------------------------------------------------------------------- *
318 * @brief processScopes
319 ** ------------------------------------------------------------------------------------------------------------- */
320void BooleanReassociationPass::processScopes(PabloBlock * const block, CharacterizationMap & map) {
321    for (Statement * stmt = block->front(); stmt; ) {
322        if (LLVM_UNLIKELY(isa<If>(stmt))) {
323            if (LLVM_UNLIKELY(isa<Zeroes>(cast<If>(stmt)->getCondition()))) {
324                stmt = stmt->eraseFromParent(true);
325            } else {
326                CharacterizationMap nested(map);
327                processScopes(cast<If>(stmt)->getBody(), nested);
328                for (Assign * def : cast<If>(stmt)->getDefined()) {
329                    map.add(def, makeVar());
330                }
331                stmt = stmt->getNextNode();
332            }
333        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
334            if (LLVM_UNLIKELY(isa<Zeroes>(cast<While>(stmt)->getCondition()))) {
335                stmt = stmt->eraseFromParent(true);
336            } else {
337                CharacterizationMap nested(map);
338                processScopes(cast<While>(stmt)->getBody(), nested);
339                for (Next * var : cast<While>(stmt)->getVariants()) {
340                    map.add(var, makeVar());
341                }
342                stmt = stmt->getNextNode();
343            }
344        } else { // characterize this statement then check whether it is equivalent to any existing one.
345            stmt = characterize(stmt, map);
346        }
347    }   
348    distributeScope(block, map);
349}
350
351/** ------------------------------------------------------------------------------------------------------------- *
352 * @brief characterize
353 ** ------------------------------------------------------------------------------------------------------------- */
354inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & map) {
355    Z3_ast node = nullptr;
356    const size_t n = stmt->getNumOperands(); assert (n > 0);
357    if (isa<Variadic>(stmt)) {
358        Z3_ast operands[n];
359        for (size_t i = 0; i < n; ++i) {
360            operands[i] = map.get(stmt->getOperand(i)); assert (operands[i]);
361        }
362        if (isa<And>(stmt)) {
363            node = Z3_mk_and(mContext, n, operands);
364        } else if (isa<Or>(stmt)) {
365            node = Z3_mk_or(mContext, n, operands);
366        } else if (isa<Xor>(stmt)) {
367            node = Z3_mk_xor(mContext, operands[0], operands[1]);
368            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
369                node = Z3_mk_xor(mContext, node, operands[i]);
370            }
371        }
372    } else if (isa<Not>(stmt)) {
373        Z3_ast op = map.get(stmt->getOperand(0)); assert (op);
374        node = Z3_mk_not(mContext, op);
375    } else if (isa<Sel>(stmt)) {
376        Z3_ast operands[3];
377        for (size_t i = 0; i < 3; ++i) {
378            operands[i] = map.get(stmt->getOperand(i)); assert (operands[i]);
379        }
380        node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
381    } else if (LLVM_UNLIKELY(isa<InFile>(stmt) || isa<AtEOF>(stmt))) {
382        assert (stmt->getNumOperands() == 1);
383        Z3_ast check[2];
384        check[0] = map.get(stmt->getOperand(0)); assert (check[0]);
385        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
386        node = Z3_mk_and(mContext, 2, check);
387    } else {
388        if (LLVM_UNLIKELY(isa<Assign>(stmt) || isa<Next>(stmt))) {
389            Z3_ast op = map.get(stmt->getOperand(0)); assert (op);
390            map.add(stmt, op, true);
391        } else {
392            map.add(stmt, makeVar());
393        }
394        return stmt->getNextNode();
395    }
396    node = simplify(node); assert (node);
397    PabloAST * const replacement = map.findKey(node);
398    if (LLVM_LIKELY(replacement == nullptr)) {
399        map.add(stmt, node);
400        return stmt->getNextNode();
401    } else {
402        return stmt->replaceWith(replacement);
403    }
404
405}
406
407/** ------------------------------------------------------------------------------------------------------------- *
408 * @brief processScope
409 ** ------------------------------------------------------------------------------------------------------------- */
410inline void BooleanReassociationPass::distributeScope(PabloBlock * const block, CharacterizationMap & C) {
411    Graph G;
412    try {
413        mBlock = block;
414        transformAST(C, G);
415    } catch (std::runtime_error err) {
416        printGraph(G, "E");
417        throw err;
418    } catch (std::exception err) {
419        printGraph(G, "E");
420        throw err;
421    }
422}
423
424/** ------------------------------------------------------------------------------------------------------------- *
425 * @brief summarizeAST
426 *
427 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
428 * are "flattened" (i.e., allowed to have any number of inputs.)
429 ** ------------------------------------------------------------------------------------------------------------- */
430void BooleanReassociationPass::transformAST(CharacterizationMap & C, Graph & G) {
431
432    StatementMap S;
433
434    // Compute the base def-use graph ...
435    for (Statement * stmt : *mBlock) {
436
437        const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, S, G, C.get(stmt));
438
439        for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
440            PabloAST * const op = stmt->getOperand(i);
441            if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
442                add_edge(op, makeVertex(TypeId::Var, op, C, S, G), u, G);
443            }
444        }
445
446        if (LLVM_UNLIKELY(isa<If>(stmt))) {
447            for (Assign * def : cast<const If>(stmt)->getDefined()) {
448                const Vertex v = makeVertex(TypeId::Var, def, C, S, G);
449                add_edge(def, u, v, G);
450                resolveNestedUsages(def, v, C, S, G, stmt);
451            }
452        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
453            // To keep G a DAG, we need to do a bit of surgery on loop variants because
454            // the next variables it produces can be used within the condition. Instead,
455            // we make the loop dependent on the original value of each Next node and
456            // the Next node dependent on the loop.
457            for (Next * var : cast<const While>(stmt)->getVariants()) {
458                const Vertex v = makeVertex(TypeId::Var, var, C, S, G);
459                assert (in_degree(v, G) == 1);
460                auto e = first(in_edges(v, G));
461                add_edge(G[e], source(e, G), u, G);
462                remove_edge(v, u, G);
463                add_edge(var, u, v, G);
464                resolveNestedUsages(var, v, C, S, G, stmt);
465            }
466        } else {
467            resolveNestedUsages(stmt, u, C, S, G, stmt);
468        }
469    }
470
471//    printGraph(G, "G");
472
473    VertexMap M;
474    if (redistributeGraph(C, M, G)) {
475        factorGraph(G);
476
477//        printGraph(G, "H");
478
479        rewriteAST(C, M, G);
480        mModified = true;
481    }
482
483}
484
485/** ------------------------------------------------------------------------------------------------------------- *
486 * @brief resolveNestedUsages
487 ** ------------------------------------------------------------------------------------------------------------- */
488void BooleanReassociationPass::resolveNestedUsages(PabloAST * const expr, const Vertex u,
489                                                   CharacterizationMap & C, StatementMap & S, Graph & G,
490                                                   const Statement * const ignoreIfThis) const {
491    assert ("Cannot resolve nested usages of a null expression!" && expr);
492    for (PabloAST * user : expr->users()) { assert (user);
493        if (LLVM_LIKELY(user != expr && isa<Statement>(user))) {
494            PabloBlock * parent = cast<Statement>(user)->getParent(); assert (parent);
495            if (LLVM_UNLIKELY(parent != mBlock)) {
496                for (;;) {
497                    if (parent->getParent() == mBlock) {
498                        Statement * const branch = parent->getBranch();
499                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
500                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
501                            const Vertex v = makeVertex(TypeId::Var, user, C, S, G);
502                            add_edge(expr, u, v, G);
503                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
504                            add_edge(user, v, w, G);
505                        }
506                        break;
507                    }
508                    parent = parent->getParent();
509                    if (LLVM_UNLIKELY(parent == nullptr)) {
510                        assert (isa<Assign>(expr) || isa<Next>(expr));
511                        break;
512                    }
513                }
514            }
515        }
516    }
517}
518
519
520
521/** ------------------------------------------------------------------------------------------------------------- *
522 * @brief enumerateBicliques
523 *
524 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
525 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
526 * bipartition A and their adjacencies to be in B.
527  ** ------------------------------------------------------------------------------------------------------------- */
528template <class Graph>
529static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
530    using IntersectionSets = std::set<VertexSet>;
531
532    IntersectionSets B1;
533    for (auto u : A) {
534        if (in_degree(u, G) > 0) {
535            VertexSet adjacencies;
536            adjacencies.reserve(in_degree(u, G));
537            for (auto e : make_iterator_range(in_edges(u, G))) {
538                adjacencies.push_back(source(e, G));
539            }
540            std::sort(adjacencies.begin(), adjacencies.end());
541            assert(std::unique(adjacencies.begin(), adjacencies.end()) == adjacencies.end());
542            B1.insert(std::move(adjacencies));
543        }
544    }
545
546    IntersectionSets B(B1);
547
548    IntersectionSets Bi;
549
550    VertexSet clique;
551    for (auto i = B1.begin(); i != B1.end(); ++i) {
552        for (auto j = i; ++j != B1.end(); ) {
553            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
554            if (clique.size() > 0) {
555                if (B.count(clique) == 0) {
556                    Bi.insert(clique);
557                }
558                clique.clear();
559            }
560        }
561    }
562
563    for (;;) {
564        if (Bi.empty()) {
565            break;
566        }
567        B.insert(Bi.begin(), Bi.end());
568        IntersectionSets Bk;
569        for (auto i = B1.begin(); i != B1.end(); ++i) {
570            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
571                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
572                if (clique.size() > 0) {
573                    if (B.count(clique) == 0) {
574                        Bk.insert(clique);
575                    }
576                    clique.clear();
577                }
578            }
579        }
580        Bi.swap(Bk);
581    }
582
583    BicliqueSet cliques;
584    cliques.reserve(B.size());
585    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
586        VertexSet Ai(A);
587        for (const Vertex u : *Bi) {
588            VertexSet Aj;
589            Aj.reserve(out_degree(u, G));
590            for (auto e : make_iterator_range(out_edges(u, G))) {
591                Aj.push_back(target(e, G));
592            }
593            std::sort(Aj.begin(), Aj.end());
594            assert(std::unique(Aj.begin(), Aj.end()) == Aj.end());
595            VertexSet Ak;
596            Ak.reserve(std::min(Ai.size(), Aj.size()));
597            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
598            Ai.swap(Ak);
599        }
600        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
601        cliques.emplace_back(std::move(Ai), std::move(*Bi));
602    }
603    return std::move(cliques);
604}
605
606/** ------------------------------------------------------------------------------------------------------------- *
607 * @brief independentCliqueSets
608 ** ------------------------------------------------------------------------------------------------------------- */
609template <unsigned side>
610inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
611    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
612
613    const auto l = cliques.size();
614    IndependentSetGraph I(l);
615
616    // Initialize our weights
617    for (unsigned i = 0; i != l; ++i) {
618        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
619    }
620
621    // Determine our constraints
622    for (unsigned i = 0; i != l; ++i) {
623        for (unsigned j = i + 1; j != l; ++j) {
624            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
625                add_edge(i, j, I);
626            }
627        }
628    }
629
630    // Use the greedy algorithm to choose our independent set
631    VertexSet selected;
632    for (;;) {
633        unsigned w = 0;
634        Vertex u = 0;
635        for (unsigned i = 0; i != l; ++i) {
636            if (I[i] > w) {
637                w = I[i];
638                u = i;
639            }
640        }
641        if (w < minimum) break;
642        selected.push_back(u);
643        I[u] = 0;
644        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
645            I[v] = 0;
646        }
647    }
648
649    // Sort the selected list and then remove the unselected cliques
650    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
651    auto end = cliques.end();
652    for (const unsigned offset : selected) {
653        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
654    }
655    cliques.erase(cliques.begin(), end);
656
657    return std::move(cliques);
658}
659
660/** ------------------------------------------------------------------------------------------------------------- *
661 * @brief removeUnhelpfulBicliques
662 *
663 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
664 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
665 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
666 ** ------------------------------------------------------------------------------------------------------------- */
667static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G, DistributionGraph & H) {
668    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
669        const auto cardinalityA = std::get<0>(*ci).size();
670        VertexSet & B = std::get<1>(*ci);
671        for (auto bi = B.begin(); bi != B.end(); ) {
672            if (out_degree(H[*bi], G) == cardinalityA) {
673                ++bi;
674            } else {
675                bi = B.erase(bi);
676            }
677        }
678        if (B.size() > 1) {
679            ++ci;
680        } else {
681            ci = cliques.erase(ci);
682        }
683    }
684    return std::move(cliques);
685}
686
687/** ------------------------------------------------------------------------------------------------------------- *
688 * @brief safeDistributionSets
689 ** ------------------------------------------------------------------------------------------------------------- */
690static DistributionSets safeDistributionSets(const Graph & G, DistributionGraph & H) {
691
692    VertexSet sinks;
693    for (const Vertex u : make_iterator_range(vertices(H))) {
694        if (out_degree(u, H) == 0 && in_degree(u, H) != 0) {
695            sinks.push_back(u);
696        }
697    }
698    std::sort(sinks.begin(), sinks.end());
699
700    DistributionSets T;
701    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(H, sinks), G, H), 1);
702    for (Biclique & lower : lowerSet) {
703        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(H, std::get<1>(lower)), 2);
704        for (Biclique & upper : upperSet) {
705            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
706        }
707    }
708    return std::move(T);
709}
710
711/** ------------------------------------------------------------------------------------------------------------- *
712 * @brief getVertex
713 ** ------------------------------------------------------------------------------------------------------------- */
714template<typename ValueType, typename GraphType, typename MapType>
715static inline Vertex getVertex(const ValueType value, GraphType & G, MapType & M) {
716    const auto f = M.find(value);
717    if (f != M.end()) {
718        return f->second;
719    }
720    const auto u = add_vertex(value, G);
721    M.insert(std::make_pair(value, u));
722    return u;
723}
724
725/** ------------------------------------------------------------------------------------------------------------- *
726 * @brief generateDistributionGraph
727 ** ------------------------------------------------------------------------------------------------------------- */
728void generateDistributionGraph(const Graph & G, DistributionGraph & H) {
729    DistributionMap M;
730    for (const Vertex u : make_iterator_range(vertices(G))) {
731        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
732            continue;
733        } else if (isDistributive(G[u])) {
734            const TypeId outerTypeId = getType(G[u]);
735            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
736            flat_set<Vertex> distributable;
737            for (auto e : make_iterator_range(in_edges(u, G))) {
738                const Vertex v = source(e, G);
739                if (LLVM_UNLIKELY(getType(G[v]) == innerTypeId)) {
740                    bool safe = true;
741                    for (const auto e : make_iterator_range(out_edges(v, G))) {
742                        if (getType(G[target(e, G)]) != outerTypeId) {
743                            safe = false;
744                            break;
745                        }
746                    }
747                    if (safe) {
748                        distributable.insert(v);
749                    }
750                }
751            }
752            if (LLVM_LIKELY(distributable.size() > 1)) {
753                flat_set<Vertex> observed;
754                for (const Vertex v : distributable) {
755                    for (auto e : make_iterator_range(in_edges(v, G))) {
756                        const auto v = source(e, G);
757                        observed.insert(v);
758                    }
759                }
760                for (const Vertex w : observed) {
761                    for (auto e : make_iterator_range(out_edges(w, G))) {
762                        const Vertex v = target(e, G);
763                        if (distributable.count(v)) {
764                            const Vertex y = getVertex(v, H, M);
765                            boost::add_edge(y, getVertex(u, H, M), H);
766                            boost::add_edge(getVertex(w, H, M), y, H);
767                        }
768                    }
769                }
770            }
771        }
772    }
773}
774
775/** ------------------------------------------------------------------------------------------------------------- *
776 * @brief redistributeAST
777 *
778 * Apply the distribution law to reduce computations whenever possible.
779 ** ------------------------------------------------------------------------------------------------------------- */
780bool BooleanReassociationPass::redistributeGraph(CharacterizationMap & C, VertexMap & M, Graph & G) const {
781
782    bool modified = false;
783
784    DistributionGraph H;
785
786    contractGraph(G);
787
788    for (;;) {
789
790        for (;;) {
791
792            generateDistributionGraph(G, H);
793
794            // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
795            if (num_vertices(H) == 0) {
796                break;
797            }
798
799            const DistributionSets distributionSets = safeDistributionSets(G, H);
800
801            if (LLVM_UNLIKELY(distributionSets.empty())) {
802                break;
803            }
804
805            modified = true;
806
807            for (const DistributionSet & set : distributionSets) {
808
809                // Each distribution tuple consists of the sources, intermediary, and sink nodes.
810                const VertexSet & sources = std::get<0>(set);
811                const VertexSet & intermediary = std::get<1>(set);
812                const VertexSet & sinks = std::get<2>(set);
813
814                const TypeId outerTypeId = getType(G[H[sinks.front()]]);
815                assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
816                const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
817
818                const Vertex x = makeVertex(outerTypeId, nullptr, G);
819                const Vertex y = makeVertex(innerTypeId, nullptr, G);
820
821                // Update G to reflect the distributed operations (including removing the subgraph of
822                // the to-be distributed edges.)
823
824                add_edge(nullptr, x, y, G);
825
826                for (const Vertex i : sources) {
827                    const auto u = H[i];
828                    for (const Vertex j : intermediary) {
829                        const auto v = H[j];
830                        const auto e = edge(u, v, G); assert (e.second);
831                        remove_edge(e.first, G);
832                    }
833                    add_edge(nullptr, u, y, G);
834                }
835
836                for (const Vertex i : intermediary) {
837                    const auto u = H[i];
838                    for (const Vertex j : sinks) {
839                        const auto v = H[j];
840                        const auto e = edge(u, v, G); assert (e.second);
841                        add_edge(G[e.first], y, v, G);
842                        remove_edge(e.first, G);
843                    }
844                    add_edge(nullptr, u, x, G);
845                    getDefinition(G[u]) = nullptr;
846                }
847
848            }
849
850            H.clear();
851
852            contractGraph(G);
853        }
854
855        // Although exceptionally unlikely, it's possible that if we can reduce the graph, we could
856        // further simplify it. Restart the process if and only if we succeed.
857        if (reduceGraph(C, M, G)) {
858            if (LLVM_UNLIKELY(contractGraph(G))) {
859                H.clear();
860                continue;
861            }
862        }
863
864        break;
865    }
866
867    return modified;
868}
869
870/** ------------------------------------------------------------------------------------------------------------- *
871 * @brief isNonEscaping
872 ** ------------------------------------------------------------------------------------------------------------- */
873inline bool isNonEscaping(const VertexData & data) {
874    // If these are redundant, the Simplifier pass will eliminate them. Trust that they're necessary.
875    switch (getType(data)) {
876        case TypeId::Assign:
877        case TypeId::Next:
878        case TypeId::If:
879        case TypeId::While:
880        case TypeId::Count:
881            return false;
882        default:
883            return true;
884    }
885}
886
887/** ------------------------------------------------------------------------------------------------------------- *
888 * @brief unique_source
889 ** ------------------------------------------------------------------------------------------------------------- */
890inline bool has_unique_source(const Vertex u, const Graph & G) {
891    if (in_degree(u, G) > 0) {
892        graph_traits<Graph>::in_edge_iterator i, end;
893        std::tie(i, end) = in_edges(u, G);
894        const Vertex v = source(*i, G);
895        while (++i != end) {
896            if (source(*i, G) != v) {
897                return false;
898            }
899        }
900        return true;
901    }
902    return false;
903}
904
905/** ------------------------------------------------------------------------------------------------------------- *
906 * @brief unique_target
907 ** ------------------------------------------------------------------------------------------------------------- */
908inline bool has_unique_target(const Vertex u, const Graph & G) {
909    if (out_degree(u, G) > 0) {
910        graph_traits<Graph>::out_edge_iterator i, end;
911        std::tie(i, end) = out_edges(u, G);
912        const Vertex v = target(*i, G);
913        while (++i != end) {
914            if (target(*i, G) != v) {
915                return false;
916            }
917        }
918        return true;
919    }
920    return false;
921}
922
923
924/** ------------------------------------------------------------------------------------------------------------- *
925 * @brief contractGraph
926 ** ------------------------------------------------------------------------------------------------------------- */
927bool BooleanReassociationPass::contractGraph(Graph & G) const {
928
929    bool contracted = false;
930
931    circular_buffer<Vertex> ordering(num_vertices(G));
932
933    topological_sort(G, std::back_inserter(ordering)); // reverse topological ordering
934
935    // first contract the graph
936    for (const Vertex u : ordering) {
937        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
938            continue;
939        } else if (LLVM_LIKELY(out_degree(u, G) > 0)) {
940            if (isAssociative(G[u])) {
941                if (LLVM_UNLIKELY(has_unique_source(u, G))) {
942                    // We have a redundant node here that'll simply end up being a duplicate
943                    // of the input value. Remove it and add any of its outgoing edges to its
944                    // input node.
945                    const auto ei = first(in_edges(u, G));
946                    const Vertex v = source(ei, G);
947                    for (auto ej : make_iterator_range(out_edges(u, G))) {
948                        add_edge(G[ej], v, target(ej, G), G);
949                    }
950//                    if (LLVM_UNLIKELY(getValue(G[v]) == nullptr && getValue(G[u]) != nullptr)) {
951//                        getValue(G[v]) = getValue(G[u]);
952//                    }
953                    removeVertex(u, G);
954                    contracted = true;
955                } else if (LLVM_UNLIKELY(has_unique_target(u, G))) {
956                    // Otherwise if we have a single user, we have a similar case as above but
957                    // we can only merge this vertex into the outgoing instruction if they are
958                    // of the same type.
959                    const auto ei = first(out_edges(u, G));
960                    const Vertex v = target(ei, G);
961                    if (LLVM_UNLIKELY(getType(G[v]) == getType(G[u]))) {
962                        for (auto ej : make_iterator_range(in_edges(u, G))) {
963                            add_edge(G[ej], source(ej, G), v, G);
964                        }
965                        if (LLVM_UNLIKELY(getValue(G[v]) == nullptr && getValue(G[u]) != nullptr)) {
966                            getValue(G[v]) = getValue(G[u]);
967                        }
968                        removeVertex(u, G);
969                        contracted = true;
970                    }
971                }
972            }
973        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
974            removeVertex(u, G);
975            contracted = true;
976        }
977    }
978    return contracted;
979}
980
981/** ------------------------------------------------------------------------------------------------------------- *
982 * @brief isReducible
983 ** ------------------------------------------------------------------------------------------------------------- */
984inline bool isReducible(const VertexData & data) {
985    switch (getType(data)) {
986        case TypeId::Var:
987        case TypeId::If:
988        case TypeId::While:
989            return false;
990        default:
991            return true;
992    }
993}
994
995/** ------------------------------------------------------------------------------------------------------------- *
996 * @brief reduceGraph
997 ** ------------------------------------------------------------------------------------------------------------- */
998bool BooleanReassociationPass::reduceVertex(const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G, const bool use_expensive_simplification) const {
999
1000    bool reduced = false;
1001
1002    assert (isReducible(G[u]));
1003
1004    Z3_ast node = getDefinition(G[u]);
1005    if (isAssociative(G[u])) {
1006        const TypeId typeId = getType(G[u]);
1007        if (node == nullptr) {
1008            const auto n = in_degree(u, G); assert (n > 1);
1009            Z3_ast operands[n];
1010            unsigned i = 0;
1011            for (auto e : make_iterator_range(in_edges(u, G))) {
1012                const Vertex v = source(e, G);
1013                assert (getDefinition(G[v]));
1014                operands[i++] = getDefinition(G[v]);
1015            }
1016            switch (typeId) {
1017                case TypeId::And:
1018                    node = Z3_mk_and(mContext, n, operands);
1019                    break;
1020                case TypeId::Or:
1021                    node = Z3_mk_or(mContext, n, operands);
1022                    break;
1023                case TypeId::Xor:
1024                    node = Z3_mk_xor(mContext, operands[0], operands[1]);
1025                    for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
1026                        node = Z3_mk_xor(mContext, node, operands[i]);
1027                    }
1028                    break;
1029                default: llvm_unreachable("unexpected type id");
1030            }
1031            assert (node);
1032            getDefinition(G[u]) = node;
1033        }
1034
1035        graph_traits<Graph>::in_edge_iterator begin, end;
1036restart:if (in_degree(u, G) > 1) {
1037            std::tie(begin, end) = in_edges(u, G);
1038            for (auto i = begin; ++i != end; ) {
1039                const auto v = source(*i, G);
1040                for (auto j = begin; j != i; ++j) {
1041                    const auto w = source(*j, G);
1042                    Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
1043                    Z3_ast test = nullptr;
1044                    switch (typeId) {
1045                        case TypeId::And:
1046                            test = Z3_mk_and(mContext, 2, operands); break;
1047                        case TypeId::Or:
1048                            test = Z3_mk_or(mContext, 2, operands); break;
1049                        case TypeId::Xor:
1050                            test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
1051                        default:
1052                            llvm_unreachable("impossible type id");
1053                    }
1054                    assert (test);
1055                    test = simplify(test, use_expensive_simplification);
1056
1057                    bool replacement = false;
1058                    Vertex x = 0;
1059                    const auto f = M.find(test);
1060                    if (LLVM_UNLIKELY(f != M.end())) {
1061                        x = f->second;
1062                        assert (getDefinition(G[x]) == test);
1063                        replacement = true;
1064                    } else {
1065                        PabloAST * const factor = C.findKey(test);
1066                        if (LLVM_UNLIKELY(!inCurrentBlock(factor, mBlock))) {
1067                            x = makeVertex(TypeId::Var, factor, G, test);
1068                            M.emplace(test, x);
1069                            replacement = true;
1070                        }
1071                    }
1072
1073                    if (LLVM_UNLIKELY(replacement)) {
1074
1075                        // note: unless both edges carry an Pablo AST replacement value, they will converge into a single edge.
1076                        PabloAST * const r1 = G[*i];
1077                        PabloAST * const r2 = G[*j];
1078
1079                        remove_edge(*i, G);
1080                        remove_edge(*j, G);
1081
1082                        if (LLVM_UNLIKELY(r1 && r2)) {
1083                            add_edge(r1, x, u, G);
1084                            add_edge(r2, x, u, G);
1085                        } else {
1086                            add_edge(r1 ? r1 : r2, x, u, G);
1087                        }
1088
1089                        reduced = true;
1090                        goto restart;
1091                    }
1092                }
1093            }
1094        }
1095    }
1096
1097    if (LLVM_UNLIKELY(node == nullptr)) {
1098        throw std::runtime_error("No Z3 characterization for vertex " + std::to_string(u));
1099    }
1100
1101    auto f = M.find(node);
1102    if (LLVM_LIKELY(f == M.end())) {
1103        M.emplace(node, u);
1104    } else if (isAssociative(G[u])) {
1105        const Vertex v = f->second;
1106        for (auto e : make_iterator_range(out_edges(u, G))) {
1107            add_edge(G[e], v, target(e, G), G);
1108        }
1109        removeVertex(u, G);
1110        reduced = true;
1111    }
1112
1113    return reduced;
1114}
1115
1116/** ------------------------------------------------------------------------------------------------------------- *
1117 * @brief reduceGraph
1118 ** ------------------------------------------------------------------------------------------------------------- */
1119bool BooleanReassociationPass::reduceGraph(CharacterizationMap & C, VertexMap & M, Graph & G) const {
1120
1121    bool reduced = false;
1122
1123    circular_buffer<Vertex> ordering(num_vertices(G));
1124
1125    topological_sort(G, std::front_inserter(ordering)); // topological ordering
1126
1127    M.clear();
1128
1129    // first contract the graph
1130    for (const Vertex u : ordering) {
1131        if (isReducible(G[u])) {
1132            if (reduceVertex(u, C, M, G, false)) {
1133                reduced = true;
1134            }
1135        }
1136    }
1137    return reduced;
1138}
1139
1140/** ------------------------------------------------------------------------------------------------------------- *
1141 * @brief factorGraph
1142 ** ------------------------------------------------------------------------------------------------------------- */
1143bool BooleanReassociationPass::factorGraph(const TypeId typeId, Graph & G, std::vector<Vertex> & factors) const {
1144
1145    if (LLVM_UNLIKELY(factors.empty())) {
1146        return false;
1147    }
1148
1149    std::vector<Vertex> I, J, K;
1150
1151    bool modified = false;
1152
1153    for (unsigned i = 1; i < factors.size(); ++i) {
1154        assert (getType(G[factors[i]]) == typeId);
1155        for (auto ei : make_iterator_range(in_edges(factors[i], G))) {
1156            I.push_back(source(ei, G));
1157        }
1158        std::sort(I.begin(), I.end());
1159        for (unsigned j = 0; j < i; ++j) {
1160            for (auto ej : make_iterator_range(in_edges(factors[j], G))) {
1161                J.push_back(source(ej, G));
1162            }
1163            std::sort(J.begin(), J.end());
1164            // get the pairwise intersection of each set of inputs (i.e., their common subexpression)
1165            std::set_intersection(I.begin(), I.end(), J.begin(), J.end(), std::back_inserter(K));
1166            assert (std::is_sorted(K.begin(), K.end()));
1167            // if the intersection contains at least two elements
1168            const auto n = K.size();
1169            if (n > 1) {
1170                Vertex a = factors[i];
1171                Vertex b = factors[j];
1172                if (LLVM_UNLIKELY(in_degree(a, G) == n || in_degree(b, G) == n)) {
1173                    if (in_degree(a, G) != n) {
1174                        assert (in_degree(b, G) == n);
1175                        std::swap(a, b);
1176                    }
1177                    assert (in_degree(a, G) == n);
1178                    if (in_degree(b, G) == n) {
1179                        for (auto e : make_iterator_range(out_edges(b, G))) {
1180                            add_edge(G[e], a, target(e, G), G);
1181                        }
1182                        removeVertex(b, G);
1183                    } else {
1184                        for (auto u : K) {
1185                            remove_edge(u, b, G);
1186                        }
1187                        add_edge(nullptr, a, b, G);
1188                    }
1189                } else {
1190                    Vertex v = makeVertex(typeId, nullptr, G);
1191                    for (auto u : K) {
1192                        remove_edge(u, a, G);
1193                        remove_edge(u, b, G);
1194                        add_edge(nullptr, u, v, G);
1195                    }
1196                    add_edge(nullptr, v, a, G);
1197                    add_edge(nullptr, v, b, G);
1198                    factors.push_back(v);
1199                }
1200                modified = true;
1201            }
1202            K.clear();
1203            J.clear();
1204        }
1205        I.clear();
1206    }
1207    return modified;
1208}
1209
1210/** ------------------------------------------------------------------------------------------------------------- *
1211 * @brief factorGraph
1212 ** ------------------------------------------------------------------------------------------------------------- */
1213bool BooleanReassociationPass::factorGraph(Graph & G) const {
1214    // factor the associative vertices.
1215    std::vector<Vertex> factors;
1216    bool factored = false;
1217    for (unsigned i = 0; i < 3; ++i) {
1218        TypeId typeId[3] = { TypeId::And, TypeId::Or, TypeId::Xor};
1219        for (auto j : make_iterator_range(vertices(G))) {
1220            if (getType(G[j]) == typeId[i]) {
1221                factors.push_back(j);
1222            }
1223        }
1224        if (factorGraph(typeId[i], G, factors)) {
1225            factored = true;
1226        }
1227        factors.clear();
1228    }
1229    return factored;
1230}
1231
1232
1233inline bool isMutable(const Vertex u, const Graph & G) {
1234    return getType(G[u]) != TypeId::Var;
1235}
1236
1237/** ------------------------------------------------------------------------------------------------------------- *
1238 * @brief rewriteAST
1239 ** ------------------------------------------------------------------------------------------------------------- */
1240bool BooleanReassociationPass::rewriteAST(CharacterizationMap & C, VertexMap & M, Graph & G) {
1241
1242    using line_t = long long int;
1243
1244    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
1245
1246    Z3_config cfg = Z3_mk_config();
1247    Z3_set_param_value(cfg, "model", "true");
1248    Z3_context ctx = Z3_mk_context(cfg);
1249    Z3_del_config(cfg);
1250    Z3_solver solver = Z3_mk_solver(ctx);
1251    Z3_solver_inc_ref(ctx, solver);
1252
1253    std::vector<Z3_ast> mapping(num_vertices(G), nullptr);
1254
1255    flat_map<PabloAST *, Z3_ast> V;
1256
1257    // Generate the variables
1258    const auto ty = Z3_mk_int_sort(ctx);
1259    Z3_ast ZERO = Z3_mk_int(ctx, 0, ty);
1260
1261    for (const Vertex u : make_iterator_range(vertices(G))) {
1262        const auto var = Z3_mk_fresh_const(ctx, nullptr, ty);
1263        Z3_ast constraint = nullptr;
1264        if (in_degree(u, G) > 0) {
1265            constraint = Z3_mk_gt(ctx, var, ZERO);
1266            Z3_solver_assert(ctx, solver, constraint);
1267        }       
1268        PabloAST * const expr = getValue(G[u]);
1269        if (inCurrentBlock(expr, mBlock)) {
1270            V.emplace(expr, var);
1271        }
1272        mapping[u] = var;
1273    }
1274
1275    // Add in the dependency constraints
1276    for (const Vertex u : make_iterator_range(vertices(G))) {
1277        Z3_ast const t = mapping[u];
1278        for (auto e : make_iterator_range(in_edges(u, G))) {
1279            Z3_ast const s = mapping[source(e, G)];
1280            Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, s, t));
1281        }
1282    }
1283
1284    // Compute the soft ordering constraints
1285    std::vector<Z3_ast> ordering(0);
1286    ordering.reserve(V.size() - 1);
1287
1288    Z3_ast prior = nullptr;
1289    unsigned gap = 1;
1290    for (Statement * stmt : *mBlock) {
1291        auto f = V.find(stmt);
1292        if (f != V.end()) {
1293            Z3_ast const node = f->second;
1294            if (prior) {
1295//                ordering.push_back(Z3_mk_lt(ctx, prior, node)); // increases the cost by 6 - 10x
1296                Z3_ast ops[2] = { node, prior };
1297                ordering.push_back(Z3_mk_le(ctx, Z3_mk_sub(ctx, 2, ops), Z3_mk_int(ctx, gap, ty)));
1298            } else {
1299                ordering.push_back(Z3_mk_eq(ctx, node, Z3_mk_int(ctx, gap, ty)));
1300            }
1301            prior = node;
1302            gap = 0;
1303        }
1304        ++gap;
1305    }
1306
1307    if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) < 0)) {
1308        throw std::runtime_error("Unable to construct a topological ordering during reassociation!");
1309    }
1310
1311    Z3_model model = Z3_solver_get_model(ctx, solver);
1312    Z3_model_inc_ref(ctx, model);
1313
1314    std::vector<Vertex> S(0);
1315    S.reserve(num_vertices(G));
1316
1317    std::vector<line_t> L(num_vertices(G));
1318
1319
1320
1321    for (const Vertex u : make_iterator_range(vertices(G))) {
1322        line_t line = LoadEarly ? 0 : MAX_INT;
1323        if (isMutable(u, G)) {
1324            Z3_ast value;
1325            if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, mapping[u], Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1326                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1327            }
1328            if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1329                throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1330            }
1331            S.push_back(u);
1332        }
1333        L[u] = line;
1334    }
1335
1336    Z3_model_dec_ref(ctx, model);
1337
1338    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1339
1340    mBlock->setInsertPoint(nullptr);
1341
1342    std::vector<Vertex> T;
1343
1344    line_t count = 1;
1345
1346    for (auto u : S) {
1347        PabloAST *& stmt = getValue(G[u]);
1348
1349        assert (isMutable(u, G));
1350        assert (L[u] > 0 && L[u] < MAX_INT);
1351
1352        if (isAssociative(G[u])) {
1353
1354            if (in_degree(u, G) == 0 || out_degree(u, G) == 0) {
1355                throw std::runtime_error("Vertex " + std::to_string(u) + " is either a source or sink node but marked as associative!");
1356            }
1357
1358            Statement * ip = mBlock->getInsertPoint();
1359            ip = ip ? ip->getNextNode() : mBlock->front();
1360
1361            const auto typeId = getType(G[u]);
1362
1363            PabloAST * expr = nullptr;
1364
1365            for (;;) {
1366
1367                T.reserve(in_degree(u, G));
1368                for (const auto e : make_iterator_range(in_edges(u, G))) {
1369                    T.push_back(source(e, G));
1370                }
1371
1372                // Then sort them by their line position (noting any incoming value will either be 0 or MAX_INT)
1373                std::sort(T.begin(), T.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1374
1375                if (LoadEarly) {
1376                    mBlock->setInsertPoint(nullptr);
1377                }
1378
1379                bool done = true;
1380
1381                PabloAST * join = nullptr;
1382
1383                for (auto v : T) {
1384                    expr = getValue(G[v]);
1385                    if (LLVM_UNLIKELY(expr == nullptr)) {
1386                        throw std::runtime_error("Vertex " + std::to_string(v) + " does not have an expression!");
1387                    }
1388                    if (join) {
1389
1390                        if (in_degree(v, G) > 0) {
1391
1392                            assert (L[v] > 0 && L[v] < MAX_INT);
1393
1394                            Statement * dom = cast<Statement>(expr);
1395                            for (;;) {
1396                                PabloBlock * const parent = dom->getParent();
1397                                if (parent == mBlock) {
1398                                    break;
1399                                }
1400                                dom = parent->getBranch(); assert(dom);
1401                            }
1402                            mBlock->setInsertPoint(dom);
1403
1404                            assert (dominates(join, expr));
1405                            assert (dominates(expr, ip));
1406                            assert (dominates(dom, ip));
1407                        }
1408
1409                        Statement * const currIP = mBlock->getInsertPoint();
1410
1411                        switch (typeId) {
1412                            case TypeId::And:
1413                                expr = mBlock->createAnd(join, expr); break;
1414                            case TypeId::Or:
1415                                expr = mBlock->createOr(join, expr); break;
1416                            case TypeId::Xor:
1417                                expr = mBlock->createXor(join, expr); break;
1418                            default:
1419                                llvm_unreachable("Invalid TypeId!");
1420                        }
1421
1422                        // If the insertion point hasn't "moved" then we didn't make a new statement
1423                        // and thus must have unexpectidly reused a prior statement (or Var.)
1424                        if (LLVM_UNLIKELY(currIP == mBlock->getInsertPoint())) {
1425                            if (LLVM_LIKELY(reduceVertex(u, C, M, G, true))) {
1426                                done = false;
1427                                break;
1428                            }
1429                            throw std::runtime_error("Unable to reduce vertex " + std::to_string(u));
1430                        }
1431                    }
1432                    join = expr;
1433                }
1434
1435                T.clear();
1436
1437                if (done) {
1438                    break;
1439                }
1440            }
1441
1442
1443            PabloAST * const replacement = expr; assert (replacement);
1444
1445            mBlock->setInsertPoint(ip->getPrevNode());
1446
1447            for (auto e : make_iterator_range(out_edges(u, G))) {
1448                if (G[e]) {
1449                    if (PabloAST * user = getValue(G[target(e, G)])) {
1450                        cast<Statement>(user)->replaceUsesOfWith(G[e], replacement);
1451                    }
1452                }
1453            }
1454
1455            stmt = replacement;
1456        }
1457
1458        assert (stmt);
1459
1460        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
1461            for (auto e : make_iterator_range(out_edges(u, G))) {
1462                const auto v = target(e, G);
1463                assert (L[v] == std::numeric_limits<line_t>::max());
1464                L[v] = count;
1465            }
1466        }
1467
1468        mBlock->insert(cast<Statement>(stmt));
1469        L[u] = count++; // update the line count with the actual one.
1470    }
1471
1472    Z3_solver_dec_ref(ctx, solver);
1473    Z3_del_context(ctx);
1474
1475    Statement * const end = mBlock->getInsertPoint(); assert (end);
1476    for (;;) {
1477        Statement * const next = end->getNextNode();
1478        if (next == nullptr) {
1479            break;
1480        }
1481
1482        #ifndef NDEBUG
1483        for (PabloAST * user : next->users()) {
1484            if (isa<Statement>(user) && dominates(user, next)) {
1485                std::string tmp;
1486                raw_string_ostream out(tmp);
1487                out << "Erasing ";
1488                PabloPrinter::print(next, out);
1489                out << " erroneously modifies live statement ";
1490                PabloPrinter::print(cast<Statement>(user), out);
1491                throw std::runtime_error(out.str());
1492            }
1493        }
1494        #endif
1495        next->eraseFromParent(true);
1496    }
1497
1498    #ifndef NDEBUG
1499    PabloVerifier::verify(mFunction, "mid-reassociation");
1500    #endif
1501
1502    return true;
1503}
1504
1505/** ------------------------------------------------------------------------------------------------------------- *
1506 * @brief addSummaryVertex
1507 ** ------------------------------------------------------------------------------------------------------------- */
1508Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
1509//    for (Vertex u : make_iterator_range(vertices(G))) {
1510//        if (LLVM_UNLIKELY(in_degree(u, G) == 0 && out_degree(u, G) == 0)) {
1511//            std::get<0>(G[u]) = typeId;
1512//            std::get<1>(G[u]) = expr;
1513//            return u;
1514//        }
1515//    }
1516    return add_vertex(std::make_tuple(typeId, expr, node), G);
1517}
1518
1519/** ------------------------------------------------------------------------------------------------------------- *
1520 * @brief addSummaryVertex
1521 ** ------------------------------------------------------------------------------------------------------------- */
1522Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
1523    assert (expr);
1524    const auto f = M.find(expr);
1525    if (f != M.end()) {
1526        assert (getValue(G[f->second]) == expr);
1527        return f->second;
1528    }
1529    const Vertex u = makeVertex(typeId, expr, G, node);
1530    M.emplace(expr, u);
1531    return u;
1532}
1533
1534/** ------------------------------------------------------------------------------------------------------------- *
1535 * @brief addSummaryVertex
1536 ** ------------------------------------------------------------------------------------------------------------- */
1537Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & M, Graph & G) {
1538    assert (expr);
1539    const auto f = M.find(expr);
1540    if (f != M.end()) {
1541        assert (getValue(G[f->second]) == expr);
1542        return f->second;
1543    }
1544    const Vertex u = makeVertex(typeId, expr, G, C.get(expr));
1545    M.emplace(expr, u);
1546    return u;
1547}
1548
1549/** ------------------------------------------------------------------------------------------------------------- *
1550 * @brief removeSummaryVertex
1551 ** ------------------------------------------------------------------------------------------------------------- */
1552inline void BooleanReassociationPass::removeVertex(const Vertex u, StatementMap & M, Graph & G) const {
1553    VertexData & ref = G[u];
1554    if (std::get<1>(ref)) {
1555        auto f = M.find(std::get<1>(ref));
1556        assert (f != M.end());
1557        M.erase(f);
1558    }
1559    removeVertex(u, G);
1560}
1561
1562/** ------------------------------------------------------------------------------------------------------------- *
1563 * @brief removeSummaryVertex
1564 ** ------------------------------------------------------------------------------------------------------------- */
1565inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
1566    VertexData & ref = G[u];
1567    clear_vertex(u, G);
1568    std::get<0>(ref) = TypeId::Var;
1569    std::get<1>(ref) = nullptr;
1570    std::get<2>(ref) = nullptr;
1571}
1572
1573/** ------------------------------------------------------------------------------------------------------------- *
1574 * @brief make
1575 ** ------------------------------------------------------------------------------------------------------------- */
1576inline Z3_ast BooleanReassociationPass::makeVar() const {
1577    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1578//    Z3_inc_ref(mContext, node);
1579    return node;
1580}
1581
1582/** ------------------------------------------------------------------------------------------------------------- *
1583 * @brief simplify
1584 ** ------------------------------------------------------------------------------------------------------------- */
1585Z3_ast BooleanReassociationPass::simplify(Z3_ast node, bool use_expensive_minimization) const {
1586
1587    assert (node);
1588
1589    node = Z3_simplify_ex(mContext, node, mParams);
1590
1591    if (use_expensive_minimization) {
1592
1593        Z3_goal g = Z3_mk_goal(mContext, true, false, false);
1594        Z3_goal_inc_ref(mContext, g);
1595
1596        Z3_goal_assert(mContext, g, node);
1597
1598        Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g);
1599        Z3_apply_result_inc_ref(mContext, r);
1600
1601        assert (Z3_apply_result_get_num_subgoals(mContext, r) == 1);
1602
1603        Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0);
1604        Z3_goal_inc_ref(mContext, h);
1605
1606        const unsigned n = Z3_goal_size(mContext, h);
1607
1608        if (n == 1) {
1609            node = Z3_goal_formula(mContext, h, 0);
1610        } else if (n > 1) {
1611            Z3_ast operands[n];
1612            for (unsigned i = 0; i < n; ++i) {
1613                operands[i] = Z3_goal_formula(mContext, h, i);
1614            }
1615            node = Z3_mk_and(mContext, n, operands);
1616        }
1617        Z3_goal_dec_ref(mContext, h);
1618    }
1619    return node;
1620}
1621
1622/** ------------------------------------------------------------------------------------------------------------- *
1623 * @brief add
1624 ** ------------------------------------------------------------------------------------------------------------- */
1625inline Z3_ast BooleanReassociationPass::CharacterizationMap::add(PabloAST * const expr, Z3_ast node, const bool forwardOnly) {
1626    assert (expr && node);
1627    mForward.emplace(expr, node);
1628    if (!forwardOnly) {
1629        mBackward.emplace(node, expr);
1630    }
1631    return node;
1632}
1633
1634/** ------------------------------------------------------------------------------------------------------------- *
1635 * @brief get
1636 ** ------------------------------------------------------------------------------------------------------------- */
1637inline Z3_ast BooleanReassociationPass::CharacterizationMap::get(PabloAST * const expr) const {
1638    assert (expr);
1639    auto f = mForward.find(expr);
1640    if (LLVM_UNLIKELY(f == mForward.end())) {
1641        if (mPredecessor == nullptr) {
1642            return nullptr;
1643        }
1644        return mPredecessor->get(expr);
1645    }
1646    return f->second;
1647}
1648
1649/** ------------------------------------------------------------------------------------------------------------- *
1650 * @brief get
1651 ** ------------------------------------------------------------------------------------------------------------- */
1652inline PabloAST * BooleanReassociationPass::CharacterizationMap::findKey(Z3_ast const node) const {
1653    assert (node);
1654    auto f = mBackward.find(node);
1655    if (LLVM_UNLIKELY(f == mBackward.end())) {
1656        if (mPredecessor == nullptr) {
1657            return nullptr;
1658        }
1659        return mPredecessor->findKey(node);
1660    }
1661    return f->second;
1662}
1663
1664/** ------------------------------------------------------------------------------------------------------------- *
1665 * @brief constructor
1666 ** ------------------------------------------------------------------------------------------------------------- */
1667inline BooleanReassociationPass::BooleanReassociationPass(Z3_context ctx, Z3_params params, Z3_tactic tactic, PabloFunction & f)
1668: mContext(ctx)
1669, mParams(params)
1670, mTactic(tactic)
1671, mFunction(f)
1672, mModified(false) {
1673
1674}
1675
1676}
Note: See TracBrowser for help on using the repository browser.