source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 5202

Last change on this file since 5202 was 5202, checked in by nmedfort, 3 years ago

Initial work on adding types to PabloAST and mutable Var objects.

File size: 65.7 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <boost/graph/strong_components.hpp>
8#include <pablo/optimizers/pablo_simplifier.hpp>
9#include <pablo/analysis/pabloverifier.hpp>
10#include <algorithm>
11#include <numeric> // std::iota
12#include <pablo/printer_pablos.h>
13#include <iostream>
14#include <llvm/Support/CommandLine.h>
15#include "maxsat.hpp"
16
17//  noif-dist-mult-dist-50 \p{Cham}(?<!\p{Mc})
18
19using namespace boost;
20using namespace boost::container;
21
22
23static cl::OptionCategory ReassociationOptions("Reassociation Optimization Options", "These options control the Pablo Reassociation optimization pass.");
24
25static cl::opt<unsigned> LoadEarly("Reassociation-Load-Early", cl::init(false),
26                                  cl::desc("When recomputing an Associative operation, load values from preceeding blocks at the beginning of the "
27                                           "Scope Block rather than at the point of first use."),
28                                  cl::cat(ReassociationOptions));
29
30namespace pablo {
31
32using TypeId = PabloAST::ClassTypeId;
33using Graph = BooleanReassociationPass::Graph;
34using Vertex = Graph::vertex_descriptor;
35using VertexData = BooleanReassociationPass::VertexData;
36using DistributionGraph = BooleanReassociationPass::DistributionGraph;
37using DistributionMap = flat_map<Graph::vertex_descriptor, DistributionGraph::vertex_descriptor>;
38using VertexSet = std::vector<Vertex>;
39using VertexSets = std::vector<VertexSet>;
40using Biclique = std::pair<VertexSet, VertexSet>;
41using BicliqueSet = std::vector<Biclique>;
42using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
43using DistributionSets = std::vector<DistributionSet>;
44
45/** ------------------------------------------------------------------------------------------------------------- *
46 * @brief helper functions
47 ** ------------------------------------------------------------------------------------------------------------- */
48template<typename Iterator>
49inline Graph::edge_descriptor first(const std::pair<Iterator, Iterator> & range) {
50    assert (range.first != range.second);
51    return *range.first;
52}
53
54static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock * const block) {
55    return stmt->getParent() == block;
56}
57
58static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock * const block) {
59    return expr ? isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block) : true;
60}
61
62inline TypeId & getType(VertexData & data) {
63    return std::get<0>(data);
64}
65
66inline TypeId getType(const VertexData & data) {
67    return std::get<0>(data);
68}
69
70inline Z3_ast & getDefinition(VertexData & data) {
71    return std::get<2>(data);
72}
73
74inline PabloAST * getValue(const VertexData & data) {
75    return std::get<1>(data);
76}
77
78inline PabloAST *& getValue(VertexData & data) {
79    return std::get<1>(data);
80}
81
82inline bool isAssociative(const VertexData & data) {
83    switch (getType(data)) {
84        case TypeId::And:
85        case TypeId::Or:
86        case TypeId::Xor:
87            return true;
88        default:
89            return false;
90    }
91}
92
93inline bool isDistributive(const VertexData & data) {
94    switch (getType(data)) {
95        case TypeId::And:
96        case TypeId::Or:
97            return true;
98        default:
99            return false;
100    }
101}
102
103void add_edge(PabloAST * expr, const Vertex u, const Vertex v, Graph & G) {
104    // Make sure each edge is unique
105    assert (u < num_vertices(G) && v < num_vertices(G));
106    assert (u != v);
107
108    // Just because we've supplied an expr doesn't mean it's useful. Check it.
109    if (expr) {
110        if (isAssociative(G[v])) {
111            expr = nullptr;
112        } else {
113            bool clear = true;
114            if (const Statement * dest = dyn_cast_or_null<Statement>(getValue(G[v]))) {
115                for (unsigned i = 0; i < dest->getNumOperands(); ++i) {
116                    if (dest->getOperand(i) == expr) {
117                        clear = false;
118                        break;
119                    }
120                }
121            }
122            if (LLVM_LIKELY(clear)) {
123                expr = nullptr;
124            }
125        }
126    }
127
128    for (auto e : make_iterator_range(out_edges(u, G))) {
129        if (LLVM_UNLIKELY(target(e, G) == v)) {
130            if (expr) {
131                if (G[e] == nullptr) {
132                    G[e] = expr;
133                } else if (G[e] != expr) {
134                    continue;
135                }
136            }
137            return;
138        }
139    }
140    G[boost::add_edge(u, v, G).first] = expr;
141}
142
143/** ------------------------------------------------------------------------------------------------------------- *
144 * @brief intersects
145 ** ------------------------------------------------------------------------------------------------------------- */
146template <class Type>
147inline bool intersects(Type & A, Type & B) {
148    auto first1 = A.begin(), last1 = A.end();
149    auto first2 = B.begin(), last2 = B.end();
150    while (first1 != last1 && first2 != last2) {
151        if (*first1 < *first2) {
152            ++first1;
153        } else if (*first2 < *first1) {
154            ++first2;
155        } else {
156            return true;
157        }
158    }
159    return false;
160}
161
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief printGraph
165 ** ------------------------------------------------------------------------------------------------------------- */
166static void printGraph(const Graph & G, const std::string name) {
167    raw_os_ostream out(std::cerr);
168
169    std::vector<unsigned> c(num_vertices(G));
170    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
171
172    out << "digraph " << name << " {\n";
173    for (auto u : make_iterator_range(vertices(G))) {
174        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
175            continue;
176        }
177        out << "v" << u << " [label=\"" << u << ": ";
178        TypeId typeId;
179        PabloAST * expr;
180        Z3_ast node;
181        std::tie(typeId, expr, node) = G[u];
182        bool temporary = false;
183        bool error = false;
184        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
185            temporary = true;
186            switch (typeId) {
187                case TypeId::And:
188                    out << "And";
189                    break;
190                case TypeId::Or:
191                    out << "Or";
192                    break;
193                case TypeId::Xor:
194                    out << "Xor";
195                    break;
196                case TypeId::Not:
197                    out << "Not";
198                    break;
199                default:
200                    out << "???";
201                    error = true;
202                    break;
203            }
204            if (expr) {
205                out << " ("; PabloPrinter::print(expr, out); out << ")";
206            }
207        } else {
208            PabloPrinter::print(expr, out);
209        }
210        if (node == nullptr) {
211            out << " (*)";
212            error = true;
213        }
214        out << "\"";
215        if (typeId == TypeId::Var) {
216            out << " style=dashed";
217        }
218        if (error) {
219            out << " color=red";
220        } else if (temporary) {
221            out << " color=blue";
222        }
223        out << "];\n";
224    }
225    for (auto e : make_iterator_range(edges(G))) {
226        const auto s = source(e, G);
227        const auto t = target(e, G);
228        out << "v" << s << " -> v" << t;
229        bool cyclic = (c[s] == c[t]);
230        if (G[e] || cyclic) {
231            out << " [";
232             if (G[e]) {
233                out << "label=\"";
234                PabloPrinter::print(G[e], out);
235                out << "\" ";
236             }
237             if (cyclic) {
238                out << "color=red ";
239             }
240             out << "]";
241        }
242        out << ";\n";
243    }
244
245    if (num_vertices(G) > 0) {
246
247        out << "{ rank=same;";
248        for (auto u : make_iterator_range(vertices(G))) {
249            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
250                out << " v" << u;
251            }
252        }
253        out << "}\n";
254
255        out << "{ rank=same;";
256        for (auto u : make_iterator_range(vertices(G))) {
257            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
258                out << " v" << u;
259            }
260        }
261        out << "}\n";
262
263    }
264
265    out << "}\n\n";
266    out.flush();
267}
268
269/** ------------------------------------------------------------------------------------------------------------- *
270 * @brief optimize
271 ** ------------------------------------------------------------------------------------------------------------- */
272bool BooleanReassociationPass::optimize(PabloFunction & function) {
273
274    Z3_config cfg = Z3_mk_config();
275    Z3_context ctx = Z3_mk_context_rc(cfg);
276    Z3_del_config(cfg);
277
278    Z3_params params = Z3_mk_params(ctx);
279    Z3_params_inc_ref(ctx, params);
280    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "pull_cheap_ite"), true);
281    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "local_ctx"), true);
282
283    Z3_tactic ctx_solver_simplify = Z3_mk_tactic(ctx, "ctx-solver-simplify");
284    Z3_tactic_inc_ref(ctx, ctx_solver_simplify);
285
286    BooleanReassociationPass brp(ctx, params, ctx_solver_simplify, function);
287    brp.processScopes(function);
288
289    Z3_params_dec_ref(ctx, params);
290    Z3_tactic_dec_ref(ctx, ctx_solver_simplify);
291    Z3_del_context(ctx);
292
293    PabloVerifier::verify(function, "post-reassociation");
294
295    Simplifier::optimize(function);
296
297    return true;
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief processScopes
302 ** ------------------------------------------------------------------------------------------------------------- */
303inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
304    CharacterizationMap C;
305    PabloBlock * const entry = function.getEntryBlock();
306    // Map the constants and input variables
307    C.add(entry->createZeroes(), Z3_mk_false(mContext));
308    C.add(entry->createOnes(), Z3_mk_true(mContext));
309    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
310        C.add(mFunction.getParameter(i), makeVar());
311    }
312    mInFile = makeVar();
313    processScopes(entry, C);
314    for (auto i = mRefs.begin(); i != mRefs.end(); ++i) {
315        Z3_dec_ref(mContext, *i);
316    }
317    mRefs.clear();
318    return mModified;
319}
320
321/** ------------------------------------------------------------------------------------------------------------- *
322 * @brief processScopes
323 ** ------------------------------------------------------------------------------------------------------------- */
324void BooleanReassociationPass::processScopes(PabloBlock * const block, CharacterizationMap & C) {
325    const auto offset = mRefs.size();
326    for (Statement * stmt = block->front(); stmt; ) {
327        if (LLVM_UNLIKELY(isa<Branch>(stmt))) {
328            if (LLVM_UNLIKELY(isa<Zeroes>(cast<Branch>(stmt)->getCondition()))) {
329                stmt = stmt->eraseFromParent(true);
330            } else {
331                CharacterizationMap nested(C);
332                processScopes(cast<Branch>(stmt)->getBody(), nested);
333                for (Var * def : cast<Branch>(stmt)->getEscaped()) {
334                    C.add(def, makeVar());
335                }
336                stmt = stmt->getNextNode();
337            }
338        } else { // characterize this statement then check whether it is equivalent to any existing one.
339            PabloAST * const folded = Simplifier::fold(stmt, block);
340            if (LLVM_UNLIKELY(folded != nullptr)) {
341                stmt = stmt->replaceWith(folded);
342            } else {
343                stmt = characterize(stmt, C);
344            }
345        }
346    }   
347    distributeScope(block, C);
348    for (auto i = mRefs.begin() + offset; i != mRefs.end(); ++i) {
349        Z3_dec_ref(mContext, *i);
350    }
351    mRefs.erase(mRefs.begin() + offset, mRefs.end());
352}
353
354/** ------------------------------------------------------------------------------------------------------------- *
355 * @brief characterize
356 ** ------------------------------------------------------------------------------------------------------------- */
357inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & C) {
358
359    Z3_ast node = nullptr;
360    const size_t n = stmt->getNumOperands(); assert (n > 0);
361    bool use_expensive_simplification = false;
362    if (isa<Variadic>(stmt)) {
363        Z3_ast operands[n];
364        for (size_t i = 0; i < n; ++i) {
365            PabloAST * const op = stmt->getOperand(i);
366            if (isa<Not>(op)) {
367                use_expensive_simplification = true;
368            }
369            operands[i] = C.get(op); assert (operands[i]);
370        }
371        if (isa<And>(stmt)) {
372            node = Z3_mk_and(mContext, n, operands);
373        } else if (isa<Or>(stmt)) {
374            node = Z3_mk_or(mContext, n, operands);
375        } else if (isa<Xor>(stmt)) {
376            node = Z3_mk_xor(mContext, operands[0], operands[1]);
377            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
378                Z3_inc_ref(mContext, node);
379                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
380                Z3_inc_ref(mContext, temp);
381                Z3_dec_ref(mContext, node);
382                node = temp;
383            }
384        }
385    } else if (isa<Not>(stmt)) {
386        Z3_ast op = C.get(stmt->getOperand(0)); assert (op);
387        node = Z3_mk_not(mContext, op);
388    } else if (isa<Sel>(stmt)) {
389        Z3_ast operands[3];
390        for (size_t i = 0; i < 3; ++i) {
391            operands[i] = C.get(stmt->getOperand(i)); assert (operands[i]);
392        }
393        node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
394    } else if (LLVM_UNLIKELY(isa<InFile>(stmt) || isa<AtEOF>(stmt))) {
395        assert (stmt->getNumOperands() == 1);
396        Z3_ast check[2];
397        check[0] = C.get(stmt->getOperand(0)); assert (check[0]);
398        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
399        node = Z3_mk_and(mContext, 2, check);
400    } else if (LLVM_UNLIKELY(isa<Assign>(stmt))) {
401        return stmt->getNextNode();
402    }  else {
403        C.add(stmt, makeVar());
404        return stmt->getNextNode();
405    }
406    node = simplify(node, use_expensive_simplification);
407    PabloAST * const replacement = C.findKey(node);
408    if (LLVM_LIKELY(replacement == nullptr)) {
409        C.add(stmt, node);
410        mRefs.push_back(node);
411        return stmt->getNextNode();
412    } else {
413        Z3_dec_ref(mContext, node);
414        return stmt->replaceWith(replacement);
415    }
416}
417
418/** ------------------------------------------------------------------------------------------------------------- *
419 * @brief processScope
420 ** ------------------------------------------------------------------------------------------------------------- */
421inline void BooleanReassociationPass::distributeScope(PabloBlock * const block, CharacterizationMap & C) {
422    Graph G;
423    try {
424        mBlock = block;
425        transformAST(C, G);
426    } catch (std::runtime_error err) {
427        printGraph(G, "E");
428        throw err;
429    } catch (std::exception err) {
430        printGraph(G, "E");
431        throw err;
432    }
433}
434
435/** ------------------------------------------------------------------------------------------------------------- *
436 * @brief summarizeAST
437 *
438 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
439 * are "flattened" (i.e., allowed to have any number of inputs.)
440 ** ------------------------------------------------------------------------------------------------------------- */
441Vertex BooleanReassociationPass::transcribeSel(Sel * const stmt, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
442
443    Z3_ast args[2];
444
445    const Vertex c = makeVertex(TypeId::Var, stmt->getCondition(), C, S, M, G);
446    const Vertex t = makeVertex(TypeId::Var, cast<Sel>(stmt)->getTrueExpr(), C, S, M, G);
447    const Vertex f = makeVertex(TypeId::Var, cast<Sel>(stmt)->getFalseExpr(), C, S, M, G);
448
449    args[0] = getDefinition(G[c]);
450    args[1] = getDefinition(G[t]);
451
452    Z3_ast trueExpr = Z3_mk_and(mContext, 2, args);
453    Z3_inc_ref(mContext, trueExpr);
454    mRefs.push_back(trueExpr);
455
456    const Vertex x = makeVertex(TypeId::And, nullptr, G, trueExpr);
457    add_edge(nullptr, c, x, G);
458    add_edge(nullptr, t, x, G);
459
460    Z3_ast notCond = Z3_mk_not(mContext, args[0]);
461    Z3_inc_ref(mContext, notCond);
462    mRefs.push_back(notCond);
463
464    args[0] = notCond;
465    args[1] = getDefinition(G[f]);
466
467    Z3_ast falseExpr = Z3_mk_and(mContext, 2, args);
468    Z3_inc_ref(mContext, falseExpr);
469    mRefs.push_back(falseExpr);
470
471    const Vertex n = makeVertex(TypeId::Not, nullptr, G, notCond);
472
473    add_edge(nullptr, c, n, G);
474
475    const Vertex y = makeVertex(TypeId::And, nullptr, G, falseExpr);
476    add_edge(nullptr, n, y, G);
477    add_edge(nullptr, f, y, G);
478
479    const Vertex u = makeVertex(TypeId::Or, stmt, C, S, M, G);
480    add_edge(nullptr, x, u, G);
481    add_edge(nullptr, y, u, G);
482
483    return u;
484}
485
486/** ------------------------------------------------------------------------------------------------------------- *
487 * @brief summarizeAST
488 *
489 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
490 * are "flattened" (i.e., allowed to have any number of inputs.)
491 ** ------------------------------------------------------------------------------------------------------------- */
492void BooleanReassociationPass::transformAST(CharacterizationMap & C, Graph & G) {
493
494    StatementMap S;
495
496    VertexMap M;
497
498    // Compute the base def-use graph ...
499    for (Statement * stmt : *mBlock) {
500        if (LLVM_UNLIKELY(isa<Sel>(stmt))) {
501
502            const Vertex u = transcribeSel(cast<Sel>(stmt), C, S, M, G);
503
504            resolveNestedUsages(stmt, u, C, S, M, G, stmt);
505
506        } else {
507
508
509            const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, C, S, M, G);
510            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
511                PabloAST * const op = stmt->getOperand(i);
512                if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
513                    add_edge(op, makeVertex(TypeId::Var, op, C, S, M, G), u, G);
514                }
515            }
516            if (LLVM_UNLIKELY(isa<If>(stmt))) {
517                for (Var * def : cast<If>(stmt)->getEscaped()) {
518                    const Vertex v = makeVertex(TypeId::Var, def, C, S, M, G);
519                    add_edge(def, u, v, G);
520                    resolveNestedUsages(def, v, C, S, M, G, stmt);
521                }
522                continue;
523            } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
524                // To keep G a DAG, we need to do a bit of surgery on loop variants because
525                // the next variables it produces can be used within the condition. Instead,
526                // we make the loop dependent on the original value of each Next node and
527                // the Next node dependent on the loop.
528                for (Var * var : cast<While>(stmt)->getEscaped()) {
529                    const Vertex v = makeVertex(TypeId::Var, var, C, S, M, G);
530                    assert (in_degree(v, G) == 1);
531                    auto e = first(in_edges(v, G));
532                    add_edge(G[e], source(e, G), u, G);
533                    remove_edge(v, u, G);
534                    add_edge(var, u, v, G);
535                    resolveNestedUsages(var, v, C, S, M, G, stmt);
536                }
537                continue;
538            } else {
539                resolveNestedUsages(stmt, u, C, S, M, G, stmt);
540            }
541        }
542
543    }
544
545    if (redistributeGraph(C, M, G)) {
546        factorGraph(G);
547        rewriteAST(G);
548        mModified = true;
549    }
550
551}
552
553/** ------------------------------------------------------------------------------------------------------------- *
554 * @brief resolveNestedUsages
555 ** ------------------------------------------------------------------------------------------------------------- */
556void BooleanReassociationPass::resolveNestedUsages(PabloAST * const expr, const Vertex u,
557                                                   CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G,
558                                                   const Statement * const ignoreIfThis) const {
559    assert ("Cannot resolve nested usages of a null expression!" && expr);
560    for (PabloAST * user : expr->users()) { assert (user);
561        if (LLVM_LIKELY(user != expr && isa<Statement>(user))) {
562            PabloBlock * parent = cast<Statement>(user)->getParent(); assert (parent);
563            if (LLVM_UNLIKELY(parent != mBlock)) {
564                for (;;) {
565                    if (parent->getPredecessor() == mBlock) {
566                        Statement * const branch = parent->getBranch();
567                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
568                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
569                            const Vertex v = makeVertex(TypeId::Var, user, C, S, M, G);
570                            add_edge(expr, u, v, G);
571                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
572                            add_edge(user, v, w, G);
573                        }
574                        break;
575                    }
576                    parent = parent->getPredecessor();
577                    if (LLVM_UNLIKELY(parent == nullptr)) {
578                        assert (isa<Assign>(expr));
579                        break;
580                    }
581                }
582            }
583        }
584    }
585}
586
587
588
589/** ------------------------------------------------------------------------------------------------------------- *
590 * @brief enumerateBicliques
591 *
592 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
593 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
594 * bipartition A and their adjacencies to be in B.
595  ** ------------------------------------------------------------------------------------------------------------- */
596template <class Graph>
597static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
598    using IntersectionSets = std::set<VertexSet>;
599
600    IntersectionSets B1;
601    for (auto u : A) {
602        if (in_degree(u, G) > 0) {
603            VertexSet adjacencies;
604            adjacencies.reserve(in_degree(u, G));
605            for (auto e : make_iterator_range(in_edges(u, G))) {
606                adjacencies.push_back(source(e, G));
607            }
608            std::sort(adjacencies.begin(), adjacencies.end());
609            assert(std::unique(adjacencies.begin(), adjacencies.end()) == adjacencies.end());
610            B1.insert(std::move(adjacencies));
611        }
612    }
613
614    IntersectionSets B(B1);
615
616    IntersectionSets Bi;
617
618    VertexSet clique;
619    for (auto i = B1.begin(); i != B1.end(); ++i) {
620        for (auto j = i; ++j != B1.end(); ) {
621            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
622            if (clique.size() > 0) {
623                if (B.count(clique) == 0) {
624                    Bi.insert(clique);
625                }
626                clique.clear();
627            }
628        }
629    }
630
631    for (;;) {
632        if (Bi.empty()) {
633            break;
634        }
635        B.insert(Bi.begin(), Bi.end());
636        IntersectionSets Bk;
637        for (auto i = B1.begin(); i != B1.end(); ++i) {
638            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
639                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
640                if (clique.size() > 0) {
641                    if (B.count(clique) == 0) {
642                        Bk.insert(clique);
643                    }
644                    clique.clear();
645                }
646            }
647        }
648        Bi.swap(Bk);
649    }
650
651    BicliqueSet cliques;
652    cliques.reserve(B.size());
653    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
654        VertexSet Ai(A);
655        for (const Vertex u : *Bi) {
656            VertexSet Aj;
657            Aj.reserve(out_degree(u, G));
658            for (auto e : make_iterator_range(out_edges(u, G))) {
659                Aj.push_back(target(e, G));
660            }
661            std::sort(Aj.begin(), Aj.end());
662            assert(std::unique(Aj.begin(), Aj.end()) == Aj.end());
663            VertexSet Ak;
664            Ak.reserve(std::min(Ai.size(), Aj.size()));
665            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
666            Ai.swap(Ak);
667        }
668        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
669        cliques.emplace_back(std::move(Ai), std::move(*Bi));
670    }
671    return std::move(cliques);
672}
673
674/** ------------------------------------------------------------------------------------------------------------- *
675 * @brief independentCliqueSets
676 ** ------------------------------------------------------------------------------------------------------------- */
677template <unsigned side>
678inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
679    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
680
681    const auto l = cliques.size();
682    IndependentSetGraph I(l);
683
684    // Initialize our weights
685    for (unsigned i = 0; i != l; ++i) {
686        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
687    }
688
689    // Determine our constraints
690    for (unsigned i = 0; i != l; ++i) {
691        for (unsigned j = i + 1; j != l; ++j) {
692            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
693                add_edge(i, j, I);
694            }
695        }
696    }
697
698    // Use the greedy algorithm to choose our independent set
699    VertexSet selected;
700    for (;;) {
701        unsigned w = 0;
702        Vertex u = 0;
703        for (unsigned i = 0; i != l; ++i) {
704            if (I[i] > w) {
705                w = I[i];
706                u = i;
707            }
708        }
709        if (w < minimum) break;
710        selected.push_back(u);
711        I[u] = 0;
712        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
713            I[v] = 0;
714        }
715    }
716
717    // Sort the selected list and then remove the unselected cliques
718    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
719    auto end = cliques.end();
720    for (const unsigned offset : selected) {
721        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
722    }
723    cliques.erase(cliques.begin(), end);
724
725    return std::move(cliques);
726}
727
728/** ------------------------------------------------------------------------------------------------------------- *
729 * @brief removeUnhelpfulBicliques
730 *
731 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
732 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
733 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
734 ** ------------------------------------------------------------------------------------------------------------- */
735static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G, DistributionGraph & H) {
736    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
737        const auto cardinalityA = std::get<0>(*ci).size();
738        VertexSet & B = std::get<1>(*ci);
739        for (auto bi = B.begin(); bi != B.end(); ) {
740            if (out_degree(H[*bi], G) == cardinalityA) {
741                ++bi;
742            } else {
743                bi = B.erase(bi);
744            }
745        }
746        if (B.size() > 1) {
747            ++ci;
748        } else {
749            ci = cliques.erase(ci);
750        }
751    }
752    return std::move(cliques);
753}
754
755/** ------------------------------------------------------------------------------------------------------------- *
756 * @brief safeDistributionSets
757 ** ------------------------------------------------------------------------------------------------------------- */
758static DistributionSets safeDistributionSets(const Graph & G, DistributionGraph & H) {
759
760    VertexSet sinks;
761    for (const Vertex u : make_iterator_range(vertices(H))) {
762        if (out_degree(u, H) == 0 && in_degree(u, H) != 0) {
763            sinks.push_back(u);
764        }
765    }
766    std::sort(sinks.begin(), sinks.end());
767
768    DistributionSets T;
769    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(H, sinks), G, H), 1);
770    for (Biclique & lower : lowerSet) {
771        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(H, std::get<1>(lower)), 2);
772        for (Biclique & upper : upperSet) {
773            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
774        }
775    }
776    return std::move(T);
777}
778
779/** ------------------------------------------------------------------------------------------------------------- *
780 * @brief getVertex
781 ** ------------------------------------------------------------------------------------------------------------- */
782template<typename ValueType, typename GraphType, typename MapType>
783static inline Vertex getVertex(const ValueType value, GraphType & G, MapType & M) {
784    const auto f = M.find(value);
785    if (f != M.end()) {
786        return f->second;
787    }
788    const auto u = add_vertex(value, G);
789    M.insert(std::make_pair(value, u));
790    return u;
791}
792
793/** ------------------------------------------------------------------------------------------------------------- *
794 * @brief generateDistributionGraph
795 ** ------------------------------------------------------------------------------------------------------------- */
796void generateDistributionGraph(const Graph & G, DistributionGraph & H) {
797    DistributionMap M;
798    for (const Vertex u : make_iterator_range(vertices(G))) {
799        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
800            continue;
801        } else if (isDistributive(G[u])) {
802            TypeId outerTypeId = getType(G[u]);
803            TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
804            flat_set<Vertex> distributable;
805            for (auto e : make_iterator_range(in_edges(u, G))) {
806                const Vertex v = source(e, G);
807                if (LLVM_UNLIKELY(getType(G[v]) == innerTypeId)) {
808                    bool safe = true;
809                    for (const auto e : make_iterator_range(out_edges(v, G))) {
810                        if (getType(G[target(e, G)]) != outerTypeId) {
811                            safe = false;
812                            break;
813                        }
814                    }
815                    if (safe) {
816                        distributable.insert(v);
817                    }
818                }
819            }
820            if (LLVM_LIKELY(distributable.size() > 1)) {
821                flat_set<Vertex> observed;
822                for (const Vertex v : distributable) {
823                    for (auto e : make_iterator_range(in_edges(v, G))) {
824                        const auto v = source(e, G);
825                        observed.insert(v);
826                    }
827                }
828                for (const Vertex w : observed) {
829                    for (auto e : make_iterator_range(out_edges(w, G))) {
830                        const Vertex v = target(e, G);
831                        if (distributable.count(v)) {
832                            const Vertex y = getVertex(v, H, M);
833                            boost::add_edge(y, getVertex(u, H, M), H);
834                            boost::add_edge(getVertex(w, H, M), y, H);
835                        }
836                    }
837                }
838            }
839        }
840    }
841}
842
843/** ------------------------------------------------------------------------------------------------------------- *
844 * @brief recomputeDefinition
845 ** ------------------------------------------------------------------------------------------------------------- */
846inline Z3_ast BooleanReassociationPass::computeDefinition(TypeId typeId, const Vertex u, Graph & G, const bool use_expensive_minimization) const {
847    const unsigned n = in_degree(u, G);
848    Z3_ast operands[n];
849    unsigned k = 0;
850    for (const auto e : make_iterator_range(in_edges(u, G))) {
851        const auto v = source(e, G);
852        if (LLVM_UNLIKELY(getDefinition(G[v]) == nullptr)) {
853            throw std::runtime_error("No definition for " + std::to_string(v));
854        }
855        operands[k++] = getDefinition(G[v]);
856    }
857    assert (k == n);
858    Z3_ast const node = (typeId == TypeId::And) ? Z3_mk_and(mContext, n, operands) : Z3_mk_or(mContext, n, operands);
859    return simplify(node, use_expensive_minimization);
860}
861
862/** ------------------------------------------------------------------------------------------------------------- *
863 * @brief updateDefinition
864 *
865 * Apply the distribution law to reduce computations whenever possible.
866 ** ------------------------------------------------------------------------------------------------------------- */
867Vertex BooleanReassociationPass::updateIntermediaryDefinition(TypeId typeId, const Vertex u, VertexMap & M, Graph & G) {
868
869    Z3_ast def = computeDefinition(typeId, u, G);
870    Z3_ast orig = getDefinition(G[u]); assert (orig);
871
872    Z3_dec_ref(mContext, orig);
873
874    const auto g = M.find(orig);
875    if (LLVM_LIKELY(g != M.end())) {
876        M.erase(g);
877    }
878
879    const auto f = std::find(mRefs.rbegin(), mRefs.rend(), orig);
880    assert (f != mRefs.rend());
881    *f = def;
882
883    const auto h = M.find(def);
884    if (LLVM_UNLIKELY(h != M.end())) {
885        const auto v = h->second;
886        if (v != u) {
887            for (auto e : make_iterator_range(out_edges(u, G))) {
888                add_edge(G[e], v, target(e, G), G);
889            }
890            removeVertex(u, G);
891            return v;
892        }
893    }
894
895    getDefinition(G[u]) = def;
896    M.emplace(def, u);
897    return u;
898}
899
900/** ------------------------------------------------------------------------------------------------------------- *
901 * @brief updateDefinition
902 *
903 * Apply the distribution law to reduce computations whenever possible.
904 ** ------------------------------------------------------------------------------------------------------------- */
905Vertex BooleanReassociationPass::updateSinkDefinition(TypeId typeId, const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G) {
906
907    Z3_ast const def = computeDefinition(typeId, u, G);
908
909    auto f = M.find(def);
910
911    if (LLVM_UNLIKELY(f != M.end())) {
912        Z3_dec_ref(mContext, def);
913        Vertex v = f->second; assert (v != u);
914        for (auto e : make_iterator_range(out_edges(u, G))) {
915            add_edge(G[e], v, target(e, G), G);
916        }
917        removeVertex(u, G);
918        return v;
919    } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
920        PabloAST * const factor = C.predecessor()->findKey(def);
921        if (LLVM_UNLIKELY(factor != nullptr)) {
922            getValue(G[u]) = factor;
923            getType(G[u]) = TypeId::Var;
924            clear_in_edges(u, G);
925        }
926    }
927
928    getDefinition(G[u]) = def;
929    mRefs.push_back(def);
930
931    graph_traits<Graph>::in_edge_iterator begin, end;
932
933restart:
934
935    if (in_degree(u, G) > 1) {
936        std::tie(begin, end) = in_edges(u, G);
937        for (auto i = begin; ++i != end; ) {
938            const auto v = source(*i, G);
939            for (auto j = begin; j != i; ++j) {
940                const auto w = source(*j, G);
941                Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
942                Z3_ast test = nullptr;
943                switch (typeId) {
944                    case TypeId::And:
945                        test = Z3_mk_and(mContext, 2, operands); break;
946                    case TypeId::Or:
947                        test = Z3_mk_or(mContext, 2, operands); break;
948                    case TypeId::Xor:
949                        test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
950                    default:
951                        llvm_unreachable("impossible type id");
952                }
953                test = simplify(test, true);
954
955                bool replacement = false;
956                Vertex x = 0;
957                const auto f = M.find(test);
958                if (LLVM_UNLIKELY(f != M.end())) {
959                    x = f->second;
960                    Z3_ast orig = getDefinition(G[x]);
961                    if (LLVM_UNLIKELY(orig != test)) {
962                        std::string tmp;
963                        raw_string_ostream out(tmp);
964                        out << "vertex " << x << " is mapped to:\n"
965                            << Z3_ast_to_string(mContext, test)
966                            << "\n\nBut is recorded as:\n\n";
967                        if (orig) {
968                            out << Z3_ast_to_string(mContext, orig);
969                        } else {
970                            out << "<null>";
971                        }
972                        throw std::runtime_error(out.str());
973                    }
974                    Z3_dec_ref(mContext, test);
975                    replacement = true;
976                } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
977                    PabloAST * const factor = C.predecessor()->findKey(test);
978                    if (LLVM_UNLIKELY(factor != nullptr)) {
979                        x = makeVertex(TypeId::Var, factor, G, test);
980                        M.emplace(test, x);
981                        replacement = true;
982                        mRefs.push_back(test);
983                    }
984                }
985
986                if (LLVM_UNLIKELY(replacement)) {
987
988                    assert (G[*i] == nullptr);
989                    assert (G[*j] == nullptr);
990
991                    remove_edge(*i, G);
992                    remove_edge(*j, G);
993
994                    add_edge(nullptr, x, u, G);
995
996                    goto restart;
997                }
998
999                Z3_dec_ref(mContext, test);
1000            }
1001        }
1002    }
1003
1004    M.emplace(def, u);
1005
1006    return u;
1007}
1008
1009/** ------------------------------------------------------------------------------------------------------------- *
1010 * @brief redistributeAST
1011 *
1012 * Apply the distribution law to reduce computations whenever possible.
1013 ** ------------------------------------------------------------------------------------------------------------- */
1014bool BooleanReassociationPass::redistributeGraph(CharacterizationMap & C, VertexMap & M, Graph & G) {
1015
1016    bool modified = false;
1017
1018//    errs() << "=====================================================\n";
1019
1020    DistributionGraph H;
1021
1022    for (;;) {
1023
1024        contractGraph(M, G);
1025
1026//        printGraph(G, "G");
1027
1028        generateDistributionGraph(G, H);
1029
1030        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
1031        if (num_vertices(H) == 0) {
1032            break;
1033        }
1034
1035        const DistributionSets distributionSets = safeDistributionSets(G, H);
1036
1037        if (LLVM_UNLIKELY(distributionSets.empty())) {
1038            break;
1039        }
1040
1041        modified = true;
1042
1043        mRefs.reserve(distributionSets.size() * 2);
1044
1045        for (const DistributionSet & set : distributionSets) {
1046
1047            // Each distribution tuple consists of the sources, intermediary, and sink nodes.
1048            const VertexSet & sources = std::get<0>(set);
1049            const VertexSet & intermediary = std::get<1>(set);
1050            const VertexSet & sinks = std::get<2>(set);
1051
1052            TypeId outerTypeId = getType(G[H[sinks.front()]]);
1053            assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
1054            TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
1055
1056            const Vertex x = makeVertex(outerTypeId, nullptr, G);
1057            const Vertex y = makeVertex(innerTypeId, nullptr, G);
1058
1059            // Update G to reflect the distributed operations (including removing the subgraph of
1060            // the to-be distributed edges.)
1061
1062            add_edge(nullptr, x, y, G);
1063
1064            for (const Vertex i : sources) {
1065                const auto u = H[i];
1066                for (const Vertex j : intermediary) {
1067                    const auto v = H[j];
1068                    assert (getType(G[v]) == innerTypeId);
1069                    const auto e = edge(u, v, G); assert (e.second);
1070                    remove_edge(e.first, G);
1071                }
1072                add_edge(nullptr, u, y, G);
1073            }
1074
1075            for (const Vertex i : intermediary) {
1076
1077                const auto u = updateIntermediaryDefinition(innerTypeId, H[i], M, G);
1078
1079                for (const Vertex j : sinks) {
1080                    const auto v = H[j];
1081                    assert (getType(G[v]) == outerTypeId);
1082                    const auto e = edge(u, v, G); assert (e.second);
1083                    add_edge(G[e.first], y, v, G);
1084                    remove_edge(e.first, G);
1085                }
1086                add_edge(nullptr, u, x, G);
1087            }
1088
1089            updateSinkDefinition(outerTypeId, x, C, M, G);
1090
1091            updateSinkDefinition(innerTypeId, y, C, M, G);
1092
1093        }
1094
1095        H.clear();
1096
1097    }
1098
1099    return modified;
1100}
1101
1102/** ------------------------------------------------------------------------------------------------------------- *
1103 * @brief isNonEscaping
1104 ** ------------------------------------------------------------------------------------------------------------- */
1105inline bool isNonEscaping(const VertexData & data) {
1106    // If these are redundant, the Simplifier pass will eliminate them. Trust that they're necessary.
1107    switch (getType(data)) {
1108        case TypeId::Assign:
1109        case TypeId::If:
1110        case TypeId::While:
1111        case TypeId::Count:
1112            return false;
1113        default:
1114            return true;
1115    }
1116}
1117
1118/** ------------------------------------------------------------------------------------------------------------- *
1119 * @brief unique_source
1120 ** ------------------------------------------------------------------------------------------------------------- */
1121inline bool has_unique_source(const Vertex u, const Graph & G) {
1122    if (in_degree(u, G) > 0) {
1123        graph_traits<Graph>::in_edge_iterator i, end;
1124        std::tie(i, end) = in_edges(u, G);
1125        const Vertex v = source(*i, G);
1126        while (++i != end) {
1127            if (source(*i, G) != v) {
1128                return false;
1129            }
1130        }
1131        return true;
1132    }
1133    return false;
1134}
1135
1136/** ------------------------------------------------------------------------------------------------------------- *
1137 * @brief unique_target
1138 ** ------------------------------------------------------------------------------------------------------------- */
1139inline bool has_unique_target(const Vertex u, const Graph & G) {
1140    if (out_degree(u, G) > 0) {
1141        graph_traits<Graph>::out_edge_iterator i, end;
1142        std::tie(i, end) = out_edges(u, G);
1143        const Vertex v = target(*i, G);
1144        while (++i != end) {
1145            if (target(*i, G) != v) {
1146                return false;
1147            }
1148        }
1149        return true;
1150    }
1151    return false;
1152}
1153
1154/** ------------------------------------------------------------------------------------------------------------- *
1155 * @brief contractGraph
1156 ** ------------------------------------------------------------------------------------------------------------- */
1157bool BooleanReassociationPass::contractGraph(VertexMap & M, Graph & G) const {
1158
1159    bool contracted = false;
1160
1161    circular_buffer<Vertex> ordering(num_vertices(G));
1162
1163    topological_sort(G, std::back_inserter(ordering)); // reverse topological ordering
1164
1165    // first contract the graph
1166    for (const Vertex u : ordering) {
1167        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
1168            continue;
1169        } else if (LLVM_LIKELY(out_degree(u, G) > 0)) {
1170            if (isAssociative(G[u])) {
1171                if (LLVM_UNLIKELY(has_unique_source(u, G))) {
1172                    // We have a redundant node here that'll simply end up being a duplicate
1173                    // of the input value. Remove it and add any of its outgoing edges to its
1174                    // input node.
1175                    const auto ei = first(in_edges(u, G));
1176                    const Vertex v = source(ei, G);
1177                    for (auto ej : make_iterator_range(out_edges(u, G))) {
1178                        add_edge(G[ej], v, target(ej, G), G);
1179                    }
1180                    removeVertex(u, M, G);
1181                    contracted = true;
1182                } else if (LLVM_UNLIKELY(has_unique_target(u, G))) {
1183                    // Otherwise if we have a single user, we have a similar case as above but
1184                    // we can only merge this vertex into the outgoing instruction if they are
1185                    // of the same type.
1186                    const auto ei = first(out_edges(u, G));
1187                    const Vertex v = target(ei, G);
1188                    if (LLVM_UNLIKELY(getType(G[v]) == getType(G[u]))) {
1189                        for (auto ej : make_iterator_range(in_edges(u, G))) {
1190                            add_edge(G[ej], source(ej, G), v, G);
1191                        }
1192                        removeVertex(u, M, G);
1193                        contracted = true;
1194                    }
1195                }
1196            }
1197        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
1198            removeVertex(u, M, G);
1199            contracted = true;
1200        }
1201    }
1202    return contracted;
1203}
1204
1205/** ------------------------------------------------------------------------------------------------------------- *
1206 * @brief factorGraph
1207 ** ------------------------------------------------------------------------------------------------------------- */
1208bool BooleanReassociationPass::factorGraph(TypeId typeId, Graph & G, std::vector<Vertex> & factors) const {
1209
1210    if (LLVM_UNLIKELY(factors.empty())) {
1211        return false;
1212    }
1213
1214    std::vector<Vertex> I, J, K;
1215
1216    bool modified = false;
1217
1218    for (unsigned i = 1; i < factors.size(); ++i) {
1219        assert (getType(G[factors[i]]) == typeId);
1220        for (auto ei : make_iterator_range(in_edges(factors[i], G))) {
1221            I.push_back(source(ei, G));
1222        }
1223        std::sort(I.begin(), I.end());
1224        for (unsigned j = 0; j < i; ++j) {
1225            for (auto ej : make_iterator_range(in_edges(factors[j], G))) {
1226                J.push_back(source(ej, G));
1227            }
1228            std::sort(J.begin(), J.end());
1229            // get the pairwise intersection of each set of inputs (i.e., their common subexpression)
1230            std::set_intersection(I.begin(), I.end(), J.begin(), J.end(), std::back_inserter(K));
1231            assert (std::is_sorted(K.begin(), K.end()));
1232            // if the intersection contains at least two elements
1233            const auto n = K.size();
1234            if (n > 1) {
1235                Vertex a = factors[i];
1236                Vertex b = factors[j];
1237                if (LLVM_UNLIKELY(in_degree(a, G) == n || in_degree(b, G) == n)) {
1238                    if (in_degree(a, G) != n) {
1239                        assert (in_degree(b, G) == n);
1240                        std::swap(a, b);
1241                    }
1242                    assert (in_degree(a, G) == n);
1243                    if (in_degree(b, G) == n) {
1244                        for (auto e : make_iterator_range(out_edges(b, G))) {
1245                            add_edge(G[e], a, target(e, G), G);
1246                        }
1247                        removeVertex(b, G);
1248                    } else {
1249                        for (auto u : K) {
1250                            remove_edge(u, b, G);
1251                        }
1252                        add_edge(nullptr, a, b, G);
1253                    }
1254                } else {
1255                    Vertex v = makeVertex(typeId, nullptr, G);
1256                    for (auto u : K) {
1257                        remove_edge(u, a, G);
1258                        remove_edge(u, b, G);
1259                        add_edge(nullptr, u, v, G);
1260                    }
1261                    add_edge(nullptr, v, a, G);
1262                    add_edge(nullptr, v, b, G);
1263                    factors.push_back(v);
1264                }
1265                modified = true;
1266            }
1267            K.clear();
1268            J.clear();
1269        }
1270        I.clear();
1271    }
1272    return modified;
1273}
1274
1275/** ------------------------------------------------------------------------------------------------------------- *
1276 * @brief factorGraph
1277 ** ------------------------------------------------------------------------------------------------------------- */
1278bool BooleanReassociationPass::factorGraph(Graph & G) const {
1279    // factor the associative vertices.
1280    std::vector<Vertex> factors;
1281    bool factored = false;
1282    for (unsigned i = 0; i < 3; ++i) {
1283        TypeId typeId[3] = { TypeId::And, TypeId::Or, TypeId::Xor};
1284        for (auto j : make_iterator_range(vertices(G))) {
1285            if (getType(G[j]) == typeId[i]) {
1286                factors.push_back(j);
1287            }
1288        }
1289        if (factorGraph(typeId[i], G, factors)) {
1290            factored = true;
1291        }
1292        factors.clear();
1293    }
1294    return factored;
1295}
1296
1297/** ------------------------------------------------------------------------------------------------------------- *
1298 * @brief isMutable
1299 ** ------------------------------------------------------------------------------------------------------------- */
1300inline bool isMutable(const Vertex u, const Graph & G) {
1301    return getType(G[u]) != TypeId::Var;
1302}
1303
1304/** ------------------------------------------------------------------------------------------------------------- *
1305 * @brief rewriteAST
1306 ** ------------------------------------------------------------------------------------------------------------- */
1307bool BooleanReassociationPass::rewriteAST(Graph & G) {
1308
1309    using line_t = long long int;
1310
1311    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
1312
1313    // errs() << "---------------------------------------------------------\n";
1314
1315    // printGraph(G, "X");
1316
1317    Z3_config cfg = Z3_mk_config();
1318    Z3_set_param_value(cfg, "model", "true");
1319    Z3_context ctx = Z3_mk_context(cfg);
1320    Z3_del_config(cfg);
1321    Z3_solver solver = Z3_mk_solver(ctx);
1322    Z3_solver_inc_ref(ctx, solver);
1323
1324    std::vector<Z3_ast> mapping(num_vertices(G), nullptr);
1325
1326    flat_map<PabloAST *, Z3_ast> V;
1327
1328    // Generate the variables
1329    const auto ty = Z3_mk_int_sort(ctx);
1330    Z3_ast ZERO = Z3_mk_int(ctx, 0, ty);
1331
1332    for (const Vertex u : make_iterator_range(vertices(G))) {
1333        const auto var = Z3_mk_fresh_const(ctx, nullptr, ty);
1334        Z3_ast constraint = nullptr;
1335        if (in_degree(u, G) > 0) {
1336            constraint = Z3_mk_gt(ctx, var, ZERO);
1337            Z3_solver_assert(ctx, solver, constraint);
1338        }       
1339        PabloAST * const expr = getValue(G[u]);
1340        if (inCurrentBlock(expr, mBlock)) {
1341            V.emplace(expr, var);
1342        }
1343        mapping[u] = var;
1344    }
1345
1346    // Add in the dependency constraints
1347    for (const Vertex u : make_iterator_range(vertices(G))) {
1348        Z3_ast const t = mapping[u];
1349        for (auto e : make_iterator_range(in_edges(u, G))) {
1350            Z3_ast const s = mapping[source(e, G)];
1351            Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, s, t));
1352        }
1353    }
1354
1355    // Compute the soft ordering constraints
1356    std::vector<Z3_ast> ordering(0);
1357    ordering.reserve(V.size() - 1);
1358
1359    Z3_ast prior = nullptr;
1360    unsigned gap = 1;
1361    for (Statement * stmt : *mBlock) {
1362        auto f = V.find(stmt);
1363        if (f != V.end()) {
1364            Z3_ast const node = f->second;
1365            if (prior) {
1366//                ordering.push_back(Z3_mk_lt(ctx, prior, node)); // increases the cost by 6 - 10x
1367                Z3_ast ops[2] = { node, prior };
1368                ordering.push_back(Z3_mk_le(ctx, Z3_mk_sub(ctx, 2, ops), Z3_mk_int(ctx, gap, ty)));
1369            } else {
1370                ordering.push_back(Z3_mk_eq(ctx, node, Z3_mk_int(ctx, gap, ty)));
1371            }
1372            prior = node;
1373            gap = 0;
1374        }
1375        ++gap;
1376    }
1377
1378    if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) < 0)) {
1379        throw std::runtime_error("Unable to construct a topological ordering during reassociation!");
1380    }
1381
1382    Z3_model model = Z3_solver_get_model(ctx, solver);
1383    Z3_model_inc_ref(ctx, model);
1384
1385    std::vector<Vertex> S(0);
1386    S.reserve(num_vertices(G));
1387
1388    std::vector<line_t> L(num_vertices(G));
1389
1390    for (const Vertex u : make_iterator_range(vertices(G))) {
1391        line_t line = LoadEarly ? 0 : MAX_INT;
1392        if (isMutable(u, G)) {
1393            Z3_ast value;
1394            if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, mapping[u], Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1395                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1396            }
1397            if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1398                throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1399            }
1400            S.push_back(u);
1401        }
1402        L[u] = line;
1403    }
1404
1405    Z3_model_dec_ref(ctx, model);
1406    Z3_solver_dec_ref(ctx, solver);
1407    Z3_del_context(ctx);
1408
1409    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1410
1411    mBlock->setInsertPoint(nullptr);
1412
1413    std::vector<Vertex> T;
1414
1415    line_t count = 1;
1416
1417    for (auto u : S) {
1418        PabloAST *& stmt = getValue(G[u]);
1419
1420        assert (isMutable(u, G));
1421        assert (L[u] > 0 && L[u] < MAX_INT);
1422
1423        bool append = true;
1424
1425        if (isAssociative(G[u])) {
1426
1427            if (in_degree(u, G) == 0 || out_degree(u, G) == 0) {
1428                throw std::runtime_error("Vertex " + std::to_string(u) + " is either a source or sink node but marked as associative!");
1429            }
1430
1431            Statement * ip = mBlock->getInsertPoint();
1432            ip = ip ? ip->getNextNode() : mBlock->front();
1433
1434            const auto typeId = getType(G[u]);
1435
1436            T.clear();
1437            T.reserve(in_degree(u, G));
1438            for (const auto e : make_iterator_range(in_edges(u, G))) {
1439                T.push_back(source(e, G));
1440            }
1441
1442            // Then sort them by their line position (noting any incoming value will either be 0 or MAX_INT)
1443            std::sort(T.begin(), T.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1444
1445            if (LoadEarly) {
1446                mBlock->setInsertPoint(nullptr);
1447            }
1448
1449            PabloAST * expr = nullptr;
1450            PabloAST * join = nullptr;
1451
1452            for (auto v : T) {
1453                expr = getValue(G[v]);
1454                if (LLVM_UNLIKELY(expr == nullptr)) {
1455                    throw std::runtime_error("Vertex " + std::to_string(v) + " does not have an expression!");
1456                }
1457                if (join) {
1458
1459                    if (in_degree(v, G) > 0) {
1460
1461                        assert (L[v] > 0 && L[v] < MAX_INT);
1462
1463                        Statement * dom = cast<Statement>(expr);
1464                        for (;;) {
1465                            PabloBlock * const parent = dom->getParent();
1466                            if (parent == mBlock) {
1467                                break;
1468                            }
1469                            dom = parent->getBranch(); assert(dom);
1470                        }
1471                        mBlock->setInsertPoint(dom);
1472
1473                        assert (dominates(join, expr));
1474                        assert (dominates(expr, ip));
1475                        assert (dominates(dom, ip));
1476                    }
1477
1478                    switch (typeId) {
1479                        case TypeId::And:
1480                            expr = mBlock->createAnd(join, expr); break;
1481                        case TypeId::Or:
1482                            expr = mBlock->createOr(join, expr); break;
1483                        case TypeId::Xor:
1484                            expr = mBlock->createXor(join, expr); break;
1485                        default:
1486                            llvm_unreachable("Invalid TypeId!");
1487                    }
1488                }
1489                join = expr;
1490            }
1491
1492            assert (expr);
1493
1494            Statement * const currIP = mBlock->getInsertPoint();
1495
1496            mBlock->setInsertPoint(ip->getPrevNode());
1497
1498            stmt = expr;
1499
1500            // If the insertion point isn't the statement we just attempted to create,
1501            // we must have unexpectidly reused a prior statement (or var.)
1502            if (LLVM_UNLIKELY(expr != currIP)) {
1503                L[u] = 0;
1504                // If that statement happened to be in the current block, locate which
1505                // statement it "duplicated" (if any) and set the duplicate's line # to
1506                // the line of the original statement to maintain a consistent view of
1507                // the program.
1508                if (inCurrentBlock(expr, mBlock)) {
1509                    for (auto v : make_iterator_range(vertices(G))) {
1510                        if (LLVM_UNLIKELY(getValue(G[v]) == expr)) {
1511                            L[u] = L[v];
1512                            break;
1513                        }
1514                    }
1515                }
1516                append = false;
1517            }
1518        } else if (stmt == nullptr) {
1519            assert (getType(G[u]) == TypeId::Not);
1520            assert (in_degree(u, G) == 1);
1521            PabloAST * op = getValue(G[source(first(in_edges(u, G)), G)]); assert (op);
1522            stmt = mBlock->createNot(op);
1523        } else if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
1524            for (auto e : make_iterator_range(out_edges(u, G))) {
1525                const auto v = target(e, G);
1526                assert (L[v] == std::numeric_limits<line_t>::max());
1527                L[v] = count;
1528            }
1529        }
1530
1531        for (auto e : make_iterator_range(out_edges(u, G))) {
1532            if (G[e]) {
1533                if (PabloAST * user = getValue(G[target(e, G)])) {
1534                    cast<Statement>(user)->replaceUsesOfWith(G[e], stmt);
1535                }
1536            }
1537        }
1538
1539        if (LLVM_LIKELY(append)) {
1540            mBlock->insert(cast<Statement>(stmt));
1541            L[u] = count++; // update the line count with the actual one.
1542        }
1543    }
1544
1545    Statement * const end = mBlock->getInsertPoint(); assert (end);
1546    for (;;) {
1547        Statement * const next = end->getNextNode();
1548        if (next == nullptr) {
1549            break;
1550        }
1551
1552        #ifndef NDEBUG
1553        for (PabloAST * user : next->users()) {
1554            if (isa<Statement>(user) && dominates(user, next)) {
1555                std::string tmp;
1556                raw_string_ostream out(tmp);
1557                out << "Erasing ";
1558                PabloPrinter::print(next, out);
1559                out << " erroneously modifies live statement ";
1560                PabloPrinter::print(cast<Statement>(user), out);
1561                throw std::runtime_error(out.str());
1562            }
1563        }
1564        #endif
1565        next->eraseFromParent(true);
1566    }
1567
1568    #ifndef NDEBUG
1569    PabloVerifier::verify(mFunction, "mid-reassociation");
1570    #endif
1571
1572    return true;
1573}
1574
1575/** ------------------------------------------------------------------------------------------------------------- *
1576* @brief makeVertex
1577** ------------------------------------------------------------------------------------------------------------- */
1578Vertex BooleanReassociationPass::makeVertex(TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
1579    assert (expr);
1580    const auto f = S.find(expr);
1581    if (f != S.end()) {
1582        assert (getValue(G[f->second]) == expr);
1583        return f->second;
1584    }
1585    const auto node = C.get(expr);   
1586    const Vertex u = makeVertex(typeId, expr, G, node);
1587    S.emplace(expr, u);
1588    if (node) {
1589        M.emplace(node, u);
1590    }
1591    return u;
1592}
1593
1594/** ------------------------------------------------------------------------------------------------------------- *
1595 * @brief makeVertex
1596 ** ------------------------------------------------------------------------------------------------------------- */
1597Vertex BooleanReassociationPass::makeVertex(TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
1598    assert (expr);
1599    const auto f = M.find(expr);
1600    if (f != M.end()) {
1601        assert (getValue(G[f->second]) == expr);
1602        return f->second;
1603    }
1604    const Vertex u = makeVertex(typeId, expr, G, node);
1605    M.emplace(expr, u);
1606    return u;
1607}
1608
1609/** ------------------------------------------------------------------------------------------------------------- *
1610 * @brief makeVertex
1611 ** ------------------------------------------------------------------------------------------------------------- */
1612Vertex BooleanReassociationPass::makeVertex(TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
1613//    for (Vertex u : make_iterator_range(vertices(G))) {
1614//        if (LLVM_UNLIKELY(in_degree(u, G) == 0 && out_degree(u, G) == 0)) {
1615//            std::get<0>(G[u]) = typeId;
1616//            std::get<1>(G[u]) = expr;
1617//            return u;
1618//        }
1619//    }
1620    return add_vertex(std::make_tuple(typeId, expr, node), G);
1621}
1622
1623/** ------------------------------------------------------------------------------------------------------------- *
1624 * @brief removeVertex
1625 ** ------------------------------------------------------------------------------------------------------------- */
1626inline void BooleanReassociationPass::removeVertex(const Vertex u, VertexMap & M, Graph & G) const {
1627    VertexData & ref = G[u];
1628    Z3_ast def = getDefinition(ref); assert (def);
1629    auto f = M.find(def); assert (f != M.end());
1630    M.erase(f);
1631    removeVertex(u, G);
1632}
1633
1634/** ------------------------------------------------------------------------------------------------------------- *
1635 * @brief removeVertex
1636 ** ------------------------------------------------------------------------------------------------------------- */
1637inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
1638    VertexData & ref = G[u];
1639    clear_vertex(u, G);
1640    std::get<0>(ref) = TypeId::Var;
1641    std::get<1>(ref) = nullptr;
1642    std::get<2>(ref) = nullptr;
1643}
1644
1645/** ------------------------------------------------------------------------------------------------------------- *
1646 * @brief make
1647 ** ------------------------------------------------------------------------------------------------------------- */
1648inline Z3_ast BooleanReassociationPass::makeVar() {
1649    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1650    Z3_inc_ref(mContext, node);
1651    mRefs.push_back(node);
1652    return node;
1653}
1654
1655/** ------------------------------------------------------------------------------------------------------------- *
1656 * @brief simplify
1657 ** ------------------------------------------------------------------------------------------------------------- */
1658inline Z3_ast BooleanReassociationPass::simplify(Z3_ast const node, bool use_expensive_minimization) const {
1659    assert (node);
1660    Z3_inc_ref(mContext, node);
1661    Z3_ast result = nullptr;
1662    if (use_expensive_minimization) {
1663
1664        Z3_goal g = Z3_mk_goal(mContext, true, false, false); assert (g);
1665        Z3_goal_inc_ref(mContext, g);
1666        Z3_goal_assert(mContext, g, node);
1667
1668        Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g); assert (r);
1669        Z3_apply_result_inc_ref(mContext, r);
1670        Z3_goal_dec_ref(mContext, g);
1671
1672        assert (Z3_apply_result_get_num_subgoals(mContext, r) == 1);
1673
1674        Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0); assert (h);
1675        Z3_goal_inc_ref(mContext, h);
1676        Z3_apply_result_dec_ref(mContext, r);
1677
1678        const unsigned n = Z3_goal_size(mContext, h);
1679
1680        if (n == 1) {
1681            result = Z3_goal_formula(mContext, h, 0); assert (result);
1682            Z3_inc_ref(mContext, result);
1683        } else if (n > 1) {
1684            Z3_ast operands[n];
1685            for (unsigned i = 0; i < n; ++i) {
1686                operands[i] = Z3_goal_formula(mContext, h, i);
1687                Z3_inc_ref(mContext, operands[i]);
1688            }
1689            result = Z3_mk_and(mContext, n, operands); assert (result);
1690            Z3_inc_ref(mContext, result);
1691            for (unsigned i = 0; i < n; ++i) {
1692                Z3_dec_ref(mContext, operands[i]);
1693            }
1694        } else {
1695            result = Z3_mk_true(mContext); assert (result);
1696        }
1697        Z3_goal_dec_ref(mContext, h);
1698
1699    } else {       
1700        result = Z3_simplify_ex(mContext, node, mParams); assert (result);
1701        Z3_inc_ref(mContext, result);
1702    }   
1703    Z3_dec_ref(mContext, node);
1704    assert (result);
1705    return result;
1706}
1707
1708/** ------------------------------------------------------------------------------------------------------------- *
1709 * @brief add
1710 ** ------------------------------------------------------------------------------------------------------------- */
1711inline Z3_ast BooleanReassociationPass::CharacterizationMap::add(PabloAST * const expr, Z3_ast node, const bool forwardOnly) {
1712    assert (expr && node);
1713    mForward.emplace(expr, node);
1714    if (!forwardOnly) {
1715        mBackward.emplace(node, expr);
1716    }
1717    return node;
1718}
1719
1720/** ------------------------------------------------------------------------------------------------------------- *
1721 * @brief get
1722 ** ------------------------------------------------------------------------------------------------------------- */
1723inline Z3_ast BooleanReassociationPass::CharacterizationMap::get(PabloAST * const expr) const {
1724    assert (expr);
1725    auto f = mForward.find(expr);
1726    if (LLVM_UNLIKELY(f == mForward.end())) {
1727        if (mPredecessor == nullptr) {
1728            return nullptr;
1729        }
1730        return mPredecessor->get(expr);
1731    }
1732    return f->second;
1733}
1734
1735/** ------------------------------------------------------------------------------------------------------------- *
1736 * @brief get
1737 ** ------------------------------------------------------------------------------------------------------------- */
1738inline PabloAST * BooleanReassociationPass::CharacterizationMap::findKey(Z3_ast const node) const {
1739    assert (node);
1740    auto f = mBackward.find(node);
1741    if (LLVM_UNLIKELY(f == mBackward.end())) {
1742        if (mPredecessor == nullptr) {
1743            return nullptr;
1744        }
1745        return mPredecessor->findKey(node);
1746    }
1747    return f->second;
1748}
1749
1750/** ------------------------------------------------------------------------------------------------------------- *
1751 * @brief constructor
1752 ** ------------------------------------------------------------------------------------------------------------- */
1753inline BooleanReassociationPass::BooleanReassociationPass(Z3_context ctx, Z3_params params, Z3_tactic tactic, PabloFunction & f)
1754: mContext(ctx)
1755, mParams(params)
1756, mTactic(tactic)
1757, mFunction(f)
1758, mModified(false) {
1759
1760}
1761
1762}
Note: See TracBrowser for help on using the repository browser.