source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_automultiplexing.cpp @ 5119

Last change on this file since 5119 was 5119, checked in by nmedfort, 3 years ago

Work on multiplexing using a fixed window.

File size: 62.8 KB
Line 
1#include "pablo_automultiplexing.hpp"
2#include <pablo/builder.hpp>
3#include <pablo/function.h>
4#include <pablo/printer_pablos.h>
5#include <boost/container/flat_set.hpp>
6#include <boost/numeric/ublas/matrix.hpp>
7#include <boost/circular_buffer.hpp>
8#include <boost/graph/topological_sort.hpp>
9#include <boost/range/iterator_range.hpp>
10#include <pablo/analysis/pabloverifier.hpp>
11#include <pablo/optimizers/pablo_simplifier.hpp>
12#include <pablo/builder.hpp>
13#include <stack>
14#include <queue>
15#include <unordered_set>
16#include <functional>
17#include <llvm/Support/CommandLine.h>
18
19using namespace llvm;
20using namespace boost;
21using namespace boost::container;
22using namespace boost::numeric::ublas;
23
24static cl::OptionCategory MultiplexingOptions("Multiplexing Optimization Options", "These options control the Pablo Multiplexing optimization pass.");
25
26#ifdef NDEBUG
27#define INITIAL_SEED_VALUE (std::random_device()())
28#else
29#define INITIAL_SEED_VALUE (83234827342)
30#endif
31
32static cl::opt<std::mt19937::result_type> Seed("multiplexing-seed", cl::init(INITIAL_SEED_VALUE),
33                                        cl::desc("randomization seed used when performing any non-deterministic operations."),
34                                        cl::cat(MultiplexingOptions));
35
36#undef INITIAL_SEED_VALUE
37
38static cl::opt<unsigned> WindowSize("multiplexing-window-size", cl::init(100),
39                                        cl::desc("maximum sequence distance to consider for candidate set."),
40                                        cl::cat(MultiplexingOptions));
41
42
43namespace pablo {
44
45Z3_bool maxsat(Z3_context ctx, Z3_solver solver, std::vector<Z3_ast> & soft);
46
47using TypeId = PabloAST::ClassTypeId;
48
49/** ------------------------------------------------------------------------------------------------------------- *
50 * @brief optimize
51 * @param function the function to optimize
52 ** ------------------------------------------------------------------------------------------------------------- */
53bool MultiplexingPass::optimize(PabloFunction & function) {
54
55    PabloVerifier::verify(function, "pre-multiplexing");
56
57    errs() << "PRE-MULTIPLEXING\n==============================================\n";
58    PabloPrinter::print(function, errs());
59
60    Z3_config cfg = Z3_mk_config();
61    Z3_context ctx = Z3_mk_context_rc(cfg);
62    Z3_del_config(cfg);
63    Z3_solver solver = Z3_mk_solver(ctx);
64    Z3_solver_inc_ref(ctx, solver);
65
66    MultiplexingPass mp(function, Seed, ctx, solver);
67
68    mp.optimize();
69
70    Z3_solver_dec_ref(ctx, solver);
71    Z3_del_context(ctx);
72
73    PabloVerifier::verify(function, "post-multiplexing");
74
75    Simplifier::optimize(function);
76
77    errs() << "POST-MULTIPLEXING\n==============================================\n";
78    PabloPrinter::print(function, errs());
79
80    return true;
81}
82
83/** ------------------------------------------------------------------------------------------------------------- *
84 * @brief characterize
85 * @param function the function to optimize
86 ** ------------------------------------------------------------------------------------------------------------- */
87void MultiplexingPass::optimize() {
88    // Map the constants and input variables
89
90    add(PabloBlock::createZeroes(), Z3_mk_false(mContext));
91    add(PabloBlock::createOnes(), Z3_mk_true(mContext));
92    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
93        make(mFunction.getParameter(i));
94    }
95
96    optimize(mFunction.getEntryBlock());
97}
98
99/** ------------------------------------------------------------------------------------------------------------- *
100 * @brief characterize
101 ** ------------------------------------------------------------------------------------------------------------- */
102void MultiplexingPass::optimize(PabloBlock * const block) {
103    Statement * begin = block->front();
104    Statement * end = initialize(begin);
105    for (Statement * stmt : *block) {
106        if (LLVM_UNLIKELY(stmt == end)) {
107            Statement * const next = stmt->getNextNode();
108            multiplex(block, begin, stmt);
109            if (isa<If>(stmt)) {
110                optimize(cast<If>(stmt)->getBody());
111            } else if (isa<While>(stmt)) {
112                for (const Next * var : cast<While>(stmt)->getVariants()) {
113                    Z3_inc_ref(mContext, get(var->getInitial()));
114                }
115                optimize(cast<While>(stmt)->getBody());
116                // since we cannot be certain that we'll always execute at least one iteration of a loop, we must
117                // assume that the variants could either be their initial or resulting value.
118                for (const Next * var : cast<While>(stmt)->getVariants()) {
119                    Z3_ast v0 = get(var->getInitial());
120                    Z3_ast & v1 = get(var);
121                    Z3_ast merge[2] = { v0, v1 };
122                    Z3_ast r = Z3_mk_or(mContext, 2, merge);
123                    Z3_inc_ref(mContext, r);
124                    Z3_dec_ref(mContext, v0);
125                    Z3_dec_ref(mContext, v1);
126                    v1 = r;
127                    assert (get(var) == r);
128                }
129            }
130            end = initialize(begin = next);
131        } else {
132            characterize(stmt);
133        }
134    }
135    multiplex(block, begin, nullptr);
136}
137
138/** ------------------------------------------------------------------------------------------------------------- *
139 * @brief multiplex
140 ** ------------------------------------------------------------------------------------------------------------- */
141void MultiplexingPass::multiplex(PabloBlock * const block, Statement * const begin, Statement * const end) {
142    if (generateCandidateSets(begin, end)) {
143        selectMultiplexSetsGreedy();
144        eliminateSubsetConstraints();
145        multiplexSelectedSets(block, begin, end);
146    }
147}
148
149/** ------------------------------------------------------------------------------------------------------------- *
150 * @brief equals
151 ** ------------------------------------------------------------------------------------------------------------- */
152inline bool MultiplexingPass::equals(Z3_ast a, Z3_ast b) {
153    Z3_solver_push(mContext, mSolver);
154    Z3_ast test = Z3_mk_eq(mContext, a, b); // try using check assumption instead?
155    Z3_inc_ref(mContext, test);
156    Z3_solver_assert(mContext, mSolver, test);
157    const auto r = Z3_solver_check(mContext, mSolver);
158    Z3_dec_ref(mContext, test);
159    Z3_solver_pop(mContext, mSolver, 1);
160    return (r == Z3_L_TRUE);
161}
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief handle_unexpected_statement
165 ** ------------------------------------------------------------------------------------------------------------- */
166static void handle_unexpected_statement(Statement * const stmt) {
167    std::string tmp;
168    raw_string_ostream err(tmp);
169    err << "Unexpected statement type: ";
170    PabloPrinter::print(stmt, err);
171    throw std::runtime_error(err.str());
172}
173
174/** ------------------------------------------------------------------------------------------------------------- *
175 * @brief characterize
176 ** ------------------------------------------------------------------------------------------------------------- */
177inline Z3_ast MultiplexingPass::characterize(Statement * const stmt) {
178
179    const size_t n = stmt->getNumOperands(); assert (n > 0);
180    Z3_ast operands[n] = {};
181    for (size_t i = 0; i < n; ++i) {
182        PabloAST * op = stmt->getOperand(i);
183        if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
184            operands[i] = get(op, true);
185        }
186    }
187
188    Z3_ast node = operands[0];
189    switch (stmt->getClassTypeId()) {
190        case TypeId::Assign:
191        case TypeId::Next:
192        case TypeId::AtEOF:
193        case TypeId::InFile:
194            node = operands[0]; break;
195        case TypeId::And:
196            node = Z3_mk_and(mContext, n, operands); break;
197        case TypeId::Or:
198            node = Z3_mk_or(mContext, n, operands); break;
199        case TypeId::Xor:
200            node = Z3_mk_xor(mContext, operands[0], operands[1]);
201            Z3_inc_ref(mContext, node);
202            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
203                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
204                Z3_inc_ref(mContext, temp);
205                Z3_dec_ref(mContext, node);
206                node = temp;
207            }
208            return add(stmt, node);
209        case TypeId::Not:
210            node = Z3_mk_not(mContext, node);
211            break;
212        case TypeId::Sel:
213            node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
214            break;
215        case TypeId::Advance:
216            return characterize(cast<Advance>(stmt), operands[0]);
217        case TypeId::ScanThru:
218            // ScanThru(c, m) := (c + m) ∧ ¬m. Thus we can conservatively represent this statement using the BDD
219            // for ¬m --- provided no derivative of this statement is negated in any fashion.
220        case TypeId::MatchStar:
221        case TypeId::Count:
222            return make(stmt);
223        default:
224            handle_unexpected_statement(stmt);
225    }
226    Z3_inc_ref(mContext, node);
227    return add(stmt, node);
228}
229
230
231/** ------------------------------------------------------------------------------------------------------------- *
232 * @brief characterize
233 ** ------------------------------------------------------------------------------------------------------------- */
234inline Z3_ast MultiplexingPass::characterize(Advance * const adv, Z3_ast Ik) {
235    const auto k = mNegatedAdvance.size();
236
237    assert (adv);
238    assert (mConstraintGraph[k] == adv);
239
240    bool unconstrained[k] = {};
241
242    Z3_solver_push(mContext, mSolver);
243
244    for (size_t i = 0; i < k; ++i) {
245
246        // Have we already proven that they are unconstrained by their subset relationship?
247        if (unconstrained[i]) continue;
248
249        // If these Advances are mutually exclusive, in the same scope, transitively independent, and shift their
250        // values by the same amount, we can safely multiplex them. Otherwise mark the constraint in the graph.
251        const Advance * const ithAdv = mConstraintGraph[i];
252        if (ithAdv->getOperand(1) == adv->getOperand(1)) {
253
254            Z3_ast Ii = get(ithAdv->getOperand(0));
255
256            // Is there any satisfying truth assignment? If not, these streams are mutually exclusive.
257
258            Z3_solver_push(mContext, mSolver);
259            Z3_ast conj[2] = { Ii, Ik };
260            Z3_ast IiIk = Z3_mk_and(mContext, 2, conj);
261            Z3_inc_ref(mContext, IiIk);
262            Z3_solver_assert(mContext, mSolver, IiIk);
263            if (Z3_solver_check(mContext, mSolver) == Z3_L_FALSE) {
264                // If Ai ∩ Ak = ∅ and Aj ⊂ Ai, Aj ∩ Ak = ∅.
265                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
266                    unconstrained[source(e, mSubsetGraph)] = true;
267                }
268                unconstrained[i] = true;
269            } else if (equals(Ii, IiIk)) {
270                // If Ii = Ii ∩ Ik then Ii ⊆ Ik. Record this in the subset graph with the arc (i, k).
271                // Note: the AST will be modified to make these mutually exclusive if Ai and Ak end up in
272                // the same multiplexing set.
273                add_edge(i, k, mSubsetGraph);
274                // If Ai ⊂ Ak and Aj ⊂ Ai, Aj ⊂ Ak.
275                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
276                    const auto j = source(e, mSubsetGraph);
277                    add_edge(j, k, mSubsetGraph);
278                    unconstrained[j] = true;
279                }
280                unconstrained[i] = true;
281            } else if (equals(Ik, IiIk)) {
282                // If Ik = Ii ∩ Ik then Ik ⊆ Ii. Record this in the subset graph with the arc (k, i).
283                add_edge(k, i, mSubsetGraph);
284                // If Ak ⊂ Ai and Ai ⊂ Aj, Ak ⊂ Aj.
285                for (auto e : make_iterator_range(out_edges(i, mSubsetGraph))) {
286                    const auto j = target(e, mSubsetGraph);
287                    add_edge(k, j, mSubsetGraph);
288                    unconstrained[j] = true;
289                }
290                unconstrained[i] = true;
291            }
292            Z3_dec_ref(mContext, IiIk);
293            Z3_solver_pop(mContext, mSolver, 1);
294        }
295    }
296
297    Z3_solver_pop(mContext, mSolver, 1);
298
299    Z3_ast Ak0 = make(adv);
300    Z3_inc_ref(mContext, Ak0);
301    Z3_ast Nk = Z3_mk_not(mContext, Ak0);
302    Z3_inc_ref(mContext, Nk);
303
304    Z3_ast vars[k + 1];
305    vars[0] = Ak0;
306
307    unsigned m = 1;
308    for (unsigned i = 0; i < k; ++i) {
309        if (unconstrained[i]) {
310            // This algorithm deems two streams mutually exclusive if and only if their conjuntion is a contradiction.
311            // To generate a contradiction when comparing Advances, the BDD of each Advance is represented by the conjunction of
312            // variables representing the k-th Advance and the negation of all variables for the Advances whose inputs are mutually
313            // exclusive with the k-th input.
314
315            // For example, if the input of the i-th Advance is mutually exclusive with the input of the j-th and k-th Advance, the
316            // BDD of the i-th Advance is Ai ∧ ¬Aj ∧ ¬Ak. Similarly, the j- and k-th Advance is Aj ∧ ¬Ai and Ak ∧ ¬Ai, respectively
317            // (assuming that the j-th and k-th Advance are not mutually exclusive.)
318
319            Z3_ast & Ai0 = get(mConstraintGraph[i]);
320            Z3_ast conj[2] = { Ai0, Nk };
321            Z3_ast Ai = Z3_mk_and(mContext, 2, conj);
322            Z3_inc_ref(mContext, Ai);
323            Z3_dec_ref(mContext, Ai0);
324            Ai0 = Ai;
325            assert (get(mConstraintGraph[i]) == Ai);
326
327            vars[m++] = mNegatedAdvance[i];
328
329            continue; // note: if these Advances are transitively dependent, an edge will still exist.
330        }
331        add_edge(i, k, mConstraintGraph);
332    }
333    // To minimize the number of BDD computations, we store the negated variable instead of negating it each time.
334    mNegatedAdvance.emplace_back(Nk);
335    Z3_ast Ak = Z3_mk_and(mContext, m, vars);
336    if (LLVM_UNLIKELY(Ak != Ak0)) {
337        Z3_inc_ref(mContext, Ak);
338        Z3_dec_ref(mContext, Ak0);
339    }
340    return add(adv, Ak);
341}
342
343/** ------------------------------------------------------------------------------------------------------------- *
344 * @brief initialize
345 ** ------------------------------------------------------------------------------------------------------------- */
346Statement * MultiplexingPass::initialize(Statement * const initial) {
347
348    // clean up any unneeded refs / characterizations.
349    for (auto i = mCharacterization.begin(); i != mCharacterization.end(); ) {
350        const CharacterizationRef & r = std::get<1>(*i);
351        const auto e = i++;
352        if (LLVM_UNLIKELY(std::get<1>(r) == 0)) {
353            Z3_dec_ref(mContext, std::get<0>(r));
354            mCharacterization.erase(e);
355        }
356    }
357
358    for (Z3_ast var : mNegatedAdvance) {
359        Z3_dec_ref(mContext, var);
360    }
361    mNegatedAdvance.clear();
362
363    // Scan through and count all the advances and statements ...
364    unsigned statements = 0, advances = 0;
365    Statement * last = nullptr;
366    for (Statement * stmt = initial; stmt; stmt = stmt->getNextNode()) {
367        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
368            last = stmt;
369            break;
370        } else if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
371            ++advances;
372        }
373        ++statements;
374    }
375
376    flat_map<const PabloAST *, unsigned> M;
377    M.reserve(statements);
378    matrix<bool> G(statements, advances, false);
379    for (unsigned i = 0; i != advances; ++i) {
380        G(i, i) = true;
381    }
382
383    mConstraintGraph = ConstraintGraph(advances);
384    unsigned n = advances;
385    unsigned k = 0;
386    for (Statement * stmt = initial; stmt != last; stmt = stmt->getNextNode()) {
387        assert (!isa<If>(stmt) && !isa<While>(stmt));
388        unsigned u = 0;
389        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
390            mConstraintGraph[k] = cast<Advance>(stmt);
391            u = k++;
392        } else {
393            u = n++;
394        }
395        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
396            const PabloAST * const op = stmt->getOperand(i);
397            if (LLVM_LIKELY(isa<Statement>(op))) {
398                auto f = M.find(op);
399                if (f != M.end()) {
400                    const unsigned v = std::get<1>(*f);
401                    for (unsigned w = 0; w != k; ++w) {
402                        G(u, w) |= G(v, w);
403                    }
404                }
405            }
406        }
407        M.emplace(stmt, u);
408    }
409
410    assert (k == advances);
411
412    // Initialize the base constraint graph by transposing G and removing reflective loops
413    for (unsigned i = 0; i != advances; ++i) {
414        for (unsigned j = 0; j < i; ++j) {
415            if (G(i, j)) {
416                add_edge(j, i, mConstraintGraph);
417            }
418        }
419        for (unsigned j = i + 1; j < advances; ++j) {
420            if (G(i, j)) {
421                add_edge(j, i, mConstraintGraph);
422            }
423        }
424    }
425
426    mSubsetGraph = SubsetGraph(advances);
427    mNegatedAdvance.reserve(advances);
428
429    return last;
430}
431
432
433/** ------------------------------------------------------------------------------------------------------------- *
434 * @brief generateCandidateSets
435 ** ------------------------------------------------------------------------------------------------------------- */
436bool MultiplexingPass::generateCandidateSets(Statement * const begin, Statement * const end) {
437
438    const auto n = mNegatedAdvance.size();
439    if (LLVM_UNLIKELY(n < 3)) {
440        return false;
441    }
442    assert (num_vertices(mConstraintGraph) == n);
443
444    // The naive way to handle this would be to compute a DNF formula consisting of the
445    // powerset of all independent (candidate) sets of G, assign a weight to each, and
446    // try to maximally satisfy the clauses. However, this would be extremely costly to
447    // compute let alone solve as we could easily generate O(2^100) clauses for a complex
448    // problem. Further the vast majority of clauses would be false in the end.
449
450    // Moreover, for every set that can Advance is contained in would need a unique
451    // variable and selector. In other words:
452
453    // Suppose Advance A has a selector variable I. If I is true, then A must be in ONE set.
454    // Assume A could be in m sets. To enforce this, there are m(m - 1)/2 clauses:
455
456    //   (¬A_1 √ ¬A_2 √ ¬I), (¬A_1 √ ¬A_3 √ ¬I), ..., (¬A_m-1 √ ¬A_m √ ¬I)
457
458    // m here is be equivalent to number of independent sets in the constraint graph G
459    // that contains A.
460
461    // If two sets have a DEPENDENCY constraint between them, it will introduce a cyclic
462    // relationship even if those sets are legal on their own. Thus we'd also need need
463    // hard constraints between all constrained variables related to the pair of Advances.
464
465    // Instead, we only try to solve for one set at a time. This eliminate the need for
466    // the above constraints and computing m but this process to be closer to a simple
467    // greedy search.
468
469    // We do want to weight whether to include or exclude an item in a set but what should
470    // this be? The weight must be related to the elements already in the set. If our goal
471    // is to balance the perturbation of the AST with the reduction in # of Advances, the
472    // cost of inclusion / exclusion could be proportional to the # of instructions that
473    // it increases / decreases the span by --- but how many statements is an Advance worth?
474
475    // What if instead we maintain a queue of advances and discard any that are outside of
476    // the current window?
477
478    mCandidateGraph = CandidateGraph(n);
479
480    Z3_config cfg = Z3_mk_config();
481    Z3_set_param_value(cfg, "MODEL", "true");
482    Z3_context ctx = Z3_mk_context(cfg);
483    Z3_del_config(cfg);
484    Z3_solver solver = Z3_mk_solver(ctx);
485    Z3_solver_inc_ref(ctx, solver);
486    std::vector<Z3_ast> N(n);
487    for (unsigned i = 0; i < n; ++i) {
488        N[i] = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_bool_sort(ctx)); assert (N[i]);
489    }
490    std::vector<std::pair<unsigned, unsigned>> S;
491    S.reserve(n);
492
493    unsigned line = 0;
494    unsigned i = 0;
495    for (Statement * stmt = begin; stmt != end; stmt = stmt->getNextNode()) {
496        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
497            assert (S.empty() || line > std::get<0>(S.back()));
498            assert (cast<Advance>(stmt) == mConstraintGraph[i]);
499            if (S.size() > 0 && (line - std::get<0>(S.front())) > WindowSize) {
500                // try to compute a maximal set for this given set of Advances
501                if (S.size() > 2) {
502                    generateCandidateSets(ctx, solver, S, N);
503                }
504                // erase any that preceed our window
505                for (auto i = S.begin();;) {
506                    if (++i == S.end() || (line - std::get<0>(*i)) <= WindowSize) {
507                        S.erase(S.begin(), i);
508                        break;
509                    }
510                }
511            }
512            for (unsigned j : make_iterator_range(adjacent_vertices(i, mConstraintGraph))) {
513                Z3_ast disj[2] = { Z3_mk_not(ctx, N[j]), Z3_mk_not(ctx, N[i]) };
514                Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, disj));
515            }
516            S.emplace_back(line, i++);
517        }
518        ++line;
519    }
520    if (S.size() > 2) {
521        generateCandidateSets(ctx, solver, S, N);
522    }
523
524    Z3_solver_dec_ref(ctx, solver);
525    Z3_del_context(ctx);
526
527    return num_vertices(mCandidateGraph) > n;
528}
529
530/** ------------------------------------------------------------------------------------------------------------- *
531 * @brief generateCandidateSets
532 ** ------------------------------------------------------------------------------------------------------------- */
533void MultiplexingPass::generateCandidateSets(Z3_context ctx, Z3_solver solver, const std::vector<std::pair<unsigned, unsigned>> & S, const std::vector<Z3_ast> & N) {
534    assert (S.size() > 2);
535    assert (std::get<0>(S.front()) < std::get<0>(S.back()));
536    assert ((std::get<0>(S.back()) - std::get<0>(S.front())) <= WindowSize);
537    Z3_solver_push(ctx, solver);
538    const auto n = N.size();
539    std::vector<Z3_ast> assumptions(S.size());
540    for (unsigned i = 0, j = 0; i < n; ++i) {
541        if (LLVM_UNLIKELY(j < S.size() && std::get<1>(S[j]) == i)) { // in our window range
542            assumptions[j++] = N[i];
543        } else {
544            Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, N[i]));
545        }
546    }
547    if (maxsat(ctx, solver, assumptions) != Z3_L_FALSE) {
548        Z3_model m = Z3_solver_get_model(ctx, solver);
549        Z3_model_inc_ref(ctx, m);
550        const auto k = add_vertex(mCandidateGraph); assert(k >= N.size());
551        Z3_ast TRUE = Z3_mk_true(ctx);
552        Z3_ast FALSE = Z3_mk_false(ctx);
553        for (const auto i : S) {
554            Z3_ast value;
555            if (LLVM_UNLIKELY(Z3_model_eval(ctx, m, N[std::get<1>(i)], 1, &value) != Z3_TRUE)) {
556                throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from constraint model!");
557            }
558            if (value == TRUE) {
559                add_edge(std::get<1>(i), k, mCandidateGraph);
560            } else if (LLVM_UNLIKELY(value != FALSE)) {
561                throw std::runtime_error("Unexpected Z3 error constraint model value is a non-terminal!");
562            }
563        }
564        Z3_model_dec_ref(ctx, m);
565    }
566    Z3_solver_pop(ctx, solver, 1);
567}
568
569/** ------------------------------------------------------------------------------------------------------------- *
570 * @brief is_power_of_2
571 * @param n an integer
572 ** ------------------------------------------------------------------------------------------------------------- */
573static inline bool is_power_of_2(const size_t n) {
574    return ((n & (n - 1)) == 0);
575}
576
577/** ------------------------------------------------------------------------------------------------------------- *
578 * @brief log2_plus_one
579 ** ------------------------------------------------------------------------------------------------------------- */
580static inline size_t log2_plus_one(const size_t n) {
581    return std::log2<size_t>(n) + 1;
582}
583
584/** ------------------------------------------------------------------------------------------------------------- *
585 * @brief selectMultiplexSetsGreedy
586 *
587 * This algorithm is simply computes a greedy set cover. We want an exact max-weight set cover but can generate new
588 * sets by taking a subset of any existing set. With a few modifications, the greedy approach seems to work well
589 * enough but can be shown to produce a suboptimal solution if there are three candidate sets labelled A, B and C,
590 * in which A ∩ B = ∅, |A| ≀ |B| < |C|, and C ⊂ (A ∪ B).
591 ** ------------------------------------------------------------------------------------------------------------- */
592void MultiplexingPass::selectMultiplexSetsGreedy() {
593
594    using AdjIterator = graph_traits<CandidateGraph>::adjacency_iterator;
595    using degree_t = CandidateGraph::degree_size_type;
596    using vertex_t = CandidateGraph::vertex_descriptor;
597
598    const size_t m = num_vertices(mConstraintGraph);
599    const size_t n = num_vertices(mCandidateGraph) - m;
600
601    bool chosen[n] = {};
602
603    for (;;) {
604
605        // Choose the set with the greatest number of vertices not already included in some other set.
606        vertex_t u = 0;
607        degree_t w = 0;
608        for (vertex_t i = 0; i != n; ++i) {
609            if (chosen[i]) continue;
610            const auto t = i + m;
611            degree_t r = degree(t, mCandidateGraph);
612            if (LLVM_LIKELY(r >= 3)) { // if this set has at least 3 elements.
613                if (w < r) {
614                    u = t;
615                    w = r;
616                }
617            } else if (r) {
618                clear_vertex(t, mCandidateGraph);
619            }
620        }
621
622        // Multiplexing requires 3 or more elements; if no set contains at least 3, abort.
623        if (LLVM_UNLIKELY(w == 0)) {
624            break;
625        }
626
627        chosen[u - m] = true;
628
629        // If this contains 2^n elements for any n, discard the member that is most likely to be added
630        // to some future set.
631        if (LLVM_UNLIKELY(is_power_of_2(degree(u, mCandidateGraph)))) {
632            vertex_t x = 0;
633            degree_t w = 0;
634            for (const auto v : make_iterator_range(adjacent_vertices(u, mCandidateGraph))) {
635                if (degree(v, mCandidateGraph) > w) {
636                    x = v;
637                    w = degree(v, mCandidateGraph);
638                }
639            }
640            remove_edge(u, x, mCandidateGraph);
641        }
642
643        AdjIterator begin, end;
644        std::tie(begin, end) = adjacent_vertices(u, mCandidateGraph);
645        for (auto vi = begin; vi != end; ) {
646            const auto v = *vi++;
647            clear_vertex(v, mCandidateGraph);
648            add_edge(v, u, mCandidateGraph);
649        }
650
651    }
652
653    #ifndef NDEBUG
654    for (unsigned i = 0; i != m; ++i) {
655        assert (degree(i, mCandidateGraph) <= 1);
656    }
657    for (unsigned i = m; i != (m + n); ++i) {
658        assert (degree(i, mCandidateGraph) == 0 || degree(i, mCandidateGraph) >= 3);
659    }
660    #endif
661}
662
663/** ------------------------------------------------------------------------------------------------------------- *
664 * @brief eliminateSubsetConstraints
665 ** ------------------------------------------------------------------------------------------------------------- */
666void MultiplexingPass::eliminateSubsetConstraints() {
667    using SubsetEdgeIterator = graph_traits<SubsetGraph>::edge_iterator;
668    // If Ai ⊂ Aj then the subset graph will contain the arc (i, j). Remove all arcs corresponding to vertices
669    // that are not elements of the same multiplexing set.
670    SubsetEdgeIterator ei, ei_end, ei_next;
671    std::tie(ei, ei_end) = edges(mSubsetGraph);
672    for (ei_next = ei; ei != ei_end; ei = ei_next) {
673        ++ei_next;
674        const auto u = source(*ei, mSubsetGraph);
675        const auto v = target(*ei, mSubsetGraph);
676        if (degree(u, mCandidateGraph) != 0 && degree(v, mCandidateGraph) != 0) {
677            assert (degree(u, mCandidateGraph) == 1);
678            assert (degree(v, mCandidateGraph) == 1);
679            const auto su = *(adjacent_vertices(u, mCandidateGraph).first);
680            const auto sv = *(adjacent_vertices(v, mCandidateGraph).first);
681            if (su == sv) {
682                continue;
683            }
684        }
685        remove_edge(*ei, mSubsetGraph);
686    }
687
688    if (num_edges(mSubsetGraph) != 0) {
689
690        // At least one subset constraint exists; perform a transitive reduction on the graph to ensure that
691        // we perform the minimum number of AST modifications for the selected multiplexing sets.
692
693        doTransitiveReductionOfSubsetGraph();
694
695        // Afterwards modify the AST to ensure that multiplexing algorithm can ignore any subset constraints
696        for (auto e : make_iterator_range(edges(mSubsetGraph))) {
697            Advance * const adv1 = mConstraintGraph[source(e, mSubsetGraph)];
698            Advance * const adv2 = mConstraintGraph[target(e, mSubsetGraph)];
699            assert (adv1->getParent() == adv2->getParent());
700            PabloBlock * const pb = adv1->getParent();
701            pb->setInsertPoint(adv2->getPrevNode());
702            adv2->setOperand(0, pb->createAnd(adv2->getOperand(0), pb->createNot(adv1->getOperand(0)), "subset"));
703            pb->setInsertPoint(adv2);
704            adv2->replaceAllUsesWith(pb->createOr(adv1, adv2, "merge"));
705        }
706
707    }
708}
709
710///** ------------------------------------------------------------------------------------------------------------- *
711// * Topologically sort the sequence of instructions whilst trying to adhere as best as possible to the original
712// * program sequence.
713// ** ------------------------------------------------------------------------------------------------------------- */
714//inline bool topologicalSort(Z3_context ctx, Z3_solver solver, const std::vector<Z3_ast> & nodes, const int limit) {
715//    const auto n = nodes.size();
716//    if (LLVM_UNLIKELY(n == 0)) {
717//        return true;
718//    }
719//    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
720//        return false;
721//    }
722
723//    Z3_ast aux_vars[n];
724//    Z3_ast assumptions[n];
725//    Z3_ast ordering[n];
726//    int increments[n];
727
728//    Z3_sort boolTy = Z3_mk_bool_sort(ctx);
729//    Z3_sort intTy = Z3_mk_int_sort(ctx);
730//    Z3_ast one = Z3_mk_int(ctx, 1, intTy);
731
732//    for (unsigned i = 0; i < n; ++i) {
733//        aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, boolTy);
734//        assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
735//        Z3_ast num = one;
736//        if (i > 0) {
737//            Z3_ast prior_plus_one[2] = { nodes[i - 1], one };
738//            num = Z3_mk_add(ctx, 2, prior_plus_one);
739//        }
740//        ordering[i] = Z3_mk_eq(ctx, nodes[i], num);
741//        increments[i] = 1;
742//    }
743
744//    unsigned unsat = 0;
745
746//    for (;;) {
747//        Z3_solver_push(ctx, solver);
748//        for (unsigned i = 0; i < n; ++i) {
749//            Z3_ast constraint[2] = {ordering[i], aux_vars[i]};
750//            Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, constraint));
751//        }
752//        if (LLVM_UNLIKELY(Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE)) {
753//            errs() << " SATISFIABLE!  (" << unsat << " of " << n << ")\n";
754//            return true; // done
755//        }
756//        Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver); assert (core);
757//        unsigned m = Z3_ast_vector_size(ctx, core); assert (m > 0);
758
759//        errs() << " UNSATISFIABLE " << m << "  (" << unsat << " of " << n <<")\n";
760
761//        for (unsigned j = 0; j < m; j++) {
762//            // check whether assumption[i] is in the core or not
763//            bool not_found = true;
764//            for (unsigned i = 0; i < n; i++) {
765//                if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
766
767//                    const auto k = increments[i];
768
769//                    errs() << " -- " << i << " @k=" << k << "\n";
770
771//                    if (k < limit) {
772//                        Z3_ast gap = Z3_mk_int(ctx, 1UL << k, intTy);
773//                        Z3_ast num = gap;
774//                        if (LLVM_LIKELY(i > 0)) {
775//                            Z3_ast prior_plus_gap[2] = { nodes[i - 1], gap };
776//                            num = Z3_mk_add(ctx, 2, prior_plus_gap);
777//                        }
778//                        Z3_dec_ref(ctx, ordering[i]);
779//                        ordering[i] = Z3_mk_le(ctx, num, nodes[i]);
780//                    } else if (k == limit && i > 0) {
781//                        ordering[i] = Z3_mk_le(ctx, nodes[i - 1], nodes[i]);
782//                    } else {
783//                        assumptions[i] = aux_vars[i]; // <- trivially satisfiable
784//                        ++unsat;
785//                    }
786//                    increments[i] = k + 1;
787//                    not_found = false;
788//                    break;
789//                }
790//            }
791//            if (LLVM_UNLIKELY(not_found)) {
792//                throw std::runtime_error("Unexpected Z3 failure when attempting to locate unsatisfiable ordering constraint!");
793//            }
794//        }
795//        Z3_solver_pop(ctx, solver, 1);
796//    }
797//    llvm_unreachable("maxsat wrongly reported this being unsatisfiable despite being able to satisfy the hard constraints!");
798//    return false;
799//}
800
801///** ------------------------------------------------------------------------------------------------------------- *
802// * Topologically sort the sequence of instructions whilst trying to adhere as best as possible to the original
803// * program sequence.
804// ** ------------------------------------------------------------------------------------------------------------- */
805//inline bool topologicalSort(Z3_context ctx, Z3_solver solver, const std::vector<Z3_ast> & nodes, const int limit) {
806//    const auto n = nodes.size();
807//    if (LLVM_UNLIKELY(n == 0)) {
808//        return true;
809//    }
810//    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
811//        return false;
812//    }
813
814//    Z3_ast aux_vars[n];
815//    Z3_ast assumptions[n];
816
817//    Z3_sort boolTy = Z3_mk_bool_sort(ctx);
818//    Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
819
820//    for (unsigned i = 0; i < n; ++i) {
821//        aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, boolTy);
822//        assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
823//        Z3_ast num = one;
824//        if (i > 0) {
825//            Z3_ast prior_plus_one[2] = { nodes[i - 1], one };
826//            num = Z3_mk_add(ctx, 2, prior_plus_one);
827//        }
828//        Z3_ast ordering = Z3_mk_eq(ctx, nodes[i], num);
829//        Z3_ast constraint[2] = {ordering, aux_vars[i]};
830//        Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, constraint));
831//    }
832
833//    for (unsigned k = 0; k < n; ) {
834//        if (LLVM_UNLIKELY(Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE)) {
835//            errs() << " SATISFIABLE!\n";
836//            return true; // done
837//        }
838//        Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver); assert (core);
839//        unsigned m = Z3_ast_vector_size(ctx, core); assert (m > 0);
840
841//        k += m;
842
843//        errs() << " UNSATISFIABLE " << m << " (" << k << ")\n";
844
845//        for (unsigned j = 0; j < m; j++) {
846//            // check whether assumption[i] is in the core or not
847//            bool not_found = true;
848//            for (unsigned i = 0; i < n; i++) {
849//                if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
850//                    assumptions[i] = aux_vars[i];
851//                    not_found = false;
852//                    break;
853//                }
854//            }
855//            if (LLVM_UNLIKELY(not_found)) {
856//                throw std::runtime_error("Unexpected Z3 failure when attempting to locate unsatisfiable ordering constraint!");
857//            }
858//        }
859//    }
860//    llvm_unreachable("maxsat wrongly reported this being unsatisfiable despite being able to satisfy the hard constraints!");
861//    return false;
862//}
863
864
865/** ------------------------------------------------------------------------------------------------------------- *
866 * @brief addWithHardConstraints
867 ** ------------------------------------------------------------------------------------------------------------- */
868Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, Statement * const stmt, flat_map<Statement *, Z3_ast> & M) {
869    assert (M.count(stmt) == 0 && stmt->getParent() == block);
870    // compute the hard dependency constraints
871    Z3_ast node = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_int_sort(ctx)); assert (node);
872    // we want all numbers to be positive so that the soft assertion that the first statement ought to stay at the first location
873    // whenever possible isn't satisfied by making preceeding numbers negative.
874    Z3_solver_assert(ctx, solver, Z3_mk_gt(ctx, node, Z3_mk_int(ctx, 0, Z3_mk_int_sort(ctx))));
875    for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
876        PabloAST * const op = stmt->getOperand(i);
877        if (isa<Statement>(op) && cast<Statement>(op)->getParent() == block) {
878            const auto f = M.find(cast<Statement>(op));
879            if (f != M.end()) {
880                Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, f->second, node));
881            }
882        }
883    }
884    M.emplace(stmt, node);
885    return node;
886}
887
888/** ------------------------------------------------------------------------------------------------------------- *
889 * @brief dominates
890 *
891 * does Statement a dominate Statement b?
892 ** ------------------------------------------------------------------------------------------------------------- */
893bool dominates(const Statement * const a, const Statement * const b) {
894    assert (a);
895    if (LLVM_UNLIKELY(b == nullptr)) {
896        return false;
897    }
898    assert (a->getParent() == b->getParent());
899    for (const Statement * t : *a->getParent()) {
900        if (t == a) {
901            return true;
902        } else if (t == b) {
903            return false;
904        }
905    }
906    llvm_unreachable("Neither a nor b are in their reported block!");
907    return false;
908}
909
910/** ------------------------------------------------------------------------------------------------------------- *
911 * @brief addWithHardConstraints
912 ** ------------------------------------------------------------------------------------------------------------- */
913Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, PabloAST * expr, flat_map<Statement *, Z3_ast> & M, Statement * const ip) {
914    if (isa<Statement>(expr)) {
915        Statement * const stmt = cast<Statement>(expr);
916        if (stmt->getParent() == block) {
917            const auto f = M.find(stmt);
918            if (LLVM_UNLIKELY(f != M.end())) {
919                return f->second;
920            } else if (!dominates(stmt, ip)) {
921                for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
922                    addWithHardConstraints(ctx, solver, block, stmt->getOperand(i), M, ip);
923                }
924                return addWithHardConstraints(ctx, solver, block, stmt, M);
925            }
926        }
927    }
928    return nullptr;
929}
930
931/** ------------------------------------------------------------------------------------------------------------- *
932 * @brief multiplexSelectedSets
933 ** ------------------------------------------------------------------------------------------------------------- */
934inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block, Statement * const begin, Statement * const end) {
935
936    assert ("begin cannot be null!" && begin);
937    assert (begin->getParent() == block);
938    assert (!end || end->getParent() == block);
939    assert (!end || isa<If>(end) || isa<While>(end));
940
941    Z3_config cfg = Z3_mk_config();
942    Z3_set_param_value(cfg, "MODEL", "true");
943    Z3_context ctx = Z3_mk_context(cfg);
944    Z3_del_config(cfg);
945    Z3_solver solver = Z3_mk_solver(ctx);
946    Z3_solver_inc_ref(ctx, solver);
947
948    const auto first_set = num_vertices(mConstraintGraph);
949    const auto last_set = num_vertices(mCandidateGraph);
950
951    for (auto idx = first_set; idx != last_set; ++idx) {
952        const size_t n = degree(idx, mCandidateGraph);
953        if (n) {
954            const size_t m = log2_plus_one(n); assert (n > 2 && m < n);
955            Advance * input[n];
956            PabloAST * muxed[m];
957            PabloAST * muxed_n[m];
958
959            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
960            unsigned i = 0;
961            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) {
962                input[i] = mConstraintGraph[u];
963                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
964                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
965                assert ("Inputs are not in sequential order!" && (i == 0 || (i > 0 && dominates(input[i - 1], input[i]))));
966                ++i;
967            }
968
969            Statement * const A1 = input[0];
970            Statement * const An = input[n - 1]->getNextNode();
971
972            Statement * const ip = A1->getPrevNode(); // save our insertion point prior to modifying the AST
973
974            Z3_solver_push(ctx, solver);
975
976            // Compute the hard and soft constraints for any part of the AST that we are not intending to modify.
977            flat_map<Statement *, Z3_ast> M;
978
979            Z3_ast prior = nullptr;
980            Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
981            std::vector<Z3_ast> ordering;
982//            std::vector<Z3_ast> nodes;
983
984            for (Statement * stmt = A1; stmt != An; stmt = stmt->getNextNode()) { assert (stmt != ip);
985                Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
986                // compute the soft ordering constraints
987                Z3_ast num = one;
988                if (prior) {
989                    Z3_ast prior_plus_one[2] = { prior, one };
990                    num = Z3_mk_add(ctx, 2, prior_plus_one);
991                }
992                ordering.push_back(Z3_mk_eq(ctx, node, num));
993                if (prior) {
994                    ordering.push_back(Z3_mk_lt(ctx, prior, node));
995                }
996
997
998//                for (Z3_ast prior : nodes) {
999//                    Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, Z3_mk_eq(ctx, prior, node)));
1000//                }
1001 //               nodes.push_back(node);
1002
1003
1004                prior = node;
1005            }
1006
1007            // assert (nodes.size() <= WindowSize);
1008
1009            block->setInsertPoint(block->back()); // <- necessary for domination check!
1010
1011            circular_buffer<PabloAST *> Q(n);
1012
1013            /// Perform n-to-m Multiplexing
1014            for (size_t j = 0; j != m; ++j) {
1015                std::ostringstream prefix;
1016                prefix << "mux" << n << "to" << m << '.' << (j);
1017                assert (Q.empty());
1018                for (size_t i = 0; i != n; ++i) {
1019                    if (((i + 1) & (1UL << j)) != 0) {
1020                        Q.push_back(input[i]->getOperand(0));
1021                    }
1022                }
1023                while (Q.size() > 1) {
1024                    PabloAST * a = Q.front(); Q.pop_front();
1025                    PabloAST * b = Q.front(); Q.pop_front();
1026                    PabloAST * expr = block->createOr(a, b);
1027                    addWithHardConstraints(ctx, solver, block, expr, M, ip);
1028                    Q.push_back(expr);
1029                }
1030                PabloAST * const muxing = Q.front(); Q.clear();
1031                muxed[j] = block->createAdvance(muxing, input[0]->getOperand(1), prefix.str());
1032                addWithHardConstraints(ctx, solver, block, muxed[j], M, ip);
1033                muxed_n[j] = block->createNot(muxed[j]);
1034                addWithHardConstraints(ctx, solver, block, muxed_n[j], M, ip);
1035            }
1036
1037            /// Perform m-to-n Demultiplexing
1038            for (size_t i = 0; i != n; ++i) {
1039                // Construct the demuxed values and replaces all the users of the original advances with them.
1040                assert (Q.empty());
1041                for (size_t j = 0; j != m; ++j) {
1042                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
1043                }
1044                Z3_ast replacement = nullptr;
1045                while (Q.size() > 1) {
1046                    PabloAST * const a = Q.front(); Q.pop_front();
1047                    PabloAST * const b = Q.front(); Q.pop_front();
1048                    PabloAST * expr = block->createAnd(a, b);
1049                    replacement = addWithHardConstraints(ctx, solver, block, expr, M, ip);
1050                    Q.push_back(expr);
1051                }
1052                assert (replacement);
1053                PabloAST * const demuxed = Q.front(); Q.clear();
1054
1055                const auto f = M.find(input[i]);
1056                assert (f != M.end());
1057                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
1058                M.erase(f);
1059
1060                input[i]->replaceWith(demuxed);
1061                assert (M.count(input[i]) == 0);
1062            }
1063
1064            assert (M.count(ip) == 0);
1065
1066            if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) != Z3_L_TRUE)) {
1067                throw std::runtime_error("Unexpected Z3 failure when attempting to topologically sort the AST!");
1068            }
1069
1070            Z3_model model = Z3_solver_get_model(ctx, solver);
1071            Z3_model_inc_ref(ctx, model);
1072
1073            std::vector<std::pair<long long int, Statement *>> I;
1074
1075            for (const auto i : M) {
1076                Z3_ast value;
1077                if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE)) {
1078                    throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1079                }
1080                long long int line;
1081                if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
1082                    throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1083                }
1084                I.emplace_back(line, std::get<0>(i));
1085            }
1086
1087            Z3_model_dec_ref(ctx, model);
1088
1089            std::sort(I.begin(), I.end());
1090
1091            block->setInsertPoint(ip);
1092            for (auto i : I) {
1093                block->insert(std::get<1>(i));
1094            }
1095
1096            Z3_solver_pop(ctx, solver, 1);
1097        }
1098    }
1099
1100    Z3_solver_dec_ref(ctx, solver);
1101    Z3_del_context(ctx);
1102
1103}
1104
1105///** ------------------------------------------------------------------------------------------------------------- *
1106// * @brief multiplexSelectedSets
1107// ** ------------------------------------------------------------------------------------------------------------- */
1108//inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block, Statement * const begin, Statement * const end) {
1109
1110//    assert ("begin cannot be null!" && begin);
1111//    assert (begin->getParent() == block);
1112//    assert (!end || end->getParent() == block);
1113//    assert (!end || isa<If>(end) || isa<While>(end));
1114
1115//    Statement * const ip = begin->getPrevNode(); // save our insertion point prior to modifying the AST
1116
1117//    Z3_config cfg = Z3_mk_config();
1118//    Z3_set_param_value(cfg, "MODEL", "true");
1119//    Z3_context ctx = Z3_mk_context(cfg);
1120//    Z3_del_config(cfg);
1121//    Z3_solver solver = Z3_mk_solver(ctx);
1122//    Z3_solver_inc_ref(ctx, solver);
1123
1124//    const auto first_set = num_vertices(mConstraintGraph);
1125//    const auto last_set = num_vertices(mCandidateGraph);
1126
1127//    // Compute the hard and soft constraints for any part of the AST that we are not intending to modify.
1128//    flat_map<Statement *, Z3_ast> M;
1129
1130//    Z3_ast prior = nullptr;
1131//    Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
1132//    std::vector<Z3_ast> ordering;
1133
1134//    for (Statement * stmt = begin; stmt != end; stmt = stmt->getNextNode()) { assert (stmt != ip);
1135//        Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
1136//        // compute the soft ordering constraints
1137//        Z3_ast num = one;
1138//        if (prior) {
1139//            Z3_ast prior_plus_one[2] = { prior, one };
1140//            num = Z3_mk_add(ctx, 2, prior_plus_one);
1141//        }
1142//        ordering.push_back(Z3_mk_eq(ctx, node, num));
1143//        prior = node;
1144//    }
1145
1146//    block->setInsertPoint(block->back()); // <- necessary for domination check!
1147
1148//    errs() << "---------------------------------------------\n";
1149
1150//    for (auto idx = first_set; idx != last_set; ++idx) {
1151//        const size_t n = degree(idx, mCandidateGraph);
1152//        if (n) {
1153//            const size_t m = log2_plus_one(n); assert (n > 2 && m < n);
1154//            Advance * input[n];
1155//            PabloAST * muxed[m];
1156//            PabloAST * muxed_n[m];
1157
1158//            errs() << n << " -> " << m << "\n";
1159
1160//            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
1161//            unsigned i = 0;
1162//            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) {
1163//                input[i] = mConstraintGraph[u];
1164//                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
1165//                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
1166//                ++i;
1167//            }
1168
1169//            circular_buffer<PabloAST *> Q(n);
1170
1171//            /// Perform n-to-m Multiplexing
1172//            for (size_t j = 0; j != m; ++j) {
1173//                std::ostringstream prefix;
1174//                prefix << "mux" << n << "to" << m << '.' << (j);
1175//                assert (Q.empty());
1176//                for (size_t i = 0; i != n; ++i) {
1177//                    if (((i + 1) & (1UL << j)) != 0) {
1178//                        Q.push_back(input[i]->getOperand(0));
1179//                    }
1180//                }
1181//                while (Q.size() > 1) {
1182//                    PabloAST * a = Q.front(); Q.pop_front();
1183//                    PabloAST * b = Q.front(); Q.pop_front();
1184//                    PabloAST * expr = block->createOr(a, b);
1185//                    addWithHardConstraints(ctx, solver, block, expr, M, ip);
1186//                    Q.push_back(expr);
1187//                }
1188//                PabloAST * const muxing = Q.front(); Q.clear();
1189//                muxed[j] = block->createAdvance(muxing, input[0]->getOperand(1), prefix.str());
1190//                addWithHardConstraints(ctx, solver, block, muxed[j], M, ip);
1191//                muxed_n[j] = block->createNot(muxed[j]);
1192//                addWithHardConstraints(ctx, solver, block, muxed_n[j], M, ip);
1193//            }
1194
1195//            /// Perform m-to-n Demultiplexing
1196//            for (size_t i = 0; i != n; ++i) {
1197//                // Construct the demuxed values and replaces all the users of the original advances with them.
1198//                assert (Q.empty());
1199//                for (size_t j = 0; j != m; ++j) {
1200//                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
1201//                }
1202//                Z3_ast replacement = nullptr;
1203//                while (Q.size() > 1) {
1204//                    PabloAST * const a = Q.front(); Q.pop_front();
1205//                    PabloAST * const b = Q.front(); Q.pop_front();
1206//                    PabloAST * expr = block->createAnd(a, b);
1207//                    replacement = addWithHardConstraints(ctx, solver, block, expr, M, ip);
1208//                    Q.push_back(expr);
1209//                }
1210//                assert (replacement);
1211//                PabloAST * const demuxed = Q.front(); Q.clear();
1212
1213//                const auto f = M.find(input[i]);
1214//                assert (f != M.end());
1215//                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
1216//                M.erase(f);
1217
1218//                input[i]->replaceWith(demuxed);
1219//                assert (M.count(input[i]) == 0);
1220//            }
1221//        }
1222//    }
1223
1224//    assert (M.count(ip) == 0);
1225
1226//    // if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) == Z3_L_FALSE)) {
1227//    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) != Z3_L_TRUE)) {
1228//        throw std::runtime_error("Unexpected Z3 failure when attempting to topologically sort the AST!");
1229//    }
1230
1231//    Z3_model m = Z3_solver_get_model(ctx, solver);
1232//    Z3_model_inc_ref(ctx, m);
1233
1234//    std::vector<std::pair<long long int, Statement *>> Q;
1235
1236//    errs() << "-----------------------------------------------------------\n";
1237
1238//    for (const auto i : M) {
1239//        Z3_ast value;
1240//        if (Z3_model_eval(ctx, m, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE) {
1241//            throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
1242//        }
1243//        long long int line;
1244//        if (Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE) {
1245//            throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
1246//        }
1247//        Q.emplace_back(line, std::get<0>(i));
1248//    }
1249
1250//    Z3_model_dec_ref(ctx, m);
1251//    Z3_solver_dec_ref(ctx, solver);
1252//    Z3_del_context(ctx);
1253
1254//    std::sort(Q.begin(), Q.end());
1255
1256//    block->setInsertPoint(ip);
1257//    for (auto i : Q) {
1258//        block->insert(std::get<1>(i));
1259//    }
1260//}
1261
1262/** ------------------------------------------------------------------------------------------------------------- *
1263 * @brief doTransitiveReductionOfSubsetGraph
1264 ** ------------------------------------------------------------------------------------------------------------- */
1265void MultiplexingPass::doTransitiveReductionOfSubsetGraph() {
1266    std::vector<SubsetGraph::vertex_descriptor> Q;
1267    for (auto u : make_iterator_range(vertices(mSubsetGraph))) {
1268        if (in_degree(u, mSubsetGraph) == 0 && out_degree(u, mSubsetGraph) != 0) {
1269            Q.push_back(u);
1270        }
1271    }
1272    flat_set<SubsetGraph::vertex_descriptor> targets;
1273    flat_set<SubsetGraph::vertex_descriptor> visited;
1274    do {
1275        const auto u = Q.back(); Q.pop_back();
1276        for (auto ei : make_iterator_range(out_edges(u, mSubsetGraph))) {
1277            for (auto ej : make_iterator_range(out_edges(target(ei, mSubsetGraph), mSubsetGraph))) {
1278                targets.insert(target(ej, mSubsetGraph));
1279            }
1280        }
1281        for (auto v : targets) {
1282            remove_edge(u, v, mSubsetGraph);
1283        }
1284        for (auto e : make_iterator_range(out_edges(u, mSubsetGraph))) {
1285            const auto v = target(e, mSubsetGraph);
1286            if (visited.insert(v).second) {
1287                Q.push_back(v);
1288            }
1289        }
1290    } while (!Q.empty());
1291}
1292
1293/** ------------------------------------------------------------------------------------------------------------- *
1294 * @brief get
1295 ** ------------------------------------------------------------------------------------------------------------- */
1296inline Z3_ast & MultiplexingPass::get(const PabloAST * const expr, const bool deref) {
1297    assert (expr);
1298    auto f = mCharacterization.find(expr);
1299    assert (f != mCharacterization.end());
1300    auto & val = f->second;
1301    if (deref) {
1302        unsigned & refs = std::get<1>(val);
1303        assert (refs > 0);
1304        --refs;
1305    }
1306    return std::get<0>(val);
1307}
1308
1309/** ------------------------------------------------------------------------------------------------------------- *
1310 * @brief make
1311 ** ------------------------------------------------------------------------------------------------------------- */
1312inline Z3_ast MultiplexingPass::make(const PabloAST * const expr) {
1313    assert (expr);
1314    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
1315    Z3_inc_ref(mContext, node);
1316    return add(expr, node);
1317}
1318
1319/** ------------------------------------------------------------------------------------------------------------- *
1320 * @brief add
1321 ** ------------------------------------------------------------------------------------------------------------- */
1322inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node) {   
1323    mCharacterization.insert(std::make_pair(expr, std::make_pair(node, expr->getNumUses())));
1324    return node;
1325}
1326
1327/** ------------------------------------------------------------------------------------------------------------- *
1328 * @brief constructor
1329 ** ------------------------------------------------------------------------------------------------------------- */
1330inline MultiplexingPass::MultiplexingPass(PabloFunction & f, const RNG::result_type seed, Z3_context context, Z3_solver solver)
1331: mContext(context)
1332, mSolver(solver)
1333, mFunction(f)
1334, mRNG(seed)
1335, mConstraintGraph(0)
1336{
1337
1338}
1339
1340
1341inline Z3_ast mk_binary_or(Z3_context ctx, Z3_ast in_1, Z3_ast in_2) {
1342    Z3_ast args[2] = { in_1, in_2 };
1343    return Z3_mk_or(ctx, 2, args);
1344}
1345
1346inline Z3_ast mk_ternary_or(Z3_context ctx, Z3_ast in_1, Z3_ast in_2, Z3_ast in_3) {
1347    Z3_ast args[3] = { in_1, in_2, in_3 };
1348    return Z3_mk_or(ctx, 3, args);
1349}
1350
1351inline Z3_ast mk_binary_and(Z3_context ctx, Z3_ast in_1, Z3_ast in_2) {
1352    Z3_ast args[2] = { in_1, in_2 };
1353    return Z3_mk_and(ctx, 2, args);
1354}
1355
1356///**
1357//   \brief Create a full adder with inputs \c in_1, \c in_2 and \c cin.
1358//   The output of the full adder is stored in \c out, and the carry in \c c_out.
1359//*/
1360//inline std::pair<Z3_ast, Z3_ast> mk_full_adder(Z3_context ctx, Z3_ast in_1, Z3_ast in_2, Z3_ast cin) {
1361//    Z3_ast out = Z3_mk_xor(ctx, Z3_mk_xor(ctx, in_1, in_2), cin);
1362//    Z3_ast cout = mk_ternary_or(ctx, mk_binary_and(ctx, in_1, in_2), mk_binary_and(ctx, in_1, cin), mk_binary_and(ctx, in_2, cin));
1363//    return std::make_pair(out, cout);
1364//}
1365
1366/**
1367   \brief Create an adder for inputs of size \c num_bits.
1368   The arguments \c in1 and \c in2 are arrays of bits of size \c num_bits.
1369
1370   \remark \c result must be an array of size \c num_bits + 1.
1371*/
1372void mk_adder(Z3_context ctx, const unsigned num_bits, Z3_ast * in_1, Z3_ast * in_2, Z3_ast * result) {
1373    Z3_ast cin = Z3_mk_false(ctx);
1374    for (unsigned i = 0; i < num_bits; i++) {
1375        result[i] = Z3_mk_xor(ctx, Z3_mk_xor(ctx, in_1[i], in_2[i]), cin);
1376        cin = mk_ternary_or(ctx, mk_binary_and(ctx, in_1[i], in_2[i]), mk_binary_and(ctx, in_1[i], cin), mk_binary_and(ctx, in_2[i], cin));
1377    }
1378    result[num_bits] = cin;
1379}
1380
1381/**
1382   \brief Given \c num_ins "numbers" of size \c num_bits stored in \c in.
1383   Create floor(num_ins/2) adder circuits. Each circuit is adding two consecutive "numbers".
1384   The numbers are stored one after the next in the array \c in.
1385   That is, the array \c in has size num_bits * num_ins.
1386   Return an array of bits containing \c ceil(num_ins/2) numbers of size \c (num_bits + 1).
1387   If num_ins/2 is not an integer, then the last "number" in the output, is the last "number" in \c in with an appended "zero".
1388*/
1389unsigned mk_adder_pairs(Z3_context ctx, const unsigned num_bits, const unsigned num_ins, Z3_ast * in, Z3_ast * out) {
1390    unsigned out_num_bits = num_bits + 1;
1391    Z3_ast * _in          = in;
1392    Z3_ast * _out         = out;
1393    unsigned out_num_ins  = (num_ins % 2 == 0) ? (num_ins / 2) : (num_ins / 2) + 1;
1394    for (unsigned i = 0; i < num_ins / 2; i++) {
1395        mk_adder(ctx, num_bits, _in, _in + num_bits, _out);
1396        _in  += num_bits;
1397        _in  += num_bits;
1398        _out += out_num_bits;
1399    }
1400    if (num_ins % 2 != 0) {
1401        for (unsigned i = 0; i < num_bits; i++) {
1402            _out[i] = _in[i];
1403        }
1404        _out[num_bits] = Z3_mk_false(ctx);
1405    }
1406    return out_num_ins;
1407}
1408
1409/**
1410   \brief Return the \c idx bit of \c val.
1411*/
1412inline bool get_bit(unsigned val, unsigned idx) {
1413    return (val & (1U << (idx & 31))) != 0;
1414}
1415
1416/**
1417   \brief Given an integer val encoded in n bits (boolean variables), assert the constraint that val <= k.
1418*/
1419void assert_le_one(Z3_context ctx, Z3_solver s, unsigned n, Z3_ast * val)
1420{
1421    Z3_ast i1, i2;
1422    Z3_ast not_val = Z3_mk_not(ctx, val[0]);
1423    assert (get_bit(1, 0));
1424    Z3_ast out = Z3_mk_true(ctx);
1425    for (unsigned i = 1; i < n; i++) {
1426        not_val = Z3_mk_not(ctx, val[i]);
1427        if (get_bit(1, i)) {
1428            i1 = not_val;
1429            i2 = out;
1430        }
1431        else {
1432            i1 = Z3_mk_false(ctx);
1433            i2 = Z3_mk_false(ctx);
1434        }
1435        out = mk_ternary_or(ctx, i1, i2, mk_binary_and(ctx, not_val, out));
1436    }
1437    Z3_solver_assert(ctx, s, out);
1438}
1439
1440/**
1441   \brief Create a counter circuit to count the number of "ones" in lits.
1442   The function returns an array of bits (i.e. boolean expressions) containing the output of the circuit.
1443   The size of the array is stored in out_sz.
1444*/
1445void mk_counter_circuit(Z3_context ctx, Z3_solver solver, unsigned n, Z3_ast * lits) {
1446    unsigned k = 1;
1447    assert (n != 0);
1448    Z3_ast aux_array_1[n + 1];
1449    Z3_ast aux_array_2[n + 1];
1450    Z3_ast * aux_1 = aux_array_1;
1451    Z3_ast * aux_2 = aux_array_2;
1452    std::memcpy(aux_1, lits, sizeof(Z3_ast) * n);
1453    while (n > 1) {
1454        assert (aux_1 != aux_2);
1455        n = mk_adder_pairs(ctx, k++, n, aux_1, aux_2);
1456        std::swap(aux_1, aux_2);
1457    }
1458    assert_le_one(ctx, solver, k, aux_1);
1459}
1460
1461/** ------------------------------------------------------------------------------------------------------------- *
1462 * Fu & Malik procedure for MaxSAT. This procedure is based on unsat core extraction and the at-most-one constraint.
1463 ** ------------------------------------------------------------------------------------------------------------- */
1464Z3_bool maxsat(Z3_context ctx, Z3_solver solver, std::vector<Z3_ast> & soft) {
1465    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
1466        return Z3_L_FALSE;
1467    }
1468    if (LLVM_UNLIKELY(soft.empty())) {
1469        return true;
1470    }
1471
1472    const auto n = soft.size();
1473    const auto ty = Z3_mk_bool_sort(ctx);
1474    Z3_ast aux_vars[n];
1475    Z3_ast assumptions[n];
1476
1477    for (unsigned i = 0; i < n; ++i) {
1478        aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, ty);
1479        Z3_solver_assert(ctx, solver, mk_binary_or(ctx, soft[i], aux_vars[i]));
1480    }
1481
1482    for (;;) {
1483        // create assumptions
1484        for (unsigned i = 0; i < n; i++) {
1485            // Recall that we asserted (soft_cnstrs[i] \/ aux_vars[i])
1486            // So using (NOT aux_vars[i]) as an assumption we are actually forcing the soft_cnstrs[i] to be considered.
1487            assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
1488        }
1489        if (Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE) {
1490            return Z3_L_TRUE; // done
1491        } else {
1492            Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver);
1493            unsigned m = Z3_ast_vector_size(ctx, core);
1494            Z3_ast block_vars[m];
1495            unsigned k = 0;
1496            // update soft-constraints and aux_vars
1497            for (unsigned i = 0; i < n; i++) {
1498                // check whether assumption[i] is in the core or not
1499                for (unsigned j = 0; j < m; j++) {
1500                    if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
1501                        // assumption[i] is in the unsat core... so soft_cnstrs[i] is in the unsat core
1502                        Z3_ast block_var = Z3_mk_fresh_const(ctx, nullptr, ty);
1503                        Z3_ast new_aux_var = Z3_mk_fresh_const(ctx, nullptr, ty);
1504                        soft[i] = mk_binary_or(ctx, soft[i], block_var);
1505                        aux_vars[i] = new_aux_var;
1506                        block_vars[k] = block_var;
1507                        ++k;
1508                        // Add new constraint containing the block variable.
1509                        // Note that we are using the new auxiliary variable to be able to use it as an assumption.
1510                        Z3_solver_assert(ctx, solver, mk_binary_or(ctx, soft[i], new_aux_var) );
1511                        break;
1512                    }
1513                }
1514
1515            }
1516            if (k > 1) {
1517                mk_counter_circuit(ctx, solver, k, block_vars);
1518            }
1519        }
1520    }
1521    llvm_unreachable("unreachable");
1522    return Z3_L_FALSE;
1523}
1524
1525} // end of namespace pablo
Note: See TracBrowser for help on using the repository browser.