source: icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp @ 5160

Last change on this file since 5160 was 5160, checked in by nmedfort, 3 years ago

Initial work for incorporating Types into Pablo AST.

File size: 66.3 KB
Line 
1#include "booleanreassociationpass.h"
2#include <boost/container/flat_set.hpp>
3#include <boost/container/flat_map.hpp>
4#include <boost/circular_buffer.hpp>
5#include <boost/graph/adjacency_list.hpp>
6#include <boost/graph/topological_sort.hpp>
7#include <boost/graph/strong_components.hpp>
8#include <pablo/optimizers/pablo_simplifier.hpp>
9#include <pablo/analysis/pabloverifier.hpp>
10#include <algorithm>
11#include <numeric> // std::iota
12#include <pablo/printer_pablos.h>
13#include <iostream>
14#include <llvm/Support/CommandLine.h>
15#include "maxsat.hpp"
16
17//  noif-dist-mult-dist-50 \p{Cham}(?<!\p{Mc})
18
19using namespace boost;
20using namespace boost::container;
21
22
23static cl::OptionCategory ReassociationOptions("Reassociation Optimization Options", "These options control the Pablo Reassociation optimization pass.");
24
25static cl::opt<unsigned> LoadEarly("Reassociation-Load-Early", cl::init(false),
26                                  cl::desc("When recomputing an Associative operation, load values from preceeding blocks at the beginning of the "
27                                           "Scope Block rather than at the point of first use."),
28                                  cl::cat(ReassociationOptions));
29
30namespace pablo {
31
32using TypeId = PabloAST::ClassTypeId;
33using Graph = BooleanReassociationPass::Graph;
34using Vertex = Graph::vertex_descriptor;
35using VertexData = BooleanReassociationPass::VertexData;
36using DistributionGraph = BooleanReassociationPass::DistributionGraph;
37using DistributionMap = flat_map<Graph::vertex_descriptor, DistributionGraph::vertex_descriptor>;
38using VertexSet = std::vector<Vertex>;
39using VertexSets = std::vector<VertexSet>;
40using Biclique = std::pair<VertexSet, VertexSet>;
41using BicliqueSet = std::vector<Biclique>;
42using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
43using DistributionSets = std::vector<DistributionSet>;
44
45/** ------------------------------------------------------------------------------------------------------------- *
46 * @brief helper functions
47 ** ------------------------------------------------------------------------------------------------------------- */
48template<typename Iterator>
49inline Graph::edge_descriptor first(const std::pair<Iterator, Iterator> & range) {
50    assert (range.first != range.second);
51    return *range.first;
52}
53
54static inline bool inCurrentBlock(const Statement * stmt, const PabloBlock * const block) {
55    return stmt->getParent() == block;
56}
57
58static inline bool inCurrentBlock(const PabloAST * expr, const PabloBlock * const block) {
59    return expr ? isa<Statement>(expr) && inCurrentBlock(cast<Statement>(expr), block) : true;
60}
61
62inline TypeId & getType(VertexData & data) {
63    return std::get<0>(data);
64}
65
66inline TypeId getType(const VertexData & data) {
67    return std::get<0>(data);
68}
69
70inline Z3_ast & getDefinition(VertexData & data) {
71    return std::get<2>(data);
72}
73
74inline PabloAST * getValue(const VertexData & data) {
75    return std::get<1>(data);
76}
77
78inline PabloAST *& getValue(VertexData & data) {
79    return std::get<1>(data);
80}
81
82inline bool isAssociative(const VertexData & data) {
83    switch (getType(data)) {
84        case TypeId::And:
85        case TypeId::Or:
86        case TypeId::Xor:
87            return true;
88        default:
89            return false;
90    }
91}
92
93inline bool isDistributive(const VertexData & data) {
94    switch (getType(data)) {
95        case TypeId::And:
96        case TypeId::Or:
97            return true;
98        default:
99            return false;
100    }
101}
102
103void add_edge(PabloAST * expr, const Vertex u, const Vertex v, Graph & G) {
104    // Make sure each edge is unique
105    assert (u < num_vertices(G) && v < num_vertices(G));
106    assert (u != v);
107
108    // Just because we've supplied an expr doesn't mean it's useful. Check it.
109    if (expr) {
110        if (isAssociative(G[v])) {
111            expr = nullptr;
112        } else {
113            bool clear = true;
114            if (const Statement * dest = dyn_cast_or_null<Statement>(getValue(G[v]))) {
115                for (unsigned i = 0; i < dest->getNumOperands(); ++i) {
116                    if (dest->getOperand(i) == expr) {
117                        clear = false;
118                        break;
119                    }
120                }
121            }
122            if (LLVM_LIKELY(clear)) {
123                expr = nullptr;
124            }
125        }
126    }
127
128    for (auto e : make_iterator_range(out_edges(u, G))) {
129        if (LLVM_UNLIKELY(target(e, G) == v)) {
130            if (expr) {
131                if (G[e] == nullptr) {
132                    G[e] = expr;
133                } else if (G[e] != expr) {
134                    continue;
135                }
136            }
137            return;
138        }
139    }
140    G[boost::add_edge(u, v, G).first] = expr;
141}
142
143/** ------------------------------------------------------------------------------------------------------------- *
144 * @brief intersects
145 ** ------------------------------------------------------------------------------------------------------------- */
146template <class Type>
147inline bool intersects(const Type & A, const Type & B) {
148    auto first1 = A.begin(), last1 = A.end();
149    auto first2 = B.begin(), last2 = B.end();
150    while (first1 != last1 && first2 != last2) {
151        if (*first1 < *first2) {
152            ++first1;
153        } else if (*first2 < *first1) {
154            ++first2;
155        } else {
156            return true;
157        }
158    }
159    return false;
160}
161
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief printGraph
165 ** ------------------------------------------------------------------------------------------------------------- */
166static void printGraph(const Graph & G, const std::string name) {
167    raw_os_ostream out(std::cerr);
168
169    std::vector<unsigned> c(num_vertices(G));
170    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
171
172    out << "digraph " << name << " {\n";
173    for (auto u : make_iterator_range(vertices(G))) {
174        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
175            continue;
176        }
177        out << "v" << u << " [label=\"" << u << ": ";
178        TypeId typeId;
179        PabloAST * expr;
180        Z3_ast node;
181        std::tie(typeId, expr, node) = G[u];
182        bool temporary = false;
183        bool error = false;
184        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
185            temporary = true;
186            switch (typeId) {
187                case TypeId::And:
188                    out << "And";
189                    break;
190                case TypeId::Or:
191                    out << "Or";
192                    break;
193                case TypeId::Xor:
194                    out << "Xor";
195                    break;
196                case TypeId::Not:
197                    out << "Not";
198                    break;
199                default:
200                    out << "???";
201                    error = true;
202                    break;
203            }
204            if (expr) {
205                out << " ("; PabloPrinter::print(expr, out); out << ")";
206            }
207        } else {
208            PabloPrinter::print(expr, out);
209        }
210        if (node == nullptr) {
211            out << " (*)";
212            error = true;
213        }
214        out << "\"";
215        if (typeId == TypeId::Var) {
216            out << " style=dashed";
217        }
218        if (error) {
219            out << " color=red";
220        } else if (temporary) {
221            out << " color=blue";
222        }
223        out << "];\n";
224    }
225    for (auto e : make_iterator_range(edges(G))) {
226        const auto s = source(e, G);
227        const auto t = target(e, G);
228        out << "v" << s << " -> v" << t;
229        bool cyclic = (c[s] == c[t]);
230        if (G[e] || cyclic) {
231            out << " [";
232             if (G[e]) {
233                out << "label=\"";
234                PabloPrinter::print(G[e], out);
235                out << "\" ";
236             }
237             if (cyclic) {
238                out << "color=red ";
239             }
240             out << "]";
241        }
242        out << ";\n";
243    }
244
245    if (num_vertices(G) > 0) {
246
247        out << "{ rank=same;";
248        for (auto u : make_iterator_range(vertices(G))) {
249            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
250                out << " v" << u;
251            }
252        }
253        out << "}\n";
254
255        out << "{ rank=same;";
256        for (auto u : make_iterator_range(vertices(G))) {
257            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
258                out << " v" << u;
259            }
260        }
261        out << "}\n";
262
263    }
264
265    out << "}\n\n";
266    out.flush();
267}
268
269/** ------------------------------------------------------------------------------------------------------------- *
270 * @brief optimize
271 ** ------------------------------------------------------------------------------------------------------------- */
272bool BooleanReassociationPass::optimize(PabloFunction & function) {
273
274    Z3_config cfg = Z3_mk_config();
275    Z3_context ctx = Z3_mk_context_rc(cfg);
276    Z3_del_config(cfg);
277
278    Z3_params params = Z3_mk_params(ctx);
279    Z3_params_inc_ref(ctx, params);
280    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "pull_cheap_ite"), true);
281    Z3_params_set_bool(ctx, params, Z3_mk_string_symbol(ctx, "local_ctx"), true);
282
283    Z3_tactic ctx_solver_simplify = Z3_mk_tactic(ctx, "ctx-solver-simplify");
284    Z3_tactic_inc_ref(ctx, ctx_solver_simplify);
285
286    BooleanReassociationPass brp(ctx, params, ctx_solver_simplify, function);
287    brp.processScopes(function);
288
289    Z3_params_dec_ref(ctx, params);
290    Z3_tactic_dec_ref(ctx, ctx_solver_simplify);
291    Z3_del_context(ctx);
292
293    PabloVerifier::verify(function, "post-reassociation");
294
295    Simplifier::optimize(function);
296
297    return true;
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief processScopes
302 ** ------------------------------------------------------------------------------------------------------------- */
303inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
304    CharacterizationMap C;
305    PabloBlock * const entry = function.getEntryBlock();
306    // Map the constants and input variables
307    C.add(entry->createZeroes(), Z3_mk_false(mContext));
308    C.add(entry->createOnes(), Z3_mk_true(mContext));
309    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
310        C.add(mFunction.getParameter(i), makeVar());
311    }
312    mInFile = makeVar();
313    processScopes(entry, C);
314    for (auto i = mRefs.begin(); i != mRefs.end(); ++i) {
315        Z3_dec_ref(mContext, *i);
316    }
317    mRefs.clear();
318    return mModified;
319}
320
321/** ------------------------------------------------------------------------------------------------------------- *
322 * @brief processScopes
323 ** ------------------------------------------------------------------------------------------------------------- */
324void BooleanReassociationPass::processScopes(PabloBlock * const block, CharacterizationMap & C) {
325    const auto offset = mRefs.size();
326    for (Statement * stmt = block->front(); stmt; ) {
327        if (LLVM_UNLIKELY(isa<If>(stmt))) {
328            if (LLVM_UNLIKELY(isa<Zeroes>(cast<If>(stmt)->getCondition()))) {
329                stmt = stmt->eraseFromParent(true);
330            } else {
331                CharacterizationMap nested(C);
332                processScopes(cast<If>(stmt)->getBody(), nested);
333                for (Assign * def : cast<If>(stmt)->getDefined()) {
334                    C.add(def, makeVar());
335                }
336                stmt = stmt->getNextNode();
337            }
338        } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
339            if (LLVM_UNLIKELY(isa<Zeroes>(cast<While>(stmt)->getCondition()))) {
340                stmt = stmt->eraseFromParent(true);
341            } else {
342                CharacterizationMap nested(C);
343                processScopes(cast<While>(stmt)->getBody(), nested);
344                for (Next * var : cast<While>(stmt)->getVariants()) {
345                    C.add(var, makeVar());
346                }
347                stmt = stmt->getNextNode();
348            }
349        } else { // characterize this statement then check whether it is equivalent to any existing one.
350            PabloAST * const folded = Simplifier::fold(stmt, block);
351            if (LLVM_UNLIKELY(folded != nullptr)) {
352                stmt = stmt->replaceWith(folded);
353            } else {
354                stmt = characterize(stmt, C);
355            }
356        }
357    }   
358    distributeScope(block, C);
359    for (auto i = mRefs.begin() + offset; i != mRefs.end(); ++i) {
360        Z3_dec_ref(mContext, *i);
361    }
362    mRefs.erase(mRefs.begin() + offset, mRefs.end());
363}
364
365/** ------------------------------------------------------------------------------------------------------------- *
366 * @brief characterize
367 ** ------------------------------------------------------------------------------------------------------------- */
368inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & C) {
369
370    Z3_ast node = nullptr;
371    const size_t n = stmt->getNumOperands(); assert (n > 0);
372    bool use_expensive_simplification = false;
373    if (isa<Variadic>(stmt)) {
374        Z3_ast operands[n];
375        for (size_t i = 0; i < n; ++i) {
376            PabloAST * const op = stmt->getOperand(i);
377            if (isa<Not>(op)) {
378                use_expensive_simplification = true;
379            }
380            operands[i] = C.get(op); assert (operands[i]);
381        }
382        if (isa<And>(stmt)) {
383            node = Z3_mk_and(mContext, n, operands);
384        } else if (isa<Or>(stmt)) {
385            node = Z3_mk_or(mContext, n, operands);
386        } else if (isa<Xor>(stmt)) {
387            node = Z3_mk_xor(mContext, operands[0], operands[1]);
388            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
389                Z3_inc_ref(mContext, node);
390                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
391                Z3_inc_ref(mContext, temp);
392                Z3_dec_ref(mContext, node);
393                node = temp;
394            }
395        }
396    } else if (isa<Not>(stmt)) {
397        Z3_ast op = C.get(stmt->getOperand(0)); assert (op);
398        node = Z3_mk_not(mContext, op);
399    } else if (isa<Sel>(stmt)) {
400        Z3_ast operands[3];
401        for (size_t i = 0; i < 3; ++i) {
402            operands[i] = C.get(stmt->getOperand(i)); assert (operands[i]);
403        }
404        node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
405    } else if (LLVM_UNLIKELY(isa<InFile>(stmt) || isa<AtEOF>(stmt))) {
406        assert (stmt->getNumOperands() == 1);
407        Z3_ast check[2];
408        check[0] = C.get(stmt->getOperand(0)); assert (check[0]);
409        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
410        node = Z3_mk_and(mContext, 2, check);
411    } else if (LLVM_UNLIKELY(isa<Assign>(stmt) || isa<Next>(stmt))) {
412        return stmt->getNextNode();
413    }  else {
414        C.add(stmt, makeVar());
415        return stmt->getNextNode();
416    }
417    node = simplify(node, use_expensive_simplification);
418    PabloAST * const replacement = C.findKey(node);
419    if (LLVM_LIKELY(replacement == nullptr)) {
420        C.add(stmt, node);
421        mRefs.push_back(node);
422        return stmt->getNextNode();
423    } else {
424        Z3_dec_ref(mContext, node);
425        return stmt->replaceWith(replacement);
426    }
427}
428
429/** ------------------------------------------------------------------------------------------------------------- *
430 * @brief processScope
431 ** ------------------------------------------------------------------------------------------------------------- */
432inline void BooleanReassociationPass::distributeScope(PabloBlock * const block, CharacterizationMap & C) {
433    Graph G;
434    try {
435        mBlock = block;
436        transformAST(C, G);
437    } catch (std::runtime_error err) {
438        printGraph(G, "E");
439        throw err;
440    } catch (std::exception err) {
441        printGraph(G, "E");
442        throw err;
443    }
444}
445
446/** ------------------------------------------------------------------------------------------------------------- *
447 * @brief summarizeAST
448 *
449 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
450 * are "flattened" (i.e., allowed to have any number of inputs.)
451 ** ------------------------------------------------------------------------------------------------------------- */
452Vertex BooleanReassociationPass::transcribeSel(Sel * const stmt, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
453
454    Z3_ast args[2];
455
456    const Vertex c = makeVertex(TypeId::Var, stmt->getCondition(), C, S, M, G);
457    const Vertex t = makeVertex(TypeId::Var, cast<Sel>(stmt)->getTrueExpr(), C, S, M, G);
458    const Vertex f = makeVertex(TypeId::Var, cast<Sel>(stmt)->getFalseExpr(), C, S, M, G);
459
460    args[0] = getDefinition(G[c]);
461    args[1] = getDefinition(G[t]);
462
463    Z3_ast trueExpr = Z3_mk_and(mContext, 2, args);
464    Z3_inc_ref(mContext, trueExpr);
465    mRefs.push_back(trueExpr);
466
467    const Vertex x = makeVertex(TypeId::And, nullptr, G, trueExpr);
468    add_edge(nullptr, c, x, G);
469    add_edge(nullptr, t, x, G);
470
471    Z3_ast notCond = Z3_mk_not(mContext, args[0]);
472    Z3_inc_ref(mContext, notCond);
473    mRefs.push_back(notCond);
474
475    args[0] = notCond;
476    args[1] = getDefinition(G[f]);
477
478    Z3_ast falseExpr = Z3_mk_and(mContext, 2, args);
479    Z3_inc_ref(mContext, falseExpr);
480    mRefs.push_back(falseExpr);
481
482    const Vertex n = makeVertex(TypeId::Not, nullptr, G, notCond);
483
484    add_edge(nullptr, c, n, G);
485
486    const Vertex y = makeVertex(TypeId::And, nullptr, G, falseExpr);
487    add_edge(nullptr, n, y, G);
488    add_edge(nullptr, f, y, G);
489
490    const Vertex u = makeVertex(TypeId::Or, stmt, C, S, M, G);
491    add_edge(nullptr, x, u, G);
492    add_edge(nullptr, y, u, G);
493
494    return u;
495}
496
497/** ------------------------------------------------------------------------------------------------------------- *
498 * @brief summarizeAST
499 *
500 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
501 * are "flattened" (i.e., allowed to have any number of inputs.)
502 ** ------------------------------------------------------------------------------------------------------------- */
503void BooleanReassociationPass::transformAST(CharacterizationMap & C, Graph & G) {
504
505    StatementMap S;
506
507    VertexMap M;
508
509    // Compute the base def-use graph ...
510    for (Statement * stmt : *mBlock) {
511        if (LLVM_UNLIKELY(isa<Sel>(stmt))) {
512
513            const Vertex u = transcribeSel(cast<Sel>(stmt), C, S, M, G);
514
515            resolveNestedUsages(stmt, u, C, S, M, G, stmt);
516
517        } else {
518
519
520            const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, C, S, M, G);
521            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
522                PabloAST * const op = stmt->getOperand(i);
523                if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
524                    add_edge(op, makeVertex(TypeId::Var, op, C, S, M, G), u, G);
525                }
526            }
527            if (LLVM_UNLIKELY(isa<If>(stmt))) {
528                for (Assign * def : cast<const If>(stmt)->getDefined()) {
529                    const Vertex v = makeVertex(TypeId::Var, def, C, S, M, G);
530                    add_edge(def, u, v, G);
531                    resolveNestedUsages(def, v, C, S, M, G, stmt);
532                }
533                continue;
534            } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
535                // To keep G a DAG, we need to do a bit of surgery on loop variants because
536                // the next variables it produces can be used within the condition. Instead,
537                // we make the loop dependent on the original value of each Next node and
538                // the Next node dependent on the loop.
539                for (Next * var : cast<const While>(stmt)->getVariants()) {
540                    const Vertex v = makeVertex(TypeId::Var, var, C, S, M, G);
541                    assert (in_degree(v, G) == 1);
542                    auto e = first(in_edges(v, G));
543                    add_edge(G[e], source(e, G), u, G);
544                    remove_edge(v, u, G);
545                    add_edge(var, u, v, G);
546                    resolveNestedUsages(var, v, C, S, M, G, stmt);
547                }
548                continue;
549            } else {
550                resolveNestedUsages(stmt, u, C, S, M, G, stmt);
551            }
552        }
553
554    }
555
556    if (redistributeGraph(C, M, G)) {
557        factorGraph(G);
558        rewriteAST(G);
559        mModified = true;
560    }
561
562}
563
564/** ------------------------------------------------------------------------------------------------------------- *
565 * @brief resolveNestedUsages
566 ** ------------------------------------------------------------------------------------------------------------- */
567void BooleanReassociationPass::resolveNestedUsages(PabloAST * const expr, const Vertex u,
568                                                   CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G,
569                                                   const Statement * const ignoreIfThis) const {
570    assert ("Cannot resolve nested usages of a null expression!" && expr);
571    for (PabloAST * user : expr->users()) { assert (user);
572        if (LLVM_LIKELY(user != expr && isa<Statement>(user))) {
573            PabloBlock * parent = cast<Statement>(user)->getParent(); assert (parent);
574            if (LLVM_UNLIKELY(parent != mBlock)) {
575                for (;;) {
576                    if (parent->getPredecessor () == mBlock) {
577                        Statement * const branch = parent->getBranch();
578                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
579                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
580                            const Vertex v = makeVertex(TypeId::Var, user, C, S, M, G);
581                            add_edge(expr, u, v, G);
582                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
583                            add_edge(user, v, w, G);
584                        }
585                        break;
586                    }
587                    parent = parent->getPredecessor ();
588                    if (LLVM_UNLIKELY(parent == nullptr)) {
589                        assert (isa<Assign>(expr) || isa<Next>(expr));
590                        break;
591                    }
592                }
593            }
594        }
595    }
596}
597
598
599
600/** ------------------------------------------------------------------------------------------------------------- *
601 * @brief enumerateBicliques
602 *
603 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
604 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
605 * bipartition A and their adjacencies to be in B.
606  ** ------------------------------------------------------------------------------------------------------------- */
607template <class Graph>
608static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
609    using IntersectionSets = std::set<VertexSet>;
610
611    IntersectionSets B1;
612    for (auto u : A) {
613        if (in_degree(u, G) > 0) {
614            VertexSet adjacencies;
615            adjacencies.reserve(in_degree(u, G));
616            for (auto e : make_iterator_range(in_edges(u, G))) {
617                adjacencies.push_back(source(e, G));
618            }
619            std::sort(adjacencies.begin(), adjacencies.end());
620            assert(std::unique(adjacencies.begin(), adjacencies.end()) == adjacencies.end());
621            B1.insert(std::move(adjacencies));
622        }
623    }
624
625    IntersectionSets B(B1);
626
627    IntersectionSets Bi;
628
629    VertexSet clique;
630    for (auto i = B1.begin(); i != B1.end(); ++i) {
631        for (auto j = i; ++j != B1.end(); ) {
632            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
633            if (clique.size() > 0) {
634                if (B.count(clique) == 0) {
635                    Bi.insert(clique);
636                }
637                clique.clear();
638            }
639        }
640    }
641
642    for (;;) {
643        if (Bi.empty()) {
644            break;
645        }
646        B.insert(Bi.begin(), Bi.end());
647        IntersectionSets Bk;
648        for (auto i = B1.begin(); i != B1.end(); ++i) {
649            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
650                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
651                if (clique.size() > 0) {
652                    if (B.count(clique) == 0) {
653                        Bk.insert(clique);
654                    }
655                    clique.clear();
656                }
657            }
658        }
659        Bi.swap(Bk);
660    }
661
662    BicliqueSet cliques;
663    cliques.reserve(B.size());
664    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
665        VertexSet Ai(A);
666        for (const Vertex u : *Bi) {
667            VertexSet Aj;
668            Aj.reserve(out_degree(u, G));
669            for (auto e : make_iterator_range(out_edges(u, G))) {
670                Aj.push_back(target(e, G));
671            }
672            std::sort(Aj.begin(), Aj.end());
673            assert(std::unique(Aj.begin(), Aj.end()) == Aj.end());
674            VertexSet Ak;
675            Ak.reserve(std::min(Ai.size(), Aj.size()));
676            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
677            Ai.swap(Ak);
678        }
679        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
680        cliques.emplace_back(std::move(Ai), std::move(*Bi));
681    }
682    return std::move(cliques);
683}
684
685/** ------------------------------------------------------------------------------------------------------------- *
686 * @brief independentCliqueSets
687 ** ------------------------------------------------------------------------------------------------------------- */
688template <unsigned side>
689inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
690    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
691
692    const auto l = cliques.size();
693    IndependentSetGraph I(l);
694
695    // Initialize our weights
696    for (unsigned i = 0; i != l; ++i) {
697        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
698    }
699
700    // Determine our constraints
701    for (unsigned i = 0; i != l; ++i) {
702        for (unsigned j = i + 1; j != l; ++j) {
703            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
704                add_edge(i, j, I);
705            }
706        }
707    }
708
709    // Use the greedy algorithm to choose our independent set
710    VertexSet selected;
711    for (;;) {
712        unsigned w = 0;
713        Vertex u = 0;
714        for (unsigned i = 0; i != l; ++i) {
715            if (I[i] > w) {
716                w = I[i];
717                u = i;
718            }
719        }
720        if (w < minimum) break;
721        selected.push_back(u);
722        I[u] = 0;
723        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
724            I[v] = 0;
725        }
726    }
727
728    // Sort the selected list and then remove the unselected cliques
729    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
730    auto end = cliques.end();
731    for (const unsigned offset : selected) {
732        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
733    }
734    cliques.erase(cliques.begin(), end);
735
736    return std::move(cliques);
737}
738
739/** ------------------------------------------------------------------------------------------------------------- *
740 * @brief removeUnhelpfulBicliques
741 *
742 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
743 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
744 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
745 ** ------------------------------------------------------------------------------------------------------------- */
746static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G, DistributionGraph & H) {
747    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
748        const auto cardinalityA = std::get<0>(*ci).size();
749        VertexSet & B = std::get<1>(*ci);
750        for (auto bi = B.begin(); bi != B.end(); ) {
751            if (out_degree(H[*bi], G) == cardinalityA) {
752                ++bi;
753            } else {
754                bi = B.erase(bi);
755            }
756        }
757        if (B.size() > 1) {
758            ++ci;
759        } else {
760            ci = cliques.erase(ci);
761        }
762    }
763    return std::move(cliques);
764}
765
766/** ------------------------------------------------------------------------------------------------------------- *
767 * @brief safeDistributionSets
768 ** ------------------------------------------------------------------------------------------------------------- */
769static DistributionSets safeDistributionSets(const Graph & G, DistributionGraph & H) {
770
771    VertexSet sinks;
772    for (const Vertex u : make_iterator_range(vertices(H))) {
773        if (out_degree(u, H) == 0 && in_degree(u, H) != 0) {
774            sinks.push_back(u);
775        }
776    }
777    std::sort(sinks.begin(), sinks.end());
778
779    DistributionSets T;
780    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(H, sinks), G, H), 1);
781    for (Biclique & lower : lowerSet) {
782        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(H, std::get<1>(lower)), 2);
783        for (Biclique & upper : upperSet) {
784            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
785        }
786    }
787    return std::move(T);
788}
789
790/** ------------------------------------------------------------------------------------------------------------- *
791 * @brief getVertex
792 ** ------------------------------------------------------------------------------------------------------------- */
793template<typename ValueType, typename GraphType, typename MapType>
794static inline Vertex getVertex(const ValueType value, GraphType & G, MapType & M) {
795    const auto f = M.find(value);
796    if (f != M.end()) {
797        return f->second;
798    }
799    const auto u = add_vertex(value, G);
800    M.insert(std::make_pair(value, u));
801    return u;
802}
803
804/** ------------------------------------------------------------------------------------------------------------- *
805 * @brief generateDistributionGraph
806 ** ------------------------------------------------------------------------------------------------------------- */
807void generateDistributionGraph(const Graph & G, DistributionGraph & H) {
808    DistributionMap M;
809    for (const Vertex u : make_iterator_range(vertices(G))) {
810        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
811            continue;
812        } else if (isDistributive(G[u])) {
813            const TypeId outerTypeId = getType(G[u]);
814            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
815            flat_set<Vertex> distributable;
816            for (auto e : make_iterator_range(in_edges(u, G))) {
817                const Vertex v = source(e, G);
818                if (LLVM_UNLIKELY(getType(G[v]) == innerTypeId)) {
819                    bool safe = true;
820                    for (const auto e : make_iterator_range(out_edges(v, G))) {
821                        if (getType(G[target(e, G)]) != outerTypeId) {
822                            safe = false;
823                            break;
824                        }
825                    }
826                    if (safe) {
827                        distributable.insert(v);
828                    }
829                }
830            }
831            if (LLVM_LIKELY(distributable.size() > 1)) {
832                flat_set<Vertex> observed;
833                for (const Vertex v : distributable) {
834                    for (auto e : make_iterator_range(in_edges(v, G))) {
835                        const auto v = source(e, G);
836                        observed.insert(v);
837                    }
838                }
839                for (const Vertex w : observed) {
840                    for (auto e : make_iterator_range(out_edges(w, G))) {
841                        const Vertex v = target(e, G);
842                        if (distributable.count(v)) {
843                            const Vertex y = getVertex(v, H, M);
844                            boost::add_edge(y, getVertex(u, H, M), H);
845                            boost::add_edge(getVertex(w, H, M), y, H);
846                        }
847                    }
848                }
849            }
850        }
851    }
852}
853
854/** ------------------------------------------------------------------------------------------------------------- *
855 * @brief recomputeDefinition
856 ** ------------------------------------------------------------------------------------------------------------- */
857inline Z3_ast BooleanReassociationPass::computeDefinition(const TypeId typeId, const Vertex u, Graph & G, const bool use_expensive_minimization) const {
858    const unsigned n = in_degree(u, G);
859    Z3_ast operands[n];
860    unsigned k = 0;
861    for (const auto e : make_iterator_range(in_edges(u, G))) {
862        const auto v = source(e, G);
863        if (LLVM_UNLIKELY(getDefinition(G[v]) == nullptr)) {
864            throw std::runtime_error("No definition for " + std::to_string(v));
865        }
866        operands[k++] = getDefinition(G[v]);
867    }
868    assert (k == n);
869    Z3_ast const node = (typeId == TypeId::And) ? Z3_mk_and(mContext, n, operands) : Z3_mk_or(mContext, n, operands);
870    return simplify(node, use_expensive_minimization);
871}
872
873/** ------------------------------------------------------------------------------------------------------------- *
874 * @brief updateDefinition
875 *
876 * Apply the distribution law to reduce computations whenever possible.
877 ** ------------------------------------------------------------------------------------------------------------- */
878Vertex BooleanReassociationPass::updateIntermediaryDefinition(const TypeId typeId, const Vertex u, VertexMap & M, Graph & G) {
879
880    Z3_ast def = computeDefinition(typeId, u, G);
881    Z3_ast orig = getDefinition(G[u]); assert (orig);
882
883    Z3_dec_ref(mContext, orig);
884
885    const auto g = M.find(orig);
886    if (LLVM_LIKELY(g != M.end())) {
887        M.erase(g);
888    }
889
890    const auto f = std::find(mRefs.rbegin(), mRefs.rend(), orig);
891    assert (f != mRefs.rend());
892    *f = def;
893
894    const auto h = M.find(def);
895    if (LLVM_UNLIKELY(h != M.end())) {
896        const auto v = h->second;
897        if (v != u) {
898            for (auto e : make_iterator_range(out_edges(u, G))) {
899                add_edge(G[e], v, target(e, G), G);
900            }
901            removeVertex(u, G);
902            return v;
903        }
904    }
905
906    getDefinition(G[u]) = def;
907    M.emplace(def, u);
908    return u;
909}
910
911/** ------------------------------------------------------------------------------------------------------------- *
912 * @brief updateDefinition
913 *
914 * Apply the distribution law to reduce computations whenever possible.
915 ** ------------------------------------------------------------------------------------------------------------- */
916Vertex BooleanReassociationPass::updateSinkDefinition(const TypeId typeId, const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G) {
917
918    Z3_ast const def = computeDefinition(typeId, u, G);
919
920    auto f = M.find(def);
921
922    if (LLVM_UNLIKELY(f != M.end())) {
923        Z3_dec_ref(mContext, def);
924        Vertex v = f->second; assert (v != u);
925        for (auto e : make_iterator_range(out_edges(u, G))) {
926            add_edge(G[e], v, target(e, G), G);
927        }
928        removeVertex(u, G);
929        return v;
930    } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
931        PabloAST * const factor = C.predecessor()->findKey(def);
932        if (LLVM_UNLIKELY(factor != nullptr)) {
933            getValue(G[u]) = factor;
934            getType(G[u]) = TypeId::Var;
935            clear_in_edges(u, G);
936        }
937    }
938
939    getDefinition(G[u]) = def;
940    mRefs.push_back(def);
941
942    graph_traits<Graph>::in_edge_iterator begin, end;
943
944restart:
945
946    if (in_degree(u, G) > 1) {
947        std::tie(begin, end) = in_edges(u, G);
948        for (auto i = begin; ++i != end; ) {
949            const auto v = source(*i, G);
950            for (auto j = begin; j != i; ++j) {
951                const auto w = source(*j, G);
952                Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
953                Z3_ast test = nullptr;
954                switch (typeId) {
955                    case TypeId::And:
956                        test = Z3_mk_and(mContext, 2, operands); break;
957                    case TypeId::Or:
958                        test = Z3_mk_or(mContext, 2, operands); break;
959                    case TypeId::Xor:
960                        test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
961                    default:
962                        llvm_unreachable("impossible type id");
963                }
964                test = simplify(test, true);
965
966                bool replacement = false;
967                Vertex x = 0;
968                const auto f = M.find(test);
969                if (LLVM_UNLIKELY(f != M.end())) {
970                    x = f->second;
971                    Z3_ast orig = getDefinition(G[x]);
972                    if (LLVM_UNLIKELY(orig != test)) {
973                        std::string tmp;
974                        raw_string_ostream out(tmp);
975                        out << "vertex " << x << " is mapped to:\n"
976                            << Z3_ast_to_string(mContext, test)
977                            << "\n\nBut is recorded as:\n\n";
978                        if (orig) {
979                            out << Z3_ast_to_string(mContext, orig);
980                        } else {
981                            out << "<null>";
982                        }
983                        throw std::runtime_error(out.str());
984                    }
985                    Z3_dec_ref(mContext, test);
986                    replacement = true;
987                } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
988                    PabloAST * const factor = C.predecessor()->findKey(test);
989                    if (LLVM_UNLIKELY(factor != nullptr)) {
990                        x = makeVertex(TypeId::Var, factor, G, test);
991                        M.emplace(test, x);
992                        replacement = true;
993                        mRefs.push_back(test);
994                    }
995                }
996
997                if (LLVM_UNLIKELY(replacement)) {
998
999                    assert (G[*i] == nullptr);
1000                    assert (G[*j] == nullptr);
1001
1002                    remove_edge(*i, G);
1003                    remove_edge(*j, G);
1004
1005                    add_edge(nullptr, x, u, G);
1006
1007                    goto restart;
1008                }
1009
1010                Z3_dec_ref(mContext, test);
1011            }
1012        }
1013    }
1014
1015    M.emplace(def, u);
1016
1017    return u;
1018}
1019
1020/** ------------------------------------------------------------------------------------------------------------- *
1021 * @brief redistributeAST
1022 *
1023 * Apply the distribution law to reduce computations whenever possible.
1024 ** ------------------------------------------------------------------------------------------------------------- */
1025bool BooleanReassociationPass::redistributeGraph(CharacterizationMap & C, VertexMap & M, Graph & G) {
1026
1027    bool modified = false;
1028
1029//    errs() << "=====================================================\n";
1030
1031    DistributionGraph H;
1032
1033    for (;;) {
1034
1035        contractGraph(M, G);
1036
1037//        printGraph(G, "G");
1038
1039        generateDistributionGraph(G, H);
1040
1041        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
1042        if (num_vertices(H) == 0) {
1043            break;
1044        }
1045
1046        const DistributionSets distributionSets = safeDistributionSets(G, H);
1047
1048        if (LLVM_UNLIKELY(distributionSets.empty())) {
1049            break;
1050        }
1051
1052        modified = true;
1053
1054        mRefs.reserve(distributionSets.size() * 2);
1055
1056        for (const DistributionSet & set : distributionSets) {
1057
1058            // Each distribution tuple consists of the sources, intermediary, and sink nodes.
1059            const VertexSet & sources = std::get<0>(set);
1060            const VertexSet & intermediary = std::get<1>(set);
1061            const VertexSet & sinks = std::get<2>(set);
1062
1063            const TypeId outerTypeId = getType(G[H[sinks.front()]]);
1064            assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
1065            const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
1066
1067            const Vertex x = makeVertex(outerTypeId, nullptr, G);
1068            const Vertex y = makeVertex(innerTypeId, nullptr, G);
1069
1070            // Update G to reflect the distributed operations (including removing the subgraph of
1071            // the to-be distributed edges.)
1072
1073            add_edge(nullptr, x, y, G);
1074
1075            for (const Vertex i : sources) {
1076                const auto u = H[i];
1077                for (const Vertex j : intermediary) {
1078                    const auto v = H[j];
1079                    assert (getType(G[v]) == innerTypeId);
1080                    const auto e = edge(u, v, G); assert (e.second);
1081                    remove_edge(e.first, G);
1082                }
1083                add_edge(nullptr, u, y, G);
1084            }
1085
1086            for (const Vertex i : intermediary) {
1087
1088                const auto u = updateIntermediaryDefinition(innerTypeId, H[i], M, G);
1089
1090                for (const Vertex j : sinks) {
1091                    const auto v = H[j];
1092                    assert (getType(G[v]) == outerTypeId);
1093                    const auto e = edge(u, v, G); assert (e.second);
1094                    add_edge(G[e.first], y, v, G);
1095                    remove_edge(e.first, G);
1096                }
1097                add_edge(nullptr, u, x, G);
1098            }
1099
1100            updateSinkDefinition(outerTypeId, x, C, M, G);
1101
1102            updateSinkDefinition(innerTypeId, y, C, M, G);
1103
1104        }
1105
1106        H.clear();
1107
1108    }
1109
1110    return modified;
1111}
1112
1113/** ------------------------------------------------------------------------------------------------------------- *
1114 * @brief isNonEscaping
1115 ** ------------------------------------------------------------------------------------------------------------- */
1116inline bool isNonEscaping(const VertexData & data) {
1117    // If these are redundant, the Simplifier pass will eliminate them. Trust that they're necessary.
1118    switch (getType(data)) {
1119        case TypeId::Assign:
1120        case TypeId::Next:
1121        case TypeId::If:
1122        case TypeId::While:
1123        case TypeId::Count:
1124            return false;
1125        default:
1126            return true;
1127    }
1128}
1129
1130/** ------------------------------------------------------------------------------------------------------------- *
1131 * @brief unique_source
1132 ** ------------------------------------------------------------------------------------------------------------- */
1133inline bool has_unique_source(const Vertex u, const Graph & G) {
1134    if (in_degree(u, G) > 0) {
1135        graph_traits<Graph>::in_edge_iterator i, end;
1136        std::tie(i, end) = in_edges(u, G);
1137        const Vertex v = source(*i, G);
1138        while (++i != end) {
1139            if (source(*i, G) != v) {
1140                return false;
1141            }
1142        }
1143        return true;
1144    }
1145    return false;
1146}
1147
1148/** ------------------------------------------------------------------------------------------------------------- *
1149 * @brief unique_target
1150 ** ------------------------------------------------------------------------------------------------------------- */
1151inline bool has_unique_target(const Vertex u, const Graph & G) {
1152    if (out_degree(u, G) > 0) {
1153        graph_traits<Graph>::out_edge_iterator i, end;
1154        std::tie(i, end) = out_edges(u, G);
1155        const Vertex v = target(*i, G);
1156        while (++i != end) {
1157            if (target(*i, G) != v) {
1158                return false;
1159            }
1160        }
1161        return true;
1162    }
1163    return false;
1164}
1165
1166/** ------------------------------------------------------------------------------------------------------------- *
1167 * @brief contractGraph
1168 ** ------------------------------------------------------------------------------------------------------------- */
1169bool BooleanReassociationPass::contractGraph(VertexMap & M, Graph & G) const {
1170
1171    bool contracted = false;
1172
1173    circular_buffer<Vertex> ordering(num_vertices(G));
1174
1175    topological_sort(G, std::back_inserter(ordering)); // reverse topological ordering
1176
1177    // first contract the graph
1178    for (const Vertex u : ordering) {
1179        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
1180            continue;
1181        } else if (LLVM_LIKELY(out_degree(u, G) > 0)) {
1182            if (isAssociative(G[u])) {
1183                if (LLVM_UNLIKELY(has_unique_source(u, G))) {
1184                    // We have a redundant node here that'll simply end up being a duplicate
1185                    // of the input value. Remove it and add any of its outgoing edges to its
1186                    // input node.
1187                    const auto ei = first(in_edges(u, G));
1188                    const Vertex v = source(ei, G);
1189                    for (auto ej : make_iterator_range(out_edges(u, G))) {
1190                        add_edge(G[ej], v, target(ej, G), G);
1191                    }
1192                    removeVertex(u, M, G);
1193                    contracted = true;
1194                } else if (LLVM_UNLIKELY(has_unique_target(u, G))) {
1195                    // Otherwise if we have a single user, we have a similar case as above but
1196                    // we can only merge this vertex into the outgoing instruction if they are
1197                    // of the same type.
1198                    const auto ei = first(out_edges(u, G));
1199                    const Vertex v = target(ei, G);
1200                    if (LLVM_UNLIKELY(getType(G[v]) == getType(G[u]))) {
1201                        for (auto ej : make_iterator_range(in_edges(u, G))) {
1202                            add_edge(G[ej], source(ej, G), v, G);
1203                        }
1204                        removeVertex(u, M, G);
1205                        contracted = true;
1206                    }
1207                }
1208            }
1209        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
1210            removeVertex(u, M, G);
1211            contracted = true;
1212        }
1213    }
1214    return contracted;
1215}
1216
1217/** ------------------------------------------------------------------------------------------------------------- *
1218 * @brief factorGraph
1219 ** ------------------------------------------------------------------------------------------------------------- */
1220bool BooleanReassociationPass::factorGraph(const TypeId typeId, Graph & G, std::vector<Vertex> & factors) const {
1221
1222    if (LLVM_UNLIKELY(factors.empty())) {
1223        return false;
1224    }
1225
1226    std::vector<Vertex> I, J, K;
1227
1228    bool modified = false;
1229
1230    for (unsigned i = 1; i < factors.size(); ++i) {
1231        assert (getType(G[factors[i]]) == typeId);
1232        for (auto ei : make_iterator_range(in_edges(factors[i], G))) {
1233            I.push_back(source(ei, G));
1234        }
1235        std::sort(I.begin(), I.end());
1236        for (unsigned j = 0; j < i; ++j) {
1237            for (auto ej : make_iterator_range(in_edges(factors[j], G))) {
1238                J.push_back(source(ej, G));
1239            }
1240            std::sort(J.begin(), J.end());
1241            // get the pairwise intersection of each set of inputs (i.e., their common subexpression)
1242            std::set_intersection(I.begin(), I.end(), J.begin(), J.end(), std::back_inserter(K));
1243            assert (std::is_sorted(K.begin(), K.end()));
1244            // if the intersection contains at least two elements
1245            const auto n = K.size();
1246            if (n > 1) {
1247                Vertex a = factors[i];
1248                Vertex b = factors[j];
1249                if (LLVM_UNLIKELY(in_degree(a, G) == n || in_degree(b, G) == n)) {
1250                    if (in_degree(a, G) != n) {
1251                        assert (in_degree(b, G) == n);
1252                        std::swap(a, b);
1253                    }
1254                    assert (in_degree(a, G) == n);
1255                    if (in_degree(b, G) == n) {
1256                        for (auto e : make_iterator_range(out_edges(b, G))) {
1257                            add_edge(G[e], a, target(e, G), G);
1258                        }
1259                        removeVertex(b, G);
1260                    } else {
1261                        for (auto u : K) {
1262                            remove_edge(u, b, G);
1263                        }
1264                        add_edge(nullptr, a, b, G);
1265                    }
1266                } else {
1267                    Vertex v = makeVertex(typeId, nullptr, G);
1268                    for (auto u : K) {
1269                        remove_edge(u, a, G);
1270                        remove_edge(u, b, G);
1271                        add_edge(nullptr, u, v, G);
1272                    }
1273                    add_edge(nullptr, v, a, G);
1274                    add_edge(nullptr, v, b, G);
1275                    factors.push_back(v);
1276                }
1277                modified = true;
1278            }
1279            K.clear();
1280            J.clear();
1281        }
1282        I.clear();
1283    }
1284    return modified;
1285}
1286
1287/** ------------------------------------------------------------------------------------------------------------- *
1288 * @brief factorGraph
1289 ** ------------------------------------------------------------------------------------------------------------- */
1290bool BooleanReassociationPass::factorGraph(Graph & G) const {
1291    // factor the associative vertices.
1292    std::vector<Vertex> factors;
1293    bool factored = false;
1294    for (unsigned i = 0; i < 3; ++i) {
1295        TypeId typeId[3] = { TypeId::And, TypeId::Or, TypeId::Xor};
1296        for (auto j : make_iterator_range(vertices(G))) {
1297            if (getType(G[j]) == typeId[i]) {
1298                factors.push_back(j);
1299            }
1300        }
1301        if (factorGraph(typeId[i], G, factors)) {
1302            factored = true;
1303        }
1304        factors.clear();
1305    }
1306    return factored;
1307}
1308
1309/** ------------------------------------------------------------------------------------------------------------- *
1310 * @brief isMutable
1311 ** ------------------------------------------------------------------------------------------------------------- */
1312inline bool isMutable(const Vertex u, const Graph & G) {
1313    return getType(G[u]) != TypeId::Var;
1314}
1315
1316/** ------------------------------------------------------------------------------------------------------------- *
1317 * @brief rewriteAST
1318 ** ------------------------------------------------------------------------------------------------------------- */
1319bool BooleanReassociationPass::rewriteAST(Graph & G) {
1320
1321    using line_t = long long int;
1322
1323    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
1324
1325    // errs() << "---------------------------------------------------------\n";
1326
1327    // printGraph(G, "X");
1328
1329    Z3_config cfg = Z3_mk_config();
1330    Z3_set_param_value(cfg, "model", "true");
1331    Z3_context ctx = Z3_mk_context(cfg);
1332    Z3_del_config(cfg);
1333    Z3_solver solver = Z3_mk_solver(ctx);
1334    Z3_solver_inc_ref(ctx, solver);
1335
1336    std::vector<Z3_ast> mapping(num_vertices(G), nullptr);
1337
1338    flat_map<PabloAST *, Z3_ast> V;
1339
1340    // Generate the variables
1341    const auto ty = Z3_mk_int_sort(ctx);
1342    Z3_ast ZERO = Z3_mk_int(ctx, 0, ty);
1343
1344    for (const Vertex u : make_iterator_range(vertices(G))) {
1345        const auto var = Z3_mk_fresh_const(ctx, nullptr, ty);
1346        Z3_ast constraint = nullptr;
1347        if (in_degree(u, G) > 0) {
1348            constraint = Z3_mk_gt(ctx, var, ZERO);
1349            Z3_solver_assert(ctx, solver, constraint);
1350        }       
1351        PabloAST * const expr = getValue(G[u]);
1352        if (inCurrentBlock(expr, mBlock)) {
1353            V.emplace(expr, var);
1354        }
1355        mapping[u] = var;
1356    }
1357
1358    // Add in the dependency constraints
1359    for (const Vertex u : make_iterator_range(vertices(G))) {
1360        Z3_ast const t = mapping[u];
1361        for (auto e : make_iterator_range(in_edges(u, G))) {
1362            Z3_ast const s = mapping[source(e, G)];
1363            Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, s, t));
1364        }
1365    }
1366
1367    // Compute the soft ordering constraints
1368    std::vector<Z3_ast> ordering(0);
1369    ordering.reserve(V.size() - 1);
1370
1371    Z3_ast prior = nullptr;
1372    unsigned gap = 1;
1373    for (Statement * stmt : *mBlock) {
1374        auto f = V.find(stmt);
1375        if (f != V.end()) {
1376            Z3_ast const node = f->second;
1377            if (prior) {
1378//                ordering.push_back(Z3_mk_lt(ctx, prior, node)); // increases the cost by 6 - 10x
1379                Z3_ast ops[2] = { node, prior };
1380                ordering.push_back(Z3_mk_le(ctx, Z3_mk_sub(ctx, 2, ops), Z3_mk_int(ctx, gap, ty)));
1381            } else {
1382                ordering.push_back(Z3_mk_eq(ctx, node, Z3_mk_int(ctx, gap, ty)));
1383            }
1384            prior = node;
1385            gap = 0;
1386        }
1387        ++gap;
1388    }
1389
1390    if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) < 0)) {
1391        throw std::runtime_error("Unable to construct a topological ordering during reassociation!");
1392    }
1393
1394    Z3_model model = Z3_solver_get_model(ctx, solver);
1395    Z3_model_inc_ref(ctx, model);
1396
1397    std::vector<Vertex> S(0);
1398    S.reserve(num_vertices(G));
1399
1400    std::vector<line_t> L(num_vertices(G));
1401
1402    for (const Vertex u : make_iterator_range(vertices(G))) {
1403        line_t line = LoadEarly ? 0 : MAX_INT;
1404        if (isMutable(u, G)) {
1405            Z3_ast value;
1406            if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, mapping[u], Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1407                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1408            }
1409            if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1410                throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1411            }
1412            S.push_back(u);
1413        }
1414        L[u] = line;
1415    }
1416
1417    Z3_model_dec_ref(ctx, model);
1418    Z3_solver_dec_ref(ctx, solver);
1419    Z3_del_context(ctx);
1420
1421    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1422
1423    mBlock->setInsertPoint(nullptr);
1424
1425    std::vector<Vertex> T;
1426
1427    line_t count = 1;
1428
1429    for (auto u : S) {
1430        PabloAST *& stmt = getValue(G[u]);
1431
1432        assert (isMutable(u, G));
1433        assert (L[u] > 0 && L[u] < MAX_INT);
1434
1435        bool append = true;
1436
1437        if (isAssociative(G[u])) {
1438
1439            if (in_degree(u, G) == 0 || out_degree(u, G) == 0) {
1440                throw std::runtime_error("Vertex " + std::to_string(u) + " is either a source or sink node but marked as associative!");
1441            }
1442
1443            Statement * ip = mBlock->getInsertPoint();
1444            ip = ip ? ip->getNextNode() : mBlock->front();
1445
1446            const auto typeId = getType(G[u]);
1447
1448            T.clear();
1449            T.reserve(in_degree(u, G));
1450            for (const auto e : make_iterator_range(in_edges(u, G))) {
1451                T.push_back(source(e, G));
1452            }
1453
1454            // Then sort them by their line position (noting any incoming value will either be 0 or MAX_INT)
1455            std::sort(T.begin(), T.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
1456
1457            if (LoadEarly) {
1458                mBlock->setInsertPoint(nullptr);
1459            }
1460
1461            PabloAST * expr = nullptr;
1462            PabloAST * join = nullptr;
1463
1464            for (auto v : T) {
1465                expr = getValue(G[v]);
1466                if (LLVM_UNLIKELY(expr == nullptr)) {
1467                    throw std::runtime_error("Vertex " + std::to_string(v) + " does not have an expression!");
1468                }
1469                if (join) {
1470
1471                    if (in_degree(v, G) > 0) {
1472
1473                        assert (L[v] > 0 && L[v] < MAX_INT);
1474
1475                        Statement * dom = cast<Statement>(expr);
1476                        for (;;) {
1477                            PabloBlock * const parent = dom->getParent();
1478                            if (parent == mBlock) {
1479                                break;
1480                            }
1481                            dom = parent->getBranch(); assert(dom);
1482                        }
1483                        mBlock->setInsertPoint(dom);
1484
1485                        assert (dominates(join, expr));
1486                        assert (dominates(expr, ip));
1487                        assert (dominates(dom, ip));
1488                    }
1489
1490                    switch (typeId) {
1491                        case TypeId::And:
1492                            expr = mBlock->createAnd(join, expr); break;
1493                        case TypeId::Or:
1494                            expr = mBlock->createOr(join, expr); break;
1495                        case TypeId::Xor:
1496                            expr = mBlock->createXor(join, expr); break;
1497                        default:
1498                            llvm_unreachable("Invalid TypeId!");
1499                    }
1500                }
1501                join = expr;
1502            }
1503
1504            assert (expr);
1505
1506            Statement * const currIP = mBlock->getInsertPoint();
1507
1508            mBlock->setInsertPoint(ip->getPrevNode());
1509
1510            stmt = expr;
1511
1512            // If the insertion point isn't the statement we just attempted to create,
1513            // we must have unexpectidly reused a prior statement (or var.)
1514            if (LLVM_UNLIKELY(expr != currIP)) {
1515                L[u] = 0;
1516                // If that statement happened to be in the current block, locate which
1517                // statement it "duplicated" (if any) and set the duplicate's line # to
1518                // the line of the original statement to maintain a consistent view of
1519                // the program.
1520                if (inCurrentBlock(expr, mBlock)) {
1521                    for (auto v : make_iterator_range(vertices(G))) {
1522                        if (LLVM_UNLIKELY(getValue(G[v]) == expr)) {
1523                            L[u] = L[v];
1524                            break;
1525                        }
1526                    }
1527                }
1528                append = false;
1529            }
1530        } else if (stmt == nullptr) {
1531            assert (getType(G[u]) == TypeId::Not);
1532            assert (in_degree(u, G) == 1);
1533            PabloAST * op = getValue(G[source(first(in_edges(u, G)), G)]); assert (op);
1534            stmt = mBlock->createNot(op);
1535        } else if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
1536            for (auto e : make_iterator_range(out_edges(u, G))) {
1537                const auto v = target(e, G);
1538                assert (L[v] == std::numeric_limits<line_t>::max());
1539                L[v] = count;
1540            }
1541        }
1542
1543        for (auto e : make_iterator_range(out_edges(u, G))) {
1544            if (G[e]) {
1545                if (PabloAST * user = getValue(G[target(e, G)])) {
1546                    cast<Statement>(user)->replaceUsesOfWith(G[e], stmt);
1547                }
1548            }
1549        }
1550
1551        if (LLVM_LIKELY(append)) {
1552            mBlock->insert(cast<Statement>(stmt));
1553            L[u] = count++; // update the line count with the actual one.
1554        }
1555    }
1556
1557    Statement * const end = mBlock->getInsertPoint(); assert (end);
1558    for (;;) {
1559        Statement * const next = end->getNextNode();
1560        if (next == nullptr) {
1561            break;
1562        }
1563
1564        #ifndef NDEBUG
1565        for (PabloAST * user : next->users()) {
1566            if (isa<Statement>(user) && dominates(user, next)) {
1567                std::string tmp;
1568                raw_string_ostream out(tmp);
1569                out << "Erasing ";
1570                PabloPrinter::print(next, out);
1571                out << " erroneously modifies live statement ";
1572                PabloPrinter::print(cast<Statement>(user), out);
1573                throw std::runtime_error(out.str());
1574            }
1575        }
1576        #endif
1577        next->eraseFromParent(true);
1578    }
1579
1580    #ifndef NDEBUG
1581    PabloVerifier::verify(mFunction, "mid-reassociation");
1582    #endif
1583
1584    return true;
1585}
1586
1587/** ------------------------------------------------------------------------------------------------------------- *
1588* @brief makeVertex
1589** ------------------------------------------------------------------------------------------------------------- */
1590Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
1591    assert (expr);
1592    const auto f = S.find(expr);
1593    if (f != S.end()) {
1594        assert (getValue(G[f->second]) == expr);
1595        return f->second;
1596    }
1597    const auto node = C.get(expr);   
1598    const Vertex u = makeVertex(typeId, expr, G, node);
1599    S.emplace(expr, u);
1600    if (node) {
1601        M.emplace(node, u);
1602    }
1603    return u;
1604}
1605
1606/** ------------------------------------------------------------------------------------------------------------- *
1607 * @brief makeVertex
1608 ** ------------------------------------------------------------------------------------------------------------- */
1609Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
1610    assert (expr);
1611    const auto f = M.find(expr);
1612    if (f != M.end()) {
1613        assert (getValue(G[f->second]) == expr);
1614        return f->second;
1615    }
1616    const Vertex u = makeVertex(typeId, expr, G, node);
1617    M.emplace(expr, u);
1618    return u;
1619}
1620
1621/** ------------------------------------------------------------------------------------------------------------- *
1622 * @brief makeVertex
1623 ** ------------------------------------------------------------------------------------------------------------- */
1624Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
1625//    for (Vertex u : make_iterator_range(vertices(G))) {
1626//        if (LLVM_UNLIKELY(in_degree(u, G) == 0 && out_degree(u, G) == 0)) {
1627//            std::get<0>(G[u]) = typeId;
1628//            std::get<1>(G[u]) = expr;
1629//            return u;
1630//        }
1631//    }
1632    return add_vertex(std::make_tuple(typeId, expr, node), G);
1633}
1634
1635/** ------------------------------------------------------------------------------------------------------------- *
1636 * @brief removeVertex
1637 ** ------------------------------------------------------------------------------------------------------------- */
1638inline void BooleanReassociationPass::removeVertex(const Vertex u, VertexMap & M, Graph & G) const {
1639    VertexData & ref = G[u];
1640    Z3_ast def = getDefinition(ref); assert (def);
1641    auto f = M.find(def); assert (f != M.end());
1642    M.erase(f);
1643    removeVertex(u, G);
1644}
1645
1646/** ------------------------------------------------------------------------------------------------------------- *
1647 * @brief removeVertex
1648 ** ------------------------------------------------------------------------------------------------------------- */
1649inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
1650    VertexData & ref = G[u];
1651    clear_vertex(u, G);
1652    std::get<0>(ref) = TypeId::Var;
1653    std::get<1>(ref) = nullptr;
1654    std::get<2>(ref) = nullptr;
1655}
1656
1657/** ------------------------------------------------------------------------------------------------------------- *
1658 * @brief make
1659 ** ------------------------------------------------------------------------------------------------------------- */
1660inline Z3_ast BooleanReassociationPass::makeVar() {
1661    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1662    Z3_inc_ref(mContext, node);
1663    mRefs.push_back(node);
1664    return node;
1665}
1666
1667/** ------------------------------------------------------------------------------------------------------------- *
1668 * @brief simplify
1669 ** ------------------------------------------------------------------------------------------------------------- */
1670inline Z3_ast BooleanReassociationPass::simplify(Z3_ast const node, bool use_expensive_minimization) const {
1671    assert (node);
1672    Z3_inc_ref(mContext, node);
1673    Z3_ast result = nullptr;
1674    if (use_expensive_minimization) {
1675
1676        Z3_goal g = Z3_mk_goal(mContext, true, false, false); assert (g);
1677        Z3_goal_inc_ref(mContext, g);
1678        Z3_goal_assert(mContext, g, node);
1679
1680        Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g); assert (r);
1681        Z3_apply_result_inc_ref(mContext, r);
1682        Z3_goal_dec_ref(mContext, g);
1683
1684        assert (Z3_apply_result_get_num_subgoals(mContext, r) == 1);
1685
1686        Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0); assert (h);
1687        Z3_goal_inc_ref(mContext, h);
1688        Z3_apply_result_dec_ref(mContext, r);
1689
1690        const unsigned n = Z3_goal_size(mContext, h);
1691
1692        if (n == 1) {
1693            result = Z3_goal_formula(mContext, h, 0); assert (result);
1694            Z3_inc_ref(mContext, result);
1695        } else if (n > 1) {
1696            Z3_ast operands[n];
1697            for (unsigned i = 0; i < n; ++i) {
1698                operands[i] = Z3_goal_formula(mContext, h, i);
1699                Z3_inc_ref(mContext, operands[i]);
1700            }
1701            result = Z3_mk_and(mContext, n, operands); assert (result);
1702            Z3_inc_ref(mContext, result);
1703            for (unsigned i = 0; i < n; ++i) {
1704                Z3_dec_ref(mContext, operands[i]);
1705            }
1706        } else {
1707            result = Z3_mk_true(mContext); assert (result);
1708        }
1709        Z3_goal_dec_ref(mContext, h);
1710
1711    } else {       
1712        result = Z3_simplify_ex(mContext, node, mParams); assert (result);
1713        Z3_inc_ref(mContext, result);
1714    }   
1715    Z3_dec_ref(mContext, node);
1716    assert (result);
1717    return result;
1718}
1719
1720/** ------------------------------------------------------------------------------------------------------------- *
1721 * @brief add
1722 ** ------------------------------------------------------------------------------------------------------------- */
1723inline Z3_ast BooleanReassociationPass::CharacterizationMap::add(PabloAST * const expr, Z3_ast node, const bool forwardOnly) {
1724    assert (expr && node);
1725    mForward.emplace(expr, node);
1726    if (!forwardOnly) {
1727        mBackward.emplace(node, expr);
1728    }
1729    return node;
1730}
1731
1732/** ------------------------------------------------------------------------------------------------------------- *
1733 * @brief get
1734 ** ------------------------------------------------------------------------------------------------------------- */
1735inline Z3_ast BooleanReassociationPass::CharacterizationMap::get(PabloAST * const expr) const {
1736    assert (expr);
1737    auto f = mForward.find(expr);
1738    if (LLVM_UNLIKELY(f == mForward.end())) {
1739        if (mPredecessor == nullptr) {
1740            return nullptr;
1741        }
1742        return mPredecessor->get(expr);
1743    }
1744    return f->second;
1745}
1746
1747/** ------------------------------------------------------------------------------------------------------------- *
1748 * @brief get
1749 ** ------------------------------------------------------------------------------------------------------------- */
1750inline PabloAST * BooleanReassociationPass::CharacterizationMap::findKey(Z3_ast const node) const {
1751    assert (node);
1752    auto f = mBackward.find(node);
1753    if (LLVM_UNLIKELY(f == mBackward.end())) {
1754        if (mPredecessor == nullptr) {
1755            return nullptr;
1756        }
1757        return mPredecessor->findKey(node);
1758    }
1759    return f->second;
1760}
1761
1762/** ------------------------------------------------------------------------------------------------------------- *
1763 * @brief constructor
1764 ** ------------------------------------------------------------------------------------------------------------- */
1765inline BooleanReassociationPass::BooleanReassociationPass(Z3_context ctx, Z3_params params, Z3_tactic tactic, PabloFunction & f)
1766: mContext(ctx)
1767, mParams(params)
1768, mTactic(tactic)
1769, mFunction(f)
1770, mModified(false) {
1771
1772}
1773
1774}
Note: See TracBrowser for help on using the repository browser.