source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 5454

Last change on this file since 5454 was 5454, checked in by nmedfort, 2 years ago

Bug fix check in for DumpTrace?, compilation of DoBlock? / DoFinalBlock? functions. Pablo CodeMotionPass? optimized and enabled by default.

File size: 65.6 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <boost/graph/strong_components.hpp>
8#include <pablo/optimizers/pablo_simplifier.hpp>
9#include <pablo/analysis/pabloverifier.hpp>
10#include <algorithm>
11#include <numeric> // std::iota
12#include <pablo/printer_pablos.h>
13#include <iostream>
14#include <llvm/Support/CommandLine.h>
15#include "maxsat.hpp"
16
17//  noif-dist-mult-dist-50 \p{Cham}(?<!\p{Mc})
18
19using namespace boost;
20using namespace boost::container;
21
22
23static cl::OptionCategory ReassociationOptions("Reassociation Optimization Options", "These options control the Pablo Reassociation optimization pass.");
24
25static cl::opt<unsigned> LoadEarly("Reassociation-Load-Early", cl::init(false),
26                                  cl::desc("When recomputing an Associative operation, load values from preceeding blocks at the beginning of the "
27                                           "Scope Block rather than at the point of first use."),
28                                  cl::cat(ReassociationOptions));
29
30namespace pablo {
31
32using TypeId = PabloAST::ClassTypeId;
33using Graph = BooleanReassociationPass::Graph;
34using Vertex = Graph::vertex_descriptor;
35using VertexData = BooleanReassociationPass::VertexData;
36using DistributionGraph = BooleanReassociationPass::DistributionGraph;
37using DistributionMap = flat_map<Graph::vertex_descriptor, DistributionGraph::vertex_descriptor>;
38using VertexSet = std::vector<Vertex>;
39using VertexSets = std::vector<VertexSet>;
40using Biclique = std::pair<VertexSet, VertexSet>;
41using BicliqueSet = std::vector<Biclique>;
42using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
43using DistributionSets = std::vector<DistributionSet>;
44
45/** ------------------------------------------------------------------------------------------------------------- *
46 * @brief helper functions
47 ** ------------------------------------------------------------------------------------------------------------- */
48template<typename Iterator>
49inline Graph::edge_descriptor first(const std::pair<Iterator, Iterator> & range) {
50    assert (range.first != range.second);
51    return *range.first;
52}
53
54static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock * const block) {
55    return stmt->getParent() == block;
56}
57
58static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock * const block) {
59    return expr ? isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block) : true;
60}
61
62inline TypeId & getType(VertexData & data) {
63    return std::get<0>(data);
64}
65
66inline TypeId getType(const VertexData & data) {
67    return std::get<0>(data);
68}
69
70inline Z3_ast & getDefinition(VertexData & data) {
71    return std::get<2>(data);
72}
73
74inline PabloAST * getValue(const VertexData & data) {
75    return std::get<1>(data);
76}
77
78inline PabloAST *& getValue(VertexData & data) {
79    return std::get<1>(data);
80}
81
82inline bool isAssociative(const VertexData & data) {
83    switch (getType(data)) {
84        case TypeId::And:
85        case TypeId::Or:
86        case TypeId::Xor:
87            return true;
88        default:
89            return false;
90    }
91}
92
93inline bool isDistributive(const VertexData & data) {
94    switch (getType(data)) {
95        case TypeId::And:
96        case TypeId::Or:
97            return true;
98        default:
99            return false;
100    }
101}
102
103void add_edge(PabloAST * expr, const Vertex u, const Vertex v, Graph & G) {
104    // Make sure each edge is unique
105    assert (u < num_vertices(G) && v < num_vertices(G));
106    assert (u != v);
107
108    // Just because we've supplied an expr doesn't mean it's useful. Check it.
109    if (expr) {
110        if (isAssociative(G[v])) {
111            expr = nullptr;
112        } else {
113            bool clear = true;
114            if (const Statement * dest = dyn_cast_or_null<Statement>(getValue(G[v]))) {
115                for (unsigned i = 0; i < dest->getNumOperands(); ++i) {
116                    if (dest->getOperand(i) == expr) {
117                        clear = false;
118                        break;
119                    }
120                }
121            }
122            if (LLVM_LIKELY(clear)) {
123                expr = nullptr;
124            }
125        }
126    }
127
128    for (auto e : make_iterator_range(out_edges(u, G))) {
129        if (LLVM_UNLIKELY(target(e, G) == v)) {
130            if (expr) {
131                if (G[e] == nullptr) {
132                    G[e] = expr;
133                } else if (G[e] != expr) {
134                    continue;
135                }
136            }
137            return;
138        }
139    }
140    G[boost::add_edge(u, v, G).first] = expr;
141}
142
143/** ------------------------------------------------------------------------------------------------------------- *
144 * @brief intersects
145 ** ------------------------------------------------------------------------------------------------------------- */
146template <class Type>
147inline bool intersects(Type & A, Type & B) {
148    auto first1 = A.begin(), last1 = A.end();
149    auto first2 = B.begin(), last2 = B.end();
150    while (first1 != last1 && first2 != last2) {
151        if (*first1 < *first2) {
152            ++first1;
153        } else if (*first2 < *first1) {
154            ++first2;
155        } else {
156            return true;
157        }
158    }
159    return false;
160}
161
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief printGraph
165 ** ------------------------------------------------------------------------------------------------------------- */
166static void printGraph(const Graph & G, const std::string & name) {
167    raw_os_ostream out(std::cerr);
168
169    std::vector<unsigned> c(num_vertices(G));
170    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
171
172    out << "digraph " << name << " {\n";
173    for (auto u : make_iterator_range(vertices(G))) {
174        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
175            continue;
176        }
177        out << "v" << u << " [label=\"" << u << ": ";
178        TypeId typeId;
179        PabloAST * expr;
180        Z3_ast node;
181        std::tie(typeId, expr, node) = G[u];
182        bool temporary = false;
183        bool error = false;
184        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
185            temporary = true;
186            switch (typeId) {
187                case TypeId::And:
188                    out << "And";
189                    break;
190                case TypeId::Or:
191                    out << "Or";
192                    break;
193                case TypeId::Xor:
194                    out << "Xor";
195                    break;
196                case TypeId::Not:
197                    out << "Not";
198                    break;
199                default:
200                    out << "???";
201                    error = true;
202                    break;
203            }
204            if (expr) {
205                out << " ("; PabloPrinter::print(expr, out); out << ")";
206            }
207        } else {
208            PabloPrinter::print(expr, out);
209        }
210        if (node == nullptr) {
211            out << " (*)";
212            error = true;
213        }
214        out << "\"";
215        if (typeId == TypeId::Var) {
216            out << " style=dashed";
217        }
218        if (error) {
219            out << " color=red";
220        } else if (temporary) {
221            out << " color=blue";
222        }
223        out << "];\n";
224    }
225    for (auto e : make_iterator_range(edges(G))) {
226        const auto s = source(e, G);
227        const auto t = target(e, G);
228        out << "v" << s << " -> v" << t;
229        bool cyclic = (c[s] == c[t]);
230        if (G[e] || cyclic) {
231            out << " [";
232             if (G[e]) {
233                out << "label=\"";
234                PabloPrinter::print(G[e], out);
235                out << "\" ";
236             }
237             if (cyclic) {
238                out << "color=red ";
239             }
240             out << "]";
241        }
242        out << ";\n";
243    }
244
245    if (num_vertices(G) > 0) {
246
247        out << "{ rank=same;";
248        for (auto u : make_iterator_range(vertices(G))) {
249            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
250                out << " v" << u;
251            }
252        }
253        out << "}\n";
254
255        out << "{ rank=same;";
256        for (auto u : make_iterator_range(vertices(G))) {
257            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
258                out << " v" << u;
259            }
260        }
261        out << "}\n";
262
263    }
264
265    out << "}\n\n";
266    out.flush();
267}
268
269/** ------------------------------------------------------------------------------------------------------------- *
270 * @brief optimize
271 ** ------------------------------------------------------------------------------------------------------------- */
272bool BooleanReassociationPass::optimize(PabloFunction & function) {
273
274    Z3_config cfg = Z3_mk_config();
275    Z3_context ctx = Z3_mk_context_rc(cfg);
276    Z3_del_config(cfg);
277
278    Z3_params params = Z3_mk_params(ctx);
279    Z3_params_inc_ref(ctx, params);
280    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "pull_cheap_ite"), true);
281    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "local_ctx"), true);
282
283    Z3_tactic ctx_solver_simplify = Z3_mk_tactic(ctx, "ctx-solver-simplify");
284    Z3_tactic_inc_ref(ctx, ctx_solver_simplify);
285
286    BooleanReassociationPass brp(ctx, params, ctx_solver_simplify, function);
287    brp.processScopes(function);
288
289    Z3_params_dec_ref(ctx, params);
290    Z3_tactic_dec_ref(ctx, ctx_solver_simplify);
291    Z3_del_context(ctx);
292
293    PabloVerifier::verify(function, "post-reassociation");
294
295    Simplifier::optimize(function);
296
297    return true;
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief processScopes
302 ** ------------------------------------------------------------------------------------------------------------- */
303inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
304    CharacterizationMap C;
305    PabloBlock * const entry = function.getEntryBlock();
306    // Map the constants and input variables
307    C.add(entry->createZeroes(), Z3_mk_false(mContext));
308    C.add(entry->createOnes(), Z3_mk_true(mContext));
309    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
310        C.add(mFunction.getParameter(i), makeVar());
311    }
312    mInFile = makeVar();
313    processScopes(entry, C);
314    for (auto i = mRefs.begin(); i != mRefs.end(); ++i) {
315        Z3_dec_ref(mContext, *i);
316    }
317    mRefs.clear();
318    return mModified;
319}
320
321/** ------------------------------------------------------------------------------------------------------------- *
322 * @brief processScopes
323 ** ------------------------------------------------------------------------------------------------------------- */
324void BooleanReassociationPass::processScopes(PabloBlock * const block, CharacterizationMap & C) {
325    const auto offset = mRefs.size();
326    for (Statement * stmt = block->front(); stmt; ) {
327        if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
328            if (LLVM_UNLIKELY(isa<Zeroes>(cast<Branch>(stmt)->getCondition()))) {
329                stmt = stmt->eraseFromParent(true);
330            } else {
331                CharacterizationMap nested(C);
332                processScopes(cast<Branch>(stmt)->getBody(), nested);
333                for (Var * def : cast<Branch>(stmt)->getEscaped()) {
334                    C.add(def, makeVar());
335                }
336                stmt = stmt->getNextNode();
337            }
338        } else { // characterize this statement then check whether it is equivalent to any existing one.
339            PabloAST * const folded = Simplifier::fold(stmt, block);
340            if (LLVM_UNLIKELY(folded != nullptr)) {
341                stmt = stmt->replaceWith(folded);
342            } else {
343                stmt = characterize(stmt, C);
344            }
345        }
346    }   
347    distributeScope(block, C);
348    for (auto i = mRefs.begin() + offset; i != mRefs.end(); ++i) {
349        Z3_dec_ref(mContext, *i);
350    }
351    mRefs.erase(mRefs.begin() + offset, mRefs.end());
352}
353
354/** ------------------------------------------------------------------------------------------------------------- *
355 * @brief characterize
356 ** ------------------------------------------------------------------------------------------------------------- */
357inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & C) {
358
359    Z3_ast node = nullptr;
360    const size_t n = stmt->getNumOperands(); assert (n > 0);
361    bool use_expensive_simplification = false;
362    if (isa<Variadic>(stmt)) {
363        Z3_ast operands[n];
364        for (size_t i = 0; i < n; ++i) {
365            PabloAST * const op = stmt->getOperand(i);
366            if (isa<Not>(op)) {
367                use_expensive_simplification = true;
368            }
369            operands[i] = C.get(op); assert (operands[i]);
370        }
371        if (isa<And>(stmt)) {
372            node = Z3_mk_and(mContext, n, operands);
373        } else if (isa<Or>(stmt)) {
374            node = Z3_mk_or(mContext, n, operands);
375        } else if (isa<Xor>(stmt)) {
376            node = Z3_mk_xor(mContext, operands[0], operands[1]);
377            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
378                Z3_inc_ref(mContext, node);
379                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
380                Z3_inc_ref(mContext, temp);
381                Z3_dec_ref(mContext, node);
382                node = temp;
383            }
384        }
385    } else if (isa<Not>(stmt)) {
386        Z3_ast op = C.get(stmt->getOperand(0)); assert (op);
387        node = Z3_mk_not(mContext, op);
388    } else if (isa<Sel>(stmt)) {
389        Z3_ast operands[3];
390        for (size_t i = 0; i < 3; ++i) {
391            operands[i] = C.get(stmt->getOperand(i)); assert (operands[i]);
392        }
393        node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
394    } else if (LLVM_UNLIKELY(isa<InFile>(stmt) || isa<AtEOF>(stmt))) {
395        assert (stmt->getNumOperands() == 1);
396        Z3_ast check[2];
397        check[0] = C.get(stmt->getOperand(0)); assert (check[0]);
398        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
399        node = Z3_mk_and(mContext, 2, check);
400    } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
401        return stmt->getNextNode();
402    }  else {
403        C.add(stmt, makeVar());
404        return stmt->getNextNode();
405    }
406    node = simplify(node, use_expensive_simplification);
407    PabloAST * const replacement = C.findKey(node);
408    if (LLVM_LIKELY(replacement == nullptr)) {
409        C.add(stmt, node);
410        mRefs.push_back(node);
411        return stmt->getNextNode();
412    } else {
413        Z3_dec_ref(mContext, node);
414        return stmt->replaceWith(replacement);
415    }
416}
417
418/** ------------------------------------------------------------------------------------------------------------- *
419 * @brief processScope
420 ** ------------------------------------------------------------------------------------------------------------- */
421inline void BooleanReassociationPass::distributeScope(PabloBlock * const block, CharacterizationMap & C) {
422    Graph G;
423    try {
424        mBlock = block;
425        transformAST(C, G);
426    } catch (std::exception &) {
427        printGraph(G, "E");
428        throw;
429    }
430}
431
432/** ------------------------------------------------------------------------------------------------------------- *
433 * @brief summarizeAST
434 *
435 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
436 * are "flattened" (i.e., allowed to have any number of inputs.)
437 ** ------------------------------------------------------------------------------------------------------------- */
438Vertex BooleanReassociationPass::transcribeSel(Sel * const stmt, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
439
440    Z3_ast args[2];
441
442    const Vertex c = makeVertex(TypeId::Var, stmt->getCondition(), C, S, M, G);
443    const Vertex t = makeVertex(TypeId::Var, cast<Sel>(stmt)->getTrueExpr(), C, S, M, G);
444    const Vertex f = makeVertex(TypeId::Var, cast<Sel>(stmt)->getFalseExpr(), C, S, M, G);
445
446    args[0] = getDefinition(G[c]);
447    args[1] = getDefinition(G[t]);
448
449    Z3_ast trueExpr = Z3_mk_and(mContext, 2, args);
450    Z3_inc_ref(mContext, trueExpr);
451    mRefs.push_back(trueExpr);
452
453    const Vertex x = makeVertex(TypeId::And, nullptr, G, trueExpr);
454    add_edge(nullptr, c, x, G);
455    add_edge(nullptr, t, x, G);
456
457    Z3_ast notCond = Z3_mk_not(mContext, args[0]);
458    Z3_inc_ref(mContext, notCond);
459    mRefs.push_back(notCond);
460
461    args[0] = notCond;
462    args[1] = getDefinition(G[f]);
463
464    Z3_ast falseExpr = Z3_mk_and(mContext, 2, args);
465    Z3_inc_ref(mContext, falseExpr);
466    mRefs.push_back(falseExpr);
467
468    const Vertex n = makeVertex(TypeId::Not, nullptr, G, notCond);
469
470    add_edge(nullptr, c, n, G);
471
472    const Vertex y = makeVertex(TypeId::And, nullptr, G, falseExpr);
473    add_edge(nullptr, n, y, G);
474    add_edge(nullptr, f, y, G);
475
476    const Vertex u = makeVertex(TypeId::Or, stmt, C, S, M, G);
477    add_edge(nullptr, x, u, G);
478    add_edge(nullptr, y, u, G);
479
480    return u;
481}
482
483/** ------------------------------------------------------------------------------------------------------------- *
484 * @brief summarizeAST
485 *
486 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
487 * are "flattened" (i.e., allowed to have any number of inputs.)
488 ** ------------------------------------------------------------------------------------------------------------- */
489void BooleanReassociationPass::transformAST(CharacterizationMap & C, Graph & G) {
490
491    StatementMap S;
492
493    VertexMap M;
494
495    // Compute the base def-use graph ...
496    for (Statement * stmt : *mBlock) {
497        if (LLVM_UNLIKELY(isa<Sel>(stmt))) {
498
499            const Vertex u = transcribeSel(cast<Sel>(stmt), C, S, M, G);
500
501            resolveNestedUsages(stmt, u, C, S, M, G, stmt);
502
503        } else {
504
505
506            const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, C, S, M, G);
507            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
508                PabloAST * const op = stmt->getOperand(i);
509                if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
510                    add_edge(op, makeVertex(TypeId::Var, op, C, S, M, G), u, G);
511                }
512            }
513            if (LLVM_UNLIKELY(isa<If>(stmt))) {
514                for (Var * def : cast<If>(stmt)->getEscaped()) {
515                    const Vertex v = makeVertex(TypeId::Var, def, C, S, M, G);
516                    add_edge(def, u, v, G);
517                    resolveNestedUsages(def, v, C, S, M, G, stmt);
518                }
519                continue;
520            } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
521                // To keep G a DAG, we need to do a bit of surgery on loop variants because
522                // the next variables it produces can be used within the condition. Instead,
523                // we make the loop dependent on the original value of each Next node and
524                // the Next node dependent on the loop.
525                for (Var * var : cast<While>(stmt)->getEscaped()) {
526                    const Vertex v = makeVertex(TypeId::Var, var, C, S, M, G);
527                    assert (in_degree(v, G) == 1);
528                    auto e = first(in_edges(v, G));
529                    add_edge(G[e], source(e, G), u, G);
530                    remove_edge(v, u, G);
531                    add_edge(var, u, v, G);
532                    resolveNestedUsages(var, v, C, S, M, G, stmt);
533                }
534                continue;
535            } else {
536                resolveNestedUsages(stmt, u, C, S, M, G, stmt);
537            }
538        }
539
540    }
541
542    if (redistributeGraph(C, M, G)) {
543        factorGraph(G);
544        rewriteAST(G);
545        mModified = true;
546    }
547
548}
549
550/** ------------------------------------------------------------------------------------------------------------- *
551 * @brief resolveNestedUsages
552 ** ------------------------------------------------------------------------------------------------------------- */
553void BooleanReassociationPass::resolveNestedUsages(PabloAST * const expr, const Vertex u,
554                                                   CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G,
555                                                   const Statement * const ignoreIfThis) const {
556    assert ("Cannot resolve nested usages of a null expression!" && expr);
557    for (PabloAST * user : expr->users()) { assert (user);
558        if (LLVM_LIKELY(user != expr && isa<Statement>(user))) {
559            PabloBlock * parent = cast<Statement>(user)->getParent(); assert (parent);
560            if (LLVM_UNLIKELY(parent != mBlock)) {
561                for (;;) {
562                    if (parent->getPredecessor() == mBlock) {
563                        Statement * const branch = parent->getBranch();
564                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
565                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
566                            const Vertex v = makeVertex(TypeId::Var, user, C, S, M, G);
567                            add_edge(expr, u, v, G);
568                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
569                            add_edge(user, v, w, G);
570                        }
571                        break;
572                    }
573                    parent = parent->getPredecessor();
574                    if (LLVM_UNLIKELY(parent == nullptr)) {
575                        assert (isa<Assign>(expr));
576                        break;
577                    }
578                }
579            }
580        }
581    }
582}
583
584
585
586/** ------------------------------------------------------------------------------------------------------------- *
587 * @brief enumerateBicliques
588 *
589 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
590 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
591 * bipartition A and their adjacencies to be in B.
592  ** ------------------------------------------------------------------------------------------------------------- */
593template <class Graph>
594static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
595    using IntersectionSets = std::set<VertexSet>;
596
597    IntersectionSets B1;
598    for (auto u : A) {
599        if (in_degree(u, G) > 0) {
600            VertexSet adjacencies;
601            adjacencies.reserve(in_degree(u, G));
602            for (auto e : make_iterator_range(in_edges(u, G))) {
603                adjacencies.push_back(source(e, G));
604            }
605            std::sort(adjacencies.begin(), adjacencies.end());
606            assert(std::unique(adjacencies.begin(), adjacencies.end()) == adjacencies.end());
607            B1.insert(std::move(adjacencies));
608        }
609    }
610
611    IntersectionSets B(B1);
612
613    IntersectionSets Bi;
614
615    VertexSet clique;
616    for (auto i = B1.begin(); i != B1.end(); ++i) {
617        for (auto j = i; ++j != B1.end(); ) {
618            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
619            if (clique.size() > 0) {
620                if (B.count(clique) == 0) {
621                    Bi.insert(clique);
622                }
623                clique.clear();
624            }
625        }
626    }
627
628    for (;;) {
629        if (Bi.empty()) {
630            break;
631        }
632        B.insert(Bi.begin(), Bi.end());
633        IntersectionSets Bk;
634        for (auto i = B1.begin(); i != B1.end(); ++i) {
635            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
636                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
637                if (clique.size() > 0) {
638                    if (B.count(clique) == 0) {
639                        Bk.insert(clique);
640                    }
641                    clique.clear();
642                }
643            }
644        }
645        Bi.swap(Bk);
646    }
647
648    BicliqueSet cliques;
649    cliques.reserve(B.size());
650    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
651        VertexSet Ai(A);
652        for (const Vertex u : *Bi) {
653            VertexSet Aj;
654            Aj.reserve(out_degree(u, G));
655            for (auto e : make_iterator_range(out_edges(u, G))) {
656                Aj.push_back(target(e, G));
657            }
658            std::sort(Aj.begin(), Aj.end());
659            assert(std::unique(Aj.begin(), Aj.end()) == Aj.end());
660            VertexSet Ak;
661            Ak.reserve(std::min(Ai.size(), Aj.size()));
662            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
663            Ai.swap(Ak);
664        }
665        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
666        cliques.emplace_back(std::move(Ai), std::move(*Bi));
667    }
668    return std::move(cliques);
669}
670
671/** ------------------------------------------------------------------------------------------------------------- *
672 * @brief independentCliqueSets
673 ** ------------------------------------------------------------------------------------------------------------- */
674template <unsigned side>
675inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
676    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
677
678    const auto l = cliques.size();
679    IndependentSetGraph I(l);
680
681    // Initialize our weights
682    for (unsigned i = 0; i != l; ++i) {
683        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
684    }
685
686    // Determine our constraints
687    for (unsigned i = 0; i != l; ++i) {
688        for (unsigned j = i + 1; j != l; ++j) {
689            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
690                add_edge(i, j, I);
691            }
692        }
693    }
694
695    // Use the greedy algorithm to choose our independent set
696    VertexSet selected;
697    for (;;) {
698        unsigned w = 0;
699        Vertex u = 0;
700        for (unsigned i = 0; i != l; ++i) {
701            if (I[i] > w) {
702                w = I[i];
703                u = i;
704            }
705        }
706        if (w < minimum) break;
707        selected.push_back(u);
708        I[u] = 0;
709        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
710            I[v] = 0;
711        }
712    }
713
714    // Sort the selected list and then remove the unselected cliques
715    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
716    auto end = cliques.end();
717    for (const unsigned offset : selected) {
718        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
719    }
720    cliques.erase(cliques.begin(), end);
721
722    return cliques;
723}
724
725/** ------------------------------------------------------------------------------------------------------------- *
726 * @brief removeUnhelpfulBicliques
727 *
728 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
729 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
730 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
731 ** ------------------------------------------------------------------------------------------------------------- */
732static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G, DistributionGraph & H) {
733    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
734        const auto cardinalityA = std::get<0>(*ci).size();
735        VertexSet & B = std::get<1>(*ci);
736        for (auto bi = B.begin(); bi != B.end(); ) {
737            if (out_degree(H[*bi], G) == cardinalityA) {
738                ++bi;
739            } else {
740                bi = B.erase(bi);
741            }
742        }
743        if (B.size() > 1) {
744            ++ci;
745        } else {
746            ci = cliques.erase(ci);
747        }
748    }
749    return std::move(cliques);
750}
751
752/** ------------------------------------------------------------------------------------------------------------- *
753 * @brief safeDistributionSets
754 ** ------------------------------------------------------------------------------------------------------------- */
755static DistributionSets safeDistributionSets(const Graph & G, DistributionGraph & H) {
756
757    VertexSet sinks;
758    for (const Vertex u : make_iterator_range(vertices(H))) {
759        if (out_degree(u, H) == 0 && in_degree(u, H) != 0) {
760            sinks.push_back(u);
761        }
762    }
763    std::sort(sinks.begin(), sinks.end());
764
765    DistributionSets T;
766    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(H, sinks), G, H), 1);
767    for (Biclique & lower : lowerSet) {
768        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(H, std::get<1>(lower)), 2);
769        for (Biclique & upper : upperSet) {
770            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
771        }
772    }
773    return std::move(T);
774}
775
776/** ------------------------------------------------------------------------------------------------------------- *
777 * @brief getVertex
778 ** ------------------------------------------------------------------------------------------------------------- */
779template<typename ValueType, typename GraphType, typename MapType>
780static inline Vertex getVertex(const ValueType value, GraphType & G, MapType & M) {
781    const auto f = M.find(value);
782    if (f != M.end()) {
783        return f->second;
784    }
785    const auto u = add_vertex(value, G);
786    M.insert(std::make_pair(value, u));
787    return u;
788}
789
790/** ------------------------------------------------------------------------------------------------------------- *
791 * @brief generateDistributionGraph
792 ** ------------------------------------------------------------------------------------------------------------- */
793void generateDistributionGraph(const Graph & G, DistributionGraph & H) {
794    DistributionMap M;
795    for (const Vertex u : make_iterator_range(vertices(G))) {
796        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
797            continue;
798        } else if (isDistributive(G[u])) {
799            TypeId outerTypeId = getType(G[u]);
800            TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
801            flat_set<Vertex> distributable;
802            for (auto e : make_iterator_range(in_edges(u, G))) {
803                const Vertex v = source(e, G);
804                if (LLVM_UNLIKELY(getType(G[v]) == innerTypeId)) {
805                    bool safe = true;
806                    for (const auto e : make_iterator_range(out_edges(v, G))) {
807                        if (getType(G[target(e, G)]) != outerTypeId) {
808                            safe = false;
809                            break;
810                        }
811                    }
812                    if (safe) {
813                        distributable.insert(v);
814                    }
815                }
816            }
817            if (LLVM_LIKELY(distributable.size() > 1)) {
818                flat_set<Vertex> observed;
819                for (const Vertex v : distributable) {
820                    for (auto e : make_iterator_range(in_edges(v, G))) {
821                        const auto v = source(e, G);
822                        observed.insert(v);
823                    }
824                }
825                for (const Vertex w : observed) {
826                    for (auto e : make_iterator_range(out_edges(w, G))) {
827                        const Vertex v = target(e, G);
828                        if (distributable.count(v)) {
829                            const Vertex y = getVertex(v, H, M);
830                            boost::add_edge(y, getVertex(u, H, M), H);
831                            boost::add_edge(getVertex(w, H, M), y, H);
832                        }
833                    }
834                }
835            }
836        }
837    }
838}
839
840/** ------------------------------------------------------------------------------------------------------------- *
841 * @brief recomputeDefinition
842 ** ------------------------------------------------------------------------------------------------------------- */
843inline Z3_ast BooleanReassociationPass::computeDefinition(TypeId typeId, const Vertex u, Graph & G, const bool use_expensive_minimization) const {
844    const unsigned n = in_degree(u, G);
845    Z3_ast operands[n];
846    unsigned k = 0;
847    for (const auto e : make_iterator_range(in_edges(u, G))) {
848        const auto v = source(e, G);
849        if (LLVM_UNLIKELY(getDefinition(G[v]) == nullptr)) {
850            throw std::runtime_error("No definition for " + std::to_string(v));
851        }
852        operands[k++] = getDefinition(G[v]);
853    }
854    assert (k == n);
855    Z3_ast const node = (typeId == TypeId::And) ? Z3_mk_and(mContext, n, operands) : Z3_mk_or(mContext, n, operands);
856    return simplify(node, use_expensive_minimization);
857}
858
859/** ------------------------------------------------------------------------------------------------------------- *
860 * @brief updateDefinition
861 *
862 * Apply the distribution law to reduce computations whenever possible.
863 ** ------------------------------------------------------------------------------------------------------------- */
864Vertex BooleanReassociationPass::updateIntermediaryDefinition(TypeId typeId, const Vertex u, VertexMap & M, Graph & G) {
865
866    Z3_ast def = computeDefinition(typeId, u, G);
867    Z3_ast orig = getDefinition(G[u]); assert (orig);
868
869    Z3_dec_ref(mContext, orig);
870
871    const auto g = M.find(orig);
872    if (LLVM_LIKELY(g != M.end())) {
873        M.erase(g);
874    }
875
876    const auto f = std::find(mRefs.rbegin(), mRefs.rend(), orig);
877    assert (f != mRefs.rend());
878    *f = def;
879
880    const auto h = M.find(def);
881    if (LLVM_UNLIKELY(h != M.end())) {
882        const auto v = h->second;
883        if (v != u) {
884            for (auto e : make_iterator_range(out_edges(u, G))) {
885                add_edge(G[e], v, target(e, G), G);
886            }
887            removeVertex(u, G);
888            return v;
889        }
890    }
891
892    getDefinition(G[u]) = def;
893    M.emplace(def, u);
894    return u;
895}
896
897/** ------------------------------------------------------------------------------------------------------------- *
898 * @brief updateDefinition
899 *
900 * Apply the distribution law to reduce computations whenever possible.
901 ** ------------------------------------------------------------------------------------------------------------- */
902Vertex BooleanReassociationPass::updateSinkDefinition(TypeId typeId, const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G) {
903
904    Z3_ast const def = computeDefinition(typeId, u, G);
905
906    auto f = M.find(def);
907
908    if (LLVM_UNLIKELY(f != M.end())) {
909        Z3_dec_ref(mContext, def);
910        Vertex v = f->second; assert (v != u);
911        for (auto e : make_iterator_range(out_edges(u, G))) {
912            add_edge(G[e], v, target(e, G), G);
913        }
914        removeVertex(u, G);
915        return v;
916    } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
917        PabloAST * const factor = C.predecessor()->findKey(def);
918        if (LLVM_UNLIKELY(factor != nullptr)) {
919            getValue(G[u]) = factor;
920            getType(G[u]) = TypeId::Var;
921            clear_in_edges(u, G);
922        }
923    }
924
925    getDefinition(G[u]) = def;
926    mRefs.push_back(def);
927
928    graph_traits<Graph>::in_edge_iterator begin, end;
929
930restart:
931
932    if (in_degree(u, G) > 1) {
933        std::tie(begin, end) = in_edges(u, G);
934        for (auto i = begin; ++i != end; ) {
935            const auto v = source(*i, G);
936            for (auto j = begin; j != i; ++j) {
937                const auto w = source(*j, G);
938                Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
939                Z3_ast test = nullptr;
940                switch (typeId) {
941                    case TypeId::And:
942                        test = Z3_mk_and(mContext, 2, operands); break;
943                    case TypeId::Or:
944                        test = Z3_mk_or(mContext, 2, operands); break;
945                    case TypeId::Xor:
946                        test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
947                    default:
948                        llvm_unreachable("impossible type id");
949                }
950                test = simplify(test, true);
951
952                bool replacement = false;
953                Vertex x = 0;
954                const auto f = M.find(test);
955                if (LLVM_UNLIKELY(f != M.end())) {
956                    x = f->second;
957                    Z3_ast orig = getDefinition(G[x]);
958                    if (LLVM_UNLIKELY(orig != test)) {
959                        std::string tmp;
960                        raw_string_ostream out(tmp);
961                        out << "vertex " << x << " is mapped to:\n"
962                            << Z3_ast_to_string(mContext, test)
963                            << "\n\nBut is recorded as:\n\n";
964                        if (orig) {
965                            out << Z3_ast_to_string(mContext, orig);
966                        } else {
967                            out << "<null>";
968                        }
969                        throw std::runtime_error(out.str());
970                    }
971                    Z3_dec_ref(mContext, test);
972                    replacement = true;
973                } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
974                    PabloAST * const factor = C.predecessor()->findKey(test);
975                    if (LLVM_UNLIKELY(factor != nullptr)) {
976                        x = makeVertex(TypeId::Var, factor, G, test);
977                        M.emplace(test, x);
978                        replacement = true;
979                        mRefs.push_back(test);
980                    }
981                }
982
983                if (LLVM_UNLIKELY(replacement)) {
984
985                    assert (G[*i] == nullptr);
986                    assert (G[*j] == nullptr);
987
988                    remove_edge(*i, G);
989                    remove_edge(*j, G);
990
991                    add_edge(nullptr, x, u, G);
992
993                    goto restart;
994                }
995
996                Z3_dec_ref(mContext, test);
997            }
998        }
999    }
1000
1001    M.emplace(def, u);
1002
1003    return u;
1004}
1005
1006/** ------------------------------------------------------------------------------------------------------------- *
1007 * @brief redistributeAST
1008 *
1009 * Apply the distribution law to reduce computations whenever possible.
1010 ** ------------------------------------------------------------------------------------------------------------- */
1011bool BooleanReassociationPass::redistributeGraph(CharacterizationMap & C, VertexMap & M, Graph & G) {
1012
1013    bool modified = false;
1014
1015//    errs() << "=====================================================\n";
1016
1017    DistributionGraph H;
1018
1019    for (;;) {
1020
1021        contractGraph(M, G);
1022
1023//        printGraph(G, "G");
1024
1025        generateDistributionGraph(G, H);
1026
1027        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
1028        if (num_vertices(H) == 0) {
1029            break;
1030        }
1031
1032        const DistributionSets distributionSets = safeDistributionSets(G, H);
1033
1034        if (LLVM_UNLIKELY(distributionSets.empty())) {
1035            break;
1036        }
1037
1038        modified = true;
1039
1040        mRefs.reserve(distributionSets.size() * 2);
1041
1042        for (const DistributionSet & set : distributionSets) {
1043
1044            // Each distribution tuple consists of the sources, intermediary, and sink nodes.
1045            const VertexSet & sources = std::get<0>(set);
1046            const VertexSet & intermediary = std::get<1>(set);
1047            const VertexSet & sinks = std::get<2>(set);
1048
1049            TypeId outerTypeId = getType(G[H[sinks.front()]]);
1050            assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
1051            TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
1052
1053            const Vertex x = makeVertex(outerTypeId, nullptr, G);
1054            const Vertex y = makeVertex(innerTypeId, nullptr, G);
1055
1056            // Update G to reflect the distributed operations (including removing the subgraph of
1057            // the to-be distributed edges.)
1058
1059            add_edge(nullptr, x, y, G);
1060
1061            for (const Vertex i : sources) {
1062                const auto u = H[i];
1063                for (const Vertex j : intermediary) {
1064                    const auto v = H[j];
1065                    assert (getType(G[v]) == innerTypeId);
1066                    const auto e = edge(u, v, G); assert (e.second);
1067                    remove_edge(e.first, G);
1068                }
1069                add_edge(nullptr, u, y, G);
1070            }
1071
1072            for (const Vertex i : intermediary) {
1073
1074                const auto u = updateIntermediaryDefinition(innerTypeId, H[i], M, G);
1075
1076                for (const Vertex j : sinks) {
1077                    const auto v = H[j];
1078                    assert (getType(G[v]) == outerTypeId);
1079                    const auto e = edge(u, v, G); assert (e.second);
1080                    add_edge(G[e.first], y, v, G);
1081                    remove_edge(e.first, G);
1082                }
1083                add_edge(nullptr, u, x, G);
1084            }
1085
1086            updateSinkDefinition(outerTypeId, x, C, M, G);
1087
1088            updateSinkDefinition(innerTypeId, y, C, M, G);
1089
1090        }
1091
1092        H.clear();
1093
1094    }
1095
1096    return modified;
1097}
1098
1099/** ------------------------------------------------------------------------------------------------------------- *
1100 * @brief isNonEscaping
1101 ** ------------------------------------------------------------------------------------------------------------- */
1102inline bool isNonEscaping(const VertexData & data) {
1103    // If these are redundant, the Simplifier pass will eliminate them. Trust that they're necessary.
1104    switch (getType(data)) {
1105        case TypeId::Assign:
1106        case TypeId::If:
1107        case TypeId::While:
1108        case TypeId::Count:
1109            return false;
1110        default:
1111            return true;
1112    }
1113}
1114
1115/** ------------------------------------------------------------------------------------------------------------- *
1116 * @brief unique_source
1117 ** ------------------------------------------------------------------------------------------------------------- */
1118inline bool has_unique_source(const Vertex u, const Graph & G) {
1119    if (in_degree(u, G) > 0) {
1120        graph_traits<Graph>::in_edge_iterator i, end;
1121        std::tie(i, end) = in_edges(u, G);
1122        const Vertex v = source(*i, G);
1123        while (++i != end) {
1124            if (source(*i, G) != v) {
1125                return false;
1126            }
1127        }
1128        return true;
1129    }
1130    return false;
1131}
1132
1133/** ------------------------------------------------------------------------------------------------------------- *
1134 * @brief unique_target
1135 ** ------------------------------------------------------------------------------------------------------------- */
1136inline bool has_unique_target(const Vertex u, const Graph & G) {
1137    if (out_degree(u, G) > 0) {
1138        graph_traits<Graph>::out_edge_iterator i, end;
1139        std::tie(i, end) = out_edges(u, G);
1140        const Vertex v = target(*i, G);
1141        while (++i != end) {
1142            if (target(*i, G) != v) {
1143                return false;
1144            }
1145        }
1146        return true;
1147    }
1148    return false;
1149}
1150
1151/** ------------------------------------------------------------------------------------------------------------- *
1152 * @brief contractGraph
1153 ** ------------------------------------------------------------------------------------------------------------- */
1154bool BooleanReassociationPass::contractGraph(VertexMap & M, Graph & G) const {
1155
1156    bool contracted = false;
1157
1158    circular_buffer<Vertex> ordering(num_vertices(G));
1159
1160    topological_sort(G, std::back_inserter(ordering)); // reverse topological ordering
1161
1162    // first contract the graph
1163    for (const Vertex u : ordering) {
1164        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
1165            continue;
1166        } else if (LLVM_LIKELY(out_degree(u, G) > 0)) {
1167            if (isAssociative(G[u])) {
1168                if (LLVM_UNLIKELY(has_unique_source(u, G))) {
1169                    // We have a redundant node here that'll simply end up being a duplicate
1170                    // of the input value. Remove it and add any of its outgoing edges to its
1171                    // input node.
1172                    const auto ei = first(in_edges(u, G));
1173                    const Vertex v = source(ei, G);
1174                    for (auto ej : make_iterator_range(out_edges(u, G))) {
1175                        add_edge(G[ej], v, target(ej, G), G);
1176                    }
1177                    removeVertex(u, M, G);
1178                    contracted = true;
1179                } else if (LLVM_UNLIKELY(has_unique_target(u, G))) {
1180                    // Otherwise if we have a single user, we have a similar case as above but
1181                    // we can only merge this vertex into the outgoing instruction if they are
1182                    // of the same type.
1183                    const auto ei = first(out_edges(u, G));
1184                    const Vertex v = target(ei, G);
1185                    if (LLVM_UNLIKELY(getType(G[v]) == getType(G[u]))) {
1186                        for (auto ej : make_iterator_range(in_edges(u, G))) {
1187                            add_edge(G[ej], source(ej, G), v, G);
1188                        }
1189                        removeVertex(u, M, G);
1190                        contracted = true;
1191                    }
1192                }
1193            }
1194        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
1195            removeVertex(u, M, G);
1196            contracted = true;
1197        }
1198    }
1199    return contracted;
1200}
1201
1202/** ------------------------------------------------------------------------------------------------------------- *
1203 * @brief factorGraph
1204 ** ------------------------------------------------------------------------------------------------------------- */
1205bool BooleanReassociationPass::factorGraph(TypeId typeId, Graph & G, std::vector<Vertex> & factors) const {
1206
1207    if (LLVM_UNLIKELY(factors.empty())) {
1208        return false;
1209    }
1210
1211    std::vector<Vertex> I, J, K;
1212
1213    bool modified = false;
1214
1215    for (unsigned i = 1; i < factors.size(); ++i) {
1216        assert (getType(G[factors[i]]) == typeId);
1217        for (auto ei : make_iterator_range(in_edges(factors[i], G))) {
1218            I.push_back(source(ei, G));
1219        }
1220        std::sort(I.begin(), I.end());
1221        for (unsigned j = 0; j < i; ++j) {
1222            for (auto ej : make_iterator_range(in_edges(factors[j], G))) {
1223                J.push_back(source(ej, G));
1224            }
1225            std::sort(J.begin(), J.end());
1226            // get the pairwise intersection of each set of inputs (i.e., their common subexpression)
1227            std::set_intersection(I.begin(), I.end(), J.begin(), J.end(), std::back_inserter(K));
1228            assert (std::is_sorted(K.begin(), K.end()));
1229            // if the intersection contains at least two elements
1230            const auto n = K.size();
1231            if (n > 1) {
1232                Vertex a = factors[i];
1233                Vertex b = factors[j];
1234                if (LLVM_UNLIKELY(in_degree(a, G) == n || in_degree(b, G) == n)) {
1235                    if (in_degree(a, G) != n) {
1236                        assert (in_degree(b, G) == n);
1237                        std::swap(a, b);
1238                    }
1239                    assert (in_degree(a, G) == n);
1240                    if (in_degree(b, G) == n) {
1241                        for (auto e : make_iterator_range(out_edges(b, G))) {
1242                            add_edge(G[e], a, target(e, G), G);
1243                        }
1244                        removeVertex(b, G);
1245                    } else {
1246                        for (auto u : K) {
1247                            remove_edge(u, b, G);
1248                        }
1249                        add_edge(nullptr, a, b, G);
1250                    }
1251                } else {
1252                    Vertex v = makeVertex(typeId, nullptr, G);
1253                    for (auto u : K) {
1254                        remove_edge(u, a, G);
1255                        remove_edge(u, b, G);
1256                        add_edge(nullptr, u, v, G);
1257                    }
1258                    add_edge(nullptr, v, a, G);
1259                    add_edge(nullptr, v, b, G);
1260                    factors.push_back(v);
1261                }
1262                modified = true;
1263            }
1264            K.clear();
1265            J.clear();
1266        }
1267        I.clear();
1268    }
1269    return modified;
1270}
1271
1272/** ------------------------------------------------------------------------------------------------------------- *
1273 * @brief factorGraph
1274 ** ------------------------------------------------------------------------------------------------------------- */
1275bool BooleanReassociationPass::factorGraph(Graph & G) const {
1276    // factor the associative vertices.
1277    std::vector<Vertex> factors;
1278    bool factored = false;
1279    for (unsigned i = 0; i < 3; ++i) {
1280        TypeId typeId[3] = { TypeId::And, TypeId::Or, TypeId::Xor};
1281        for (auto j : make_iterator_range(vertices(G))) {
1282            if (getType(G[j]) == typeId[i]) {
1283                factors.push_back(j);
1284            }
1285        }
1286        if (factorGraph(typeId[i], G, factors)) {
1287            factored = true;
1288        }
1289        factors.clear();
1290    }
1291    return factored;
1292}
1293
1294/** ------------------------------------------------------------------------------------------------------------- *
1295 * @brief isMutable
1296 ** ------------------------------------------------------------------------------------------------------------- */
1297inline bool isMutable(const Vertex u, const Graph & G) {
1298    return getType(G[u]) != TypeId::Var;
1299}
1300
1301/** ------------------------------------------------------------------------------------------------------------- *
1302 * @brief rewriteAST
1303 ** ------------------------------------------------------------------------------------------------------------- */
1304bool BooleanReassociationPass::rewriteAST(Graph & G) {
1305
1306    using line_t = long long int;
1307
1308    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
1309
1310    // errs() << "---------------------------------------------------------\n";
1311
1312    // printGraph(G, "X");
1313
1314    Z3_config cfg = Z3_mk_config();
1315    Z3_set_param_value(cfg, "model", "true");
1316    Z3_context ctx = Z3_mk_context(cfg);
1317    Z3_del_config(cfg);
1318    Z3_solver solver = Z3_mk_solver(ctx);
1319    Z3_solver_inc_ref(ctx, solver);
1320
1321    std::vector<Z3_ast> mapping(num_vertices(G), nullptr);
1322
1323    flat_map<PabloAST *, Z3_ast> V;
1324
1325    // Generate the variables
1326    const auto ty = Z3_mk_int_sort(ctx);
1327    Z3_ast ZERO = Z3_mk_int(ctx, 0, ty);
1328
1329    for (const Vertex u : make_iterator_range(vertices(G))) {
1330        const auto var = Z3_mk_fresh_const(ctx, nullptr, ty);
1331        Z3_ast constraint = nullptr;
1332        if (in_degree(u, G) > 0) {
1333            constraint = Z3_mk_gt(ctx, var, ZERO);
1334            Z3_solver_assert(ctx, solver, constraint);
1335        }       
1336        PabloAST * const expr = getValue(G[u]);
1337        if (inCurrentBlock(expr, mBlock)) {
1338            V.emplace(expr, var);
1339        }
1340        mapping[u] = var;
1341    }
1342
1343    // Add in the dependency constraints
1344    for (const Vertex u : make_iterator_range(vertices(G))) {
1345        Z3_ast const t = mapping[u];
1346        for (auto e : make_iterator_range(in_edges(u, G))) {
1347            Z3_ast const s = mapping[source(e, G)];
1348            Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, s, t));
1349        }
1350    }
1351
1352    // Compute the soft ordering constraints
1353    std::vector<Z3_ast> ordering(0);
1354    ordering.reserve(V.size() - 1);
1355
1356    Z3_ast prior = nullptr;
1357    unsigned gap = 1;
1358    for (Statement * stmt : *mBlock) {
1359        auto f = V.find(stmt);
1360        if (f != V.end()) {
1361            Z3_ast const node = f->second;
1362            if (prior) {
1363//                ordering.push_back(Z3_mk_lt(ctx, prior, node)); // increases the cost by 6 - 10x
1364                Z3_ast ops[2] = { node, prior };
1365                ordering.push_back(Z3_mk_le(ctx, Z3_mk_sub(ctx, 2, ops), Z3_mk_int(ctx, gap, ty)));
1366            } else {
1367                ordering.push_back(Z3_mk_eq(ctx, node, Z3_mk_int(ctx, gap, ty)));
1368            }
1369            prior = node;
1370            gap = 0;
1371        }
1372        ++gap;
1373    }
1374
1375    if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) < 0)) {
1376        throw std::runtime_error("Unable to construct a topological ordering during reassociation!");
1377    }
1378
1379    Z3_model model = Z3_solver_get_model(ctx, solver);
1380    Z3_model_inc_ref(ctx, model);
1381
1382    std::vector<Vertex> S(0);
1383    S.reserve(num_vertices(G));
1384
1385    std::vector<line_t> L(num_vertices(G));
1386
1387    for (const Vertex u : make_iterator_range(vertices(G))) {
1388        line_t line = LoadEarly ? 0 : MAX_INT;
1389        if (isMutable(u, G)) {
1390            Z3_ast value;
1391            if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, mapping[u], Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1392                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1393            }
1394            if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1395                throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1396            }
1397            S.push_back(u);
1398        }
1399        L[u] = line;
1400    }
1401
1402    Z3_model_dec_ref(ctx, model);
1403    Z3_solver_dec_ref(ctx, solver);
1404    Z3_del_context(ctx);
1405
1406    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1407
1408    mBlock->setInsertPoint(nullptr);
1409
1410    std::vector<Vertex> T;
1411
1412    line_t count = 1;
1413
1414    for (auto u : S) {
1415        PabloAST *& stmt = getValue(G[u]);
1416
1417        assert (isMutable(u, G));
1418        assert (L[u] > 0 && L[u] < MAX_INT);
1419
1420        bool append = true;
1421
1422        if (isAssociative(G[u])) {
1423
1424            if (in_degree(u, G) == 0 || out_degree(u, G) == 0) {
1425                throw std::runtime_error("Vertex " + std::to_string(u) + " is either a source or sink node but marked as associative!");
1426            }
1427
1428            Statement * ip = mBlock->getInsertPoint();
1429            ip = ip ? ip->getNextNode() : mBlock->front();
1430
1431            const auto typeId = getType(G[u]);
1432
1433            T.clear();
1434            T.reserve(in_degree(u, G));
1435            for (const auto e : make_iterator_range(in_edges(u, G))) {
1436                T.push_back(source(e, G));
1437            }
1438
1439            // Then sort them by their line position (noting any incoming value will either be 0 or MAX_INT)
1440            std::sort(T.begin(), T.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1441
1442            if (LoadEarly) {
1443                mBlock->setInsertPoint(nullptr);
1444            }
1445
1446            PabloAST * expr = nullptr;
1447            PabloAST * join = nullptr;
1448
1449            for (auto v : T) {
1450                expr = getValue(G[v]);
1451                if (LLVM_UNLIKELY(expr == nullptr)) {
1452                    throw std::runtime_error("Vertex " + std::to_string(v) + " does not have an expression!");
1453                }
1454                if (join) {
1455
1456                    if (in_degree(v, G) > 0) {
1457
1458                        assert (L[v] > 0 && L[v] < MAX_INT);
1459
1460                        Statement * dom = cast<Statement>(expr);
1461                        for (;;) {
1462                            PabloBlock * const parent = dom->getParent();
1463                            if (parent == mBlock) {
1464                                break;
1465                            }
1466                            dom = parent->getBranch(); assert(dom);
1467                        }
1468                        mBlock->setInsertPoint(dom);
1469
1470                        assert (dominates(join, expr));
1471                        assert (dominates(expr, ip));
1472                        assert (dominates(dom, ip));
1473                    }
1474
1475                    switch (typeId) {
1476                        case TypeId::And:
1477                            expr = mBlock->createAnd(join, expr); break;
1478                        case TypeId::Or:
1479                            expr = mBlock->createOr(join, expr); break;
1480                        case TypeId::Xor:
1481                            expr = mBlock->createXor(join, expr); break;
1482                        default:
1483                            llvm_unreachable("Invalid TypeId!");
1484                    }
1485                }
1486                join = expr;
1487            }
1488
1489            assert (expr);
1490
1491            Statement * const currIP = mBlock->getInsertPoint();
1492
1493            mBlock->setInsertPoint(ip->getPrevNode());
1494
1495            stmt = expr;
1496
1497            // If the insertion point isn't the statement we just attempted to create,
1498            // we must have unexpectidly reused a prior statement (or var.)
1499            if (LLVM_UNLIKELY(expr != currIP)) {
1500                L[u] = 0;
1501                // If that statement happened to be in the current block, locate which
1502                // statement it "duplicated" (if any) and set the duplicate's line # to
1503                // the line of the original statement to maintain a consistent view of
1504                // the program.
1505                if (inCurrentBlock(expr, mBlock)) {
1506                    for (auto v : make_iterator_range(vertices(G))) {
1507                        if (LLVM_UNLIKELY(getValue(G[v]) == expr)) {
1508                            L[u] = L[v];
1509                            break;
1510                        }
1511                    }
1512                }
1513                append = false;
1514            }
1515        } else if (stmt == nullptr) {
1516            assert (getType(G[u]) == TypeId::Not);
1517            assert (in_degree(u, G) == 1);
1518            PabloAST * op = getValue(G[source(first(in_edges(u, G)), G)]); assert (op);
1519            stmt = mBlock->createNot(op);
1520        } else if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
1521            for (auto e : make_iterator_range(out_edges(u, G))) {
1522                const auto v = target(e, G);
1523                assert (L[v] == std::numeric_limits<line_t>::max());
1524                L[v] = count;
1525            }
1526        }
1527
1528        for (auto e : make_iterator_range(out_edges(u, G))) {
1529            if (G[e]) {
1530                if (PabloAST * user = getValue(G[target(e, G)])) {
1531                    cast<Statement>(user)->replaceUsesOfWith(G[e], stmt);
1532                }
1533            }
1534        }
1535
1536        if (LLVM_LIKELY(append)) {
1537            mBlock->insert(cast<Statement>(stmt));
1538            L[u] = count++; // update the line count with the actual one.
1539        }
1540    }
1541
1542    Statement * const end = mBlock->getInsertPoint(); assert (end);
1543    for (;;) {
1544        Statement * const next = end->getNextNode();
1545        if (next == nullptr) {
1546            break;
1547        }
1548
1549        #ifndef NDEBUG
1550        for (PabloAST * user : next->users()) {
1551            if (isa<Statement>(user) && dominates(user, next)) {
1552                std::string tmp;
1553                raw_string_ostream out(tmp);
1554                out << "Erasing ";
1555                PabloPrinter::print(next, out);
1556                out << " erroneously modifies live statement ";
1557                PabloPrinter::print(cast<Statement>(user), out);
1558                throw std::runtime_error(out.str());
1559            }
1560        }
1561        #endif
1562        next->eraseFromParent(true);
1563    }
1564
1565    #ifndef NDEBUG
1566    PabloVerifier::verify(mFunction, "mid-reassociation");
1567    #endif
1568
1569    return true;
1570}
1571
1572/** ------------------------------------------------------------------------------------------------------------- *
1573* @brief makeVertex
1574** ------------------------------------------------------------------------------------------------------------- */
1575Vertex BooleanReassociationPass::makeVertex(TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
1576    assert (expr);
1577    const auto f = S.find(expr);
1578    if (f != S.end()) {
1579        assert (getValue(G[f->second]) == expr);
1580        return f->second;
1581    }
1582    const auto node = C.get(expr);   
1583    const Vertex u = makeVertex(typeId, expr, G, node);
1584    S.emplace(expr, u);
1585    if (node) {
1586        M.emplace(node, u);
1587    }
1588    return u;
1589}
1590
1591/** ------------------------------------------------------------------------------------------------------------- *
1592 * @brief makeVertex
1593 ** ------------------------------------------------------------------------------------------------------------- */
1594Vertex BooleanReassociationPass::makeVertex(TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
1595    assert (expr);
1596    const auto f = M.find(expr);
1597    if (f != M.end()) {
1598        assert (getValue(G[f->second]) == expr);
1599        return f->second;
1600    }
1601    const Vertex u = makeVertex(typeId, expr, G, node);
1602    M.emplace(expr, u);
1603    return u;
1604}
1605
1606/** ------------------------------------------------------------------------------------------------------------- *
1607 * @brief makeVertex
1608 ** ------------------------------------------------------------------------------------------------------------- */
1609Vertex BooleanReassociationPass::makeVertex(TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
1610//    for (Vertex u : make_iterator_range(vertices(G))) {
1611//        if (LLVM_UNLIKELY(in_degree(u, G) == 0 && out_degree(u, G) == 0)) {
1612//            std::get<0>(G[u]) = typeId;
1613//            std::get<1>(G[u]) = expr;
1614//            return u;
1615//        }
1616//    }
1617    return add_vertex(std::make_tuple(typeId, expr, node), G);
1618}
1619
1620/** ------------------------------------------------------------------------------------------------------------- *
1621 * @brief removeVertex
1622 ** ------------------------------------------------------------------------------------------------------------- */
1623inline void BooleanReassociationPass::removeVertex(const Vertex u, VertexMap & M, Graph & G) const {
1624    VertexData & ref = G[u];
1625    Z3_ast def = getDefinition(ref); assert (def);
1626    auto f = M.find(def); assert (f != M.end());
1627    M.erase(f);
1628    removeVertex(u, G);
1629}
1630
1631/** ------------------------------------------------------------------------------------------------------------- *
1632 * @brief removeVertex
1633 ** ------------------------------------------------------------------------------------------------------------- */
1634inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
1635    VertexData & ref = G[u];
1636    clear_vertex(u, G);
1637    std::get<0>(ref) = TypeId::Var;
1638    std::get<1>(ref) = nullptr;
1639    std::get<2>(ref) = nullptr;
1640}
1641
1642/** ------------------------------------------------------------------------------------------------------------- *
1643 * @brief make
1644 ** ------------------------------------------------------------------------------------------------------------- */
1645inline Z3_ast BooleanReassociationPass::makeVar() {
1646    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1647    Z3_inc_ref(mContext, node);
1648    mRefs.push_back(node);
1649    return node;
1650}
1651
1652/** ------------------------------------------------------------------------------------------------------------- *
1653 * @brief simplify
1654 ** ------------------------------------------------------------------------------------------------------------- */
1655inline Z3_ast BooleanReassociationPass::simplify(Z3_ast const node, bool use_expensive_minimization) const {
1656    assert (node);
1657    Z3_inc_ref(mContext, node);
1658    Z3_ast result = nullptr;
1659    if (use_expensive_minimization) {
1660
1661        Z3_goal g = Z3_mk_goal(mContext, true, false, false); assert (g);
1662        Z3_goal_inc_ref(mContext, g);
1663        Z3_goal_assert(mContext, g, node);
1664
1665        Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g); assert (r);
1666        Z3_apply_result_inc_ref(mContext, r);
1667        Z3_goal_dec_ref(mContext, g);
1668
1669        assert (Z3_apply_result_get_num_subgoals(mContext, r) == 1);
1670
1671        Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0); assert (h);
1672        Z3_goal_inc_ref(mContext, h);
1673        Z3_apply_result_dec_ref(mContext, r);
1674
1675        const unsigned n = Z3_goal_size(mContext, h);
1676
1677        if (n == 1) {
1678            result = Z3_goal_formula(mContext, h, 0); assert (result);
1679            Z3_inc_ref(mContext, result);
1680        } else if (n > 1) {
1681            Z3_ast operands[n];
1682            for (unsigned i = 0; i < n; ++i) {
1683                operands[i] = Z3_goal_formula(mContext, h, i);
1684                Z3_inc_ref(mContext, operands[i]);
1685            }
1686            result = Z3_mk_and(mContext, n, operands); assert (result);
1687            Z3_inc_ref(mContext, result);
1688            for (unsigned i = 0; i < n; ++i) {
1689                Z3_dec_ref(mContext, operands[i]);
1690            }
1691        } else {
1692            result = Z3_mk_true(mContext); assert (result);
1693        }
1694        Z3_goal_dec_ref(mContext, h);
1695
1696    } else {       
1697        result = Z3_simplify_ex(mContext, node, mParams); assert (result);
1698        Z3_inc_ref(mContext, result);
1699    }   
1700    Z3_dec_ref(mContext, node);
1701    assert (result);
1702    return result;
1703}
1704
1705/** ------------------------------------------------------------------------------------------------------------- *
1706 * @brief add
1707 ** ------------------------------------------------------------------------------------------------------------- */
1708inline Z3_ast BooleanReassociationPass::CharacterizationMap::add(PabloAST * const expr, Z3_ast node, const bool forwardOnly) {
1709    assert (expr && node);
1710    mForward.emplace(expr, node);
1711    if (!forwardOnly) {
1712        mBackward.emplace(node, expr);
1713    }
1714    return node;
1715}
1716
1717/** ------------------------------------------------------------------------------------------------------------- *
1718 * @brief get
1719 ** ------------------------------------------------------------------------------------------------------------- */
1720inline Z3_ast BooleanReassociationPass::CharacterizationMap::get(PabloAST * const expr) const {
1721    assert (expr);
1722    auto f = mForward.find(expr);
1723    if (LLVM_UNLIKELY(f == mForward.end())) {
1724        if (mPredecessor == nullptr) {
1725            return nullptr;
1726        }
1727        return mPredecessor->get(expr);
1728    }
1729    return f->second;
1730}
1731
1732/** ------------------------------------------------------------------------------------------------------------- *
1733 * @brief get
1734 ** ------------------------------------------------------------------------------------------------------------- */
1735inline PabloAST * BooleanReassociationPass::CharacterizationMap::findKey(Z3_ast const node) const {
1736    assert (node);
1737    auto f = mBackward.find(node);
1738    if (LLVM_UNLIKELY(f == mBackward.end())) {
1739        if (mPredecessor == nullptr) {
1740            return nullptr;
1741        }
1742        return mPredecessor->findKey(node);
1743    }
1744    return f->second;
1745}
1746
1747/** ------------------------------------------------------------------------------------------------------------- *
1748 * @brief constructor
1749 ** ------------------------------------------------------------------------------------------------------------- */
1750inline BooleanReassociationPass::BooleanReassociationPass(Z3_context ctx, Z3_params params, Z3_tactic tactic, PabloFunction & f)
1751: mBlock(nullptr)
1752, mContext(ctx)
1753, mParams(params)
1754, mTactic(tactic)
1755, mFunction(f)
1756, mModified(false) {
1757
1758}
1759
1760}
Note: See TracBrowser for help on using the repository browser.