source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_automultiplexing.cpp @ 5112

Last change on this file since 5112 was 5112, checked in by nmedfort, 3 years ago

Initial work on multiplexing using Z3.

File size: 42.2 KB
Line 
1#include "pablo_automultiplexing.hpp"
2#include <pablo/builder.hpp>
3#include <pablo/function.h>
4#include <pablo/printer_pablos.h>
5#include <boost/container/flat_set.hpp>
6#include <boost/numeric/ublas/matrix.hpp>
7#include <boost/circular_buffer.hpp>
8#include <boost/graph/topological_sort.hpp>
9#include <boost/range/iterator_range.hpp>
10#include <pablo/analysis/pabloverifier.hpp>
11#include <pablo/optimizers/pablo_simplifier.hpp>
12#include <pablo/builder.hpp>
13#include <stack>
14#include <queue>
15#include <unordered_set>
16#include <functional>
17#include <llvm/Support/CommandLine.h>
18
19using namespace llvm;
20using namespace boost;
21using namespace boost::container;
22using namespace boost::numeric::ublas;
23
24/// Interesting test cases:
25/// ./icgrep -c -multiplexing '[\p{Lm}\p{Meetei_Mayek}]' -disable-if-hierarchy-strategy
26
27/// ./icgrep -c -multiplexing '\p{Imperial_Aramaic}(?<!\p{Sm})' -disable-if-hierarchy-strategy
28
29
30static cl::OptionCategory MultiplexingOptions("Multiplexing Optimization Options", "These options control the Pablo Multiplexing optimization pass.");
31
32#ifdef NDEBUG
33#define INITIAL_SEED_VALUE (std::random_device()())
34#else
35#define INITIAL_SEED_VALUE (83234827342)
36#endif
37
38static cl::opt<std::mt19937::result_type> Seed("multiplexing-seed", cl::init(INITIAL_SEED_VALUE),
39                                        cl::desc("randomization seed used when performing any non-deterministic operations."),
40                                        cl::cat(MultiplexingOptions));
41
42#undef INITIAL_SEED_VALUE
43
44static cl::opt<unsigned> SetLimit("multiplexing-set-limit", cl::init(std::numeric_limits<unsigned>::max()),
45                                        cl::desc("maximum size of any candidate set."),
46                                        cl::cat(MultiplexingOptions));
47
48static cl::opt<unsigned> SelectionLimit("multiplexing-selection-limit", cl::init(100),
49                                        cl::desc("maximum number of selections from any partial candidate set."),
50                                        cl::cat(MultiplexingOptions));
51
52static cl::opt<unsigned> WindowSize("multiplexing-window-size", cl::init(1),
53                                        cl::desc("maximum depth difference for computing mutual exclusion of Advance nodes."),
54                                        cl::cat(MultiplexingOptions));
55
56static cl::opt<unsigned> Samples("multiplexing-samples", cl::init(1),
57                                 cl::desc("number of times the Advance constraint graph is sampled to find multiplexing opportunities."),
58                                 cl::cat(MultiplexingOptions));
59
60
61enum SelectionStrategy {Greedy, WorkingSet};
62
63static cl::opt<SelectionStrategy> Strategy(cl::desc("Choose set selection strategy:"),
64                                             cl::values(
65                                             clEnumVal(Greedy, "choose the largest multiplexing sets possible (w.r.t. the multiplexing-set-limit)."),
66                                             clEnumVal(WorkingSet, "choose multiplexing sets that share common input values."),
67                                             clEnumValEnd),
68                                           cl::init(Greedy),
69                                           cl::cat(MultiplexingOptions));
70
71namespace pablo {
72
73using TypeId = PabloAST::ClassTypeId;
74
75/** ------------------------------------------------------------------------------------------------------------- *
76 * @brief optimize
77 * @param function the function to optimize
78 ** ------------------------------------------------------------------------------------------------------------- */
79bool MultiplexingPass::optimize(PabloFunction & function) {
80
81    if (LLVM_UNLIKELY(Samples < 1)) {
82        return false;
83    }
84
85    PabloVerifier::verify(function, "pre-multiplexing");
86
87    Z3_config cfg = Z3_mk_config();
88    Z3_context ctx = Z3_mk_context_rc(cfg);
89    Z3_del_config(cfg);
90    Z3_solver solver = Z3_mk_solver(ctx);
91    Z3_solver_inc_ref(ctx, solver);
92
93    MultiplexingPass mp(function, Seed, ctx, solver);
94
95    mp.characterize(function);
96
97    Z3_solver_dec_ref(ctx, solver);
98    Z3_del_context(ctx);
99
100    PabloVerifier::verify(function, "post-multiplexing");
101
102    return true;
103}
104
105/** ------------------------------------------------------------------------------------------------------------- *
106 * @brief characterize
107 * @param function the function to optimize
108 ** ------------------------------------------------------------------------------------------------------------- */
109void MultiplexingPass::characterize(PabloFunction & function) {
110    // Map the constants and input variables
111    Z3_sort boolTy = Z3_mk_bool_sort(mContext);
112
113    Z3_ast F = Z3_mk_const(mContext, Z3_mk_int_symbol(mContext, 0), boolTy);
114    Z3_inc_ref(mContext, F);
115    add(PabloBlock::createZeroes(), F);
116
117    Z3_ast T = Z3_mk_const(mContext, Z3_mk_int_symbol(mContext, 1), boolTy);
118    Z3_inc_ref(mContext, T);
119    add(PabloBlock::createOnes(), T);
120
121    for (unsigned i = 0; i < function.getNumOfParameters(); ++i) {
122        make(function.getParameter(i));
123    }
124
125    characterize(function.getEntryBlock());
126}
127
128/** ------------------------------------------------------------------------------------------------------------- *
129 * @brief characterize
130 ** ------------------------------------------------------------------------------------------------------------- */
131void MultiplexingPass::characterize(PabloBlock * const block) {
132    Statement * end = initialize(block->front());
133    for (Statement * stmt : *block) {
134        if (LLVM_UNLIKELY(stmt == end)) {
135            Statement * const next = stmt->getNextNode();
136            multiplex(block);
137            if (isa<If>(stmt)) {
138                characterize(cast<If>(stmt)->getBody());
139            } else if (isa<While>(stmt)) {
140                for (const Next * var : cast<While>(stmt)->getVariants()) {
141                    Z3_inc_ref(mContext, get(var->getInitial()));
142                }
143                characterize(cast<While>(stmt)->getBody());
144                // since we cannot be certain that we'll always execute at least one iteration of a loop, we must
145                // assume that the variants could either be their initial value or their resulting value.
146                for (const Next * var : cast<While>(stmt)->getVariants()) {
147                    Z3_ast v0 = get(var->getInitial());
148                    Z3_ast & v1 = get(var);
149                    Z3_ast merge[2] = { v0, v1 };
150                    Z3_ast r = Z3_mk_or(mContext, 2, merge);
151                    Z3_inc_ref(mContext, r);
152                    Z3_dec_ref(mContext, v0);
153                    Z3_dec_ref(mContext, v1);
154                    v1 = r;
155                    assert (get(var) == r);
156                }
157            }
158            end = initialize(next);
159        } else {
160            characterize(stmt);
161        }
162    }
163    multiplex(block);
164}
165
166/** ------------------------------------------------------------------------------------------------------------- *
167 * @brief multiplex
168 ** ------------------------------------------------------------------------------------------------------------- */
169void MultiplexingPass::multiplex(PabloBlock * const block) {
170    if (generateCandidateSets()) {
171        selectMultiplexSetsGreedy();
172        eliminateSubsetConstraints();
173        multiplexSelectedSets(block);
174    }
175}
176
177/** ------------------------------------------------------------------------------------------------------------- *
178 * @brief equals
179 ** ------------------------------------------------------------------------------------------------------------- */
180inline bool MultiplexingPass::equals(Z3_ast a, Z3_ast b) {
181    Z3_solver_push(mContext, mSolver);
182    Z3_ast test = Z3_mk_eq(mContext, a, b);
183    Z3_inc_ref(mContext, test);
184    Z3_solver_assert(mContext, mSolver, test);
185    const auto r = Z3_solver_check(mContext, mSolver);
186    Z3_dec_ref(mContext, test);
187    Z3_solver_pop(mContext, mSolver, 1);
188    return (r == Z3_L_TRUE);
189}
190
191/** ------------------------------------------------------------------------------------------------------------- *
192 * @brief handle_unexpected_statement
193 ** ------------------------------------------------------------------------------------------------------------- */
194static void handle_unexpected_statement(Statement * const stmt) {
195    std::string tmp;
196    raw_string_ostream err(tmp);
197    err << "Unexpected statement type: ";
198    PabloPrinter::print(stmt, err);
199    throw std::runtime_error(err.str());
200}
201
202/** ------------------------------------------------------------------------------------------------------------- *
203 * @brief characterize
204 ** ------------------------------------------------------------------------------------------------------------- */
205inline Z3_ast MultiplexingPass::characterize(Statement * const stmt) {
206
207    const size_t n = stmt->getNumOperands(); assert (n > 0);
208    Z3_ast operands[n] = {};
209    for (size_t i = 0; i < n; ++i) {
210        PabloAST * op = stmt->getOperand(i);
211        if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
212            operands[i] = get(op, true);
213        }
214    }
215
216    Z3_ast node = operands[0];
217    switch (stmt->getClassTypeId()) {
218        case TypeId::Assign:
219        case TypeId::Next:
220        case TypeId::AtEOF:
221        case TypeId::InFile:
222            node = operands[0]; break;
223        case TypeId::And:
224            node = Z3_mk_and(mContext, n, operands); break;
225        case TypeId::Or:
226            node = Z3_mk_or(mContext, n, operands); break;
227        case TypeId::Xor:
228            node = Z3_mk_xor(mContext, operands[0], operands[1]);
229            Z3_inc_ref(mContext, node);
230            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
231                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
232                Z3_inc_ref(mContext, temp);
233                Z3_dec_ref(mContext, node);
234                node = temp;
235            }
236            return add(stmt, node);
237        case TypeId::Not:
238            node = Z3_mk_not(mContext, node);
239            break;
240        case TypeId::Sel:
241            node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
242            break;
243        case TypeId::Advance:
244            return characterize(cast<Advance>(stmt), operands[0]);
245        case TypeId::ScanThru:
246            // ScanThru(c, m) := (c + m) ∧ ¬m. Thus we can conservatively represent this statement using the BDD
247            // for ¬m --- provided no derivative of this statement is negated in any fashion.
248        case TypeId::MatchStar:
249        case TypeId::Count:
250            return make(stmt);
251        default:
252            handle_unexpected_statement(stmt);
253    }
254    Z3_inc_ref(mContext, node);
255    return add(stmt, node);
256}
257
258
259/** ------------------------------------------------------------------------------------------------------------- *
260 * @brief characterize
261 ** ------------------------------------------------------------------------------------------------------------- */
262inline Z3_ast MultiplexingPass::characterize(Advance * const adv, Z3_ast Ik) {
263    const auto k = mAdvanceNegatedVariable.size();
264
265    assert (adv);
266    assert (mConstraintGraph[k] == adv);
267
268    bool unconstrained[k] = {};
269
270    Z3_solver_push(mContext, mSolver);
271
272    for (size_t i = 0; i < k; ++i) {
273
274        // Have we already proven that they are unconstrained by their subset relationship?
275        if (unconstrained[i]) continue;
276
277        // If these Advances are mutually exclusive, in the same scope, transitively independent, and shift their
278        // values by the same amount, we can safely multiplex them. Otherwise mark the constraint in the graph.
279        const Advance * const ithAdv = mConstraintGraph[i];
280        if (ithAdv->getOperand(1) == adv->getOperand(1)) {
281
282            Z3_ast Ii = get(ithAdv->getOperand(0));
283
284            // Is there any satisfying truth assignment? If not, these streams are mutually exclusive.
285
286            Z3_solver_push(mContext, mSolver);
287            Z3_ast conj[2] = { Ii, Ik };
288            Z3_ast IiIk = Z3_mk_and(mContext, 2, conj);
289            Z3_inc_ref(mContext, IiIk);
290            Z3_solver_assert(mContext, mSolver, IiIk);
291            if (Z3_solver_check(mContext, mSolver) == Z3_L_FALSE) {
292                // If Ai ∩ Ak = ∅ and Aj ⊂ Ai, Aj ∩ Ak = ∅.
293                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
294                    unconstrained[source(e, mSubsetGraph)] = true;
295                }
296                unconstrained[i] = true;
297
298            } else if (equals(Ii, IiIk)) {
299                // If Ii = Ii ∩ Ik then Ii ⊆ Ik. Record this in the subset graph with the arc (i, k).
300                // Note: the AST will be modified to make these mutually exclusive if Ai and Ak end up in
301                // the same multiplexing set.
302                add_edge(i, k, mSubsetGraph);
303                // If Ai ⊂ Ak and Aj ⊂ Ai, Aj ⊂ Ak.
304                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
305                    const auto j = source(e, mSubsetGraph);
306                    add_edge(j, k, mSubsetGraph);
307                    unconstrained[j] = true;
308                }
309                unconstrained[i] = true;
310
311            } else if (equals(Ik, IiIk)) {
312                // If Ik = Ii ∩ Ik then Ik ⊆ Ii. Record this in the subset graph with the arc (k, i).
313                add_edge(k, i, mSubsetGraph);
314                // If Ak ⊂ Ai and Ai ⊂ Aj, Ak ⊂ Aj.
315                for (auto e : make_iterator_range(out_edges(i, mSubsetGraph))) {
316                    const auto j = target(e, mSubsetGraph);
317                    add_edge(k, j, mSubsetGraph);
318                    unconstrained[j] = true;
319                }
320                unconstrained[i] = true;
321            }
322
323            Z3_dec_ref(mContext, IiIk);
324            Z3_solver_pop(mContext, mSolver, 1);
325        }
326    }
327
328    Z3_solver_pop(mContext, mSolver, 1);
329
330    Z3_ast Ak0 = make(adv);
331    Z3_inc_ref(mContext, Ak0);
332    Z3_ast Nk = Z3_mk_not(mContext, Ak0);
333    Z3_inc_ref(mContext, Nk);
334
335    Z3_ast vars[k + 1];
336    vars[0] = Ak0;
337
338    unsigned m = 1;
339    for (unsigned i = 0; i < k; ++i) {
340        if (unconstrained[i]) {
341            // This algorithm deems two streams mutually exclusive if and only if their conjuntion is a contradiction.
342            // To generate a contradiction when comparing Advances, the BDD of each Advance is represented by the conjunction of
343            // variables representing the k-th Advance and the negation of all variables for the Advances whose inputs are mutually
344            // exclusive with the k-th input.
345
346            // For example, if the input of the i-th Advance is mutually exclusive with the input of the j-th and k-th Advance, the
347            // BDD of the i-th Advance is Ai ∧ ¬Aj ∧ ¬Ak. Similarly, the j- and k-th Advance is Aj ∧ ¬Ai and Ak ∧ ¬Ai, respectively
348            // (assuming that the j-th and k-th Advance are not mutually exclusive.)
349
350            Z3_ast & Ai0 = get(mConstraintGraph[i]);
351            Z3_ast conj[2] = { Ai0, Nk };
352            Z3_ast Ai = Z3_mk_and(mContext, 2, conj);
353            Z3_inc_ref(mContext, Ai);
354            Z3_dec_ref(mContext, Ai0); // if this doesn't work, we'll have to scan from the output variables.
355            Ai0 = Ai;
356            assert (get(mConstraintGraph[i]) == Ai);
357
358            vars[m++] = mAdvanceNegatedVariable[i];
359
360            continue; // note: if these Advances aren't transtively independent, an edge will still exist.
361        }
362        add_edge(i, k, mConstraintGraph);
363    }
364    // To minimize the number of BDD computations, we store the negated variable instead of negating it each time.
365    mAdvanceNegatedVariable.emplace_back(Nk);
366    Z3_ast Ak = Z3_mk_and(mContext, m, vars);
367    if (LLVM_UNLIKELY(Ak != Ak0)) {
368        Z3_inc_ref(mContext, Ak);
369        Z3_dec_ref(mContext, Ak0);
370    }
371    return add(adv, Ak);
372}
373
374/** ------------------------------------------------------------------------------------------------------------- *
375 * @brief initialize
376 ** ------------------------------------------------------------------------------------------------------------- */
377Statement * MultiplexingPass::initialize(Statement * const initial) {
378
379    // clean up any unneeded refs / characterizations.
380    for (auto i = mCharacterization.begin(); i != mCharacterization.end(); ) {
381        const CharacterizationRef & r = std::get<1>(*i);
382        if (LLVM_UNLIKELY(std::get<1>(r) == 0)) {
383            Z3_dec_ref(mContext, std::get<0>(r));
384            auto j = i++;
385            mCharacterization.erase(j);
386        } else {
387            ++i;
388        }
389    }
390
391    for (Z3_ast var : mAdvanceNegatedVariable) {
392        Z3_dec_ref(mContext, var);
393    }
394    mAdvanceNegatedVariable.clear();
395
396    // Scan through and count all the advances and statements ...
397    unsigned statements = 0, advances = 0;
398    Statement * last = nullptr;
399    for (Statement * stmt = initial; stmt; stmt = stmt->getNextNode()) {
400        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
401            last = stmt;
402            break;
403        } else if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
404            ++advances;
405        }
406        ++statements;
407    }
408
409    flat_map<const PabloAST *, unsigned> M;
410    M.reserve(statements);
411    matrix<bool> G(statements, advances, false);
412    for (unsigned i = 0; i != advances; ++i) {
413        G(i, i) = true;
414    }
415
416    mConstraintGraph = ConstraintGraph(advances);
417    unsigned n = advances;
418    unsigned k = 0;
419    for (Statement * stmt = initial; stmt != last; stmt = stmt->getNextNode()) {
420        assert (!isa<If>(stmt) && !isa<While>(stmt));
421        unsigned u = 0;
422        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
423            mConstraintGraph[k] = cast<Advance>(stmt);
424            u = k++;
425        } else {
426            u = n++;
427        }
428        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
429            const PabloAST * const op = stmt->getOperand(i);
430            if (LLVM_LIKELY(isa<Statement>(op))) {
431                auto f = M.find(op);
432                if (f != M.end()) {
433                    const unsigned v = std::get<1>(*f);
434                    for (unsigned w = 0; w != k; ++w) {
435                        G(u, w) |= G(v, w);
436                    }
437                }
438            }
439        }
440        M.emplace(stmt, u);
441    }
442
443    assert (k == advances);
444
445    // Initialize the base constraint graph by transposing G and removing reflective loops
446    for (unsigned i = 0; i != advances; ++i) {
447        for (unsigned j = 0; j < i; ++j) {
448            if (G(i, j)) {
449                add_edge(j, i, mConstraintGraph);
450            }
451        }
452        for (unsigned j = i + 1; j < advances; ++j) {
453            if (G(i, j)) {
454                add_edge(j, i, mConstraintGraph);
455            }
456        }
457    }
458
459    mSubsetGraph = SubsetGraph(advances);
460    mAdvanceNegatedVariable.reserve(advances);
461
462    return last;
463}
464
465/** ------------------------------------------------------------------------------------------------------------- *
466 * @brief is_power_of_2
467 * @param n an integer
468 ** ------------------------------------------------------------------------------------------------------------- */
469static inline bool is_power_of_2(const size_t n) {
470    return ((n & (n - 1)) == 0);
471}
472
473/** ------------------------------------------------------------------------------------------------------------- *
474 * @brief generateCandidateSets
475 ** ------------------------------------------------------------------------------------------------------------- */
476bool MultiplexingPass::generateCandidateSets() {
477
478    const auto n = mAdvanceNegatedVariable.size();
479    if (n < 3) {
480        return false;
481    }
482    assert (num_vertices(mConstraintGraph) == n);
483
484    Constraints S;
485
486    ConstraintGraph::degree_size_type D[n];
487
488    mCandidateGraph = CandidateGraph(n);
489
490    for (unsigned r = Samples; r; --r) {
491
492        // Push all source nodes into the (initial) independent set S
493        for (const auto v : make_iterator_range(vertices(mConstraintGraph))) {
494            const auto d = in_degree(v, mConstraintGraph);
495            D[v] = d;
496            if (d == 0) {
497                S.push_back(v);
498            }
499        }
500
501        auto remaining = num_vertices(mConstraintGraph) - S.size();
502
503        for (;;) {
504            assert (S.size() > 0);
505            addCandidateSet(S);
506            if (LLVM_UNLIKELY(remaining == 0)) {
507                break;
508            }
509            for (;;) {
510                assert (S.size() > 0);
511                // Randomly choose a vertex in S and discard it.
512                const auto i = S.begin() + IntDistribution(0, S.size() - 1)(mRNG);
513                assert (i != S.end());
514                const auto u = *i;
515                S.erase(i);
516                bool checkCandidate = false;
517                for (auto e : make_iterator_range(out_edges(u, mConstraintGraph))) {
518                    const auto v = target(e, mConstraintGraph);
519                    assert ("Constraint set degree subtraction error!" && (D[v] != 0));
520                    if ((--D[v]) == 0) {
521                        assert ("Error v is already in S!" && std::count(S.begin(), S.end(), v) == 0);
522                        S.push_back(v);
523                        assert (remaining != 0);
524                        --remaining;
525                        if (LLVM_LIKELY(S.size() >= 3)) {
526                            checkCandidate = true;
527                        }
528                    }
529                }
530                if (checkCandidate || LLVM_UNLIKELY(remaining == 0)) {
531                    break;
532                }
533            }
534        }
535
536        S.clear();
537    }
538
539    return num_vertices(mCandidateGraph) > num_vertices(mConstraintGraph);
540}
541
542/** ------------------------------------------------------------------------------------------------------------- *
543 * @brief choose
544 *
545 * Compute n choose k
546 ** ------------------------------------------------------------------------------------------------------------- */
547__attribute__ ((const)) inline unsigned long choose(const unsigned n, const unsigned k) {
548    if (n < k)
549        return 0;
550    if (n == k || k == 0)
551        return 1;
552    unsigned long delta = k;
553    unsigned long max = n - k;
554    if (delta < max) {
555        std::swap(delta, max);
556    }
557    unsigned long result = delta + 1;
558    for (unsigned i = 2; i <= max; ++i) {
559        result = (result * (delta + i)) / i;
560    }
561    return result;
562}
563
564/** ------------------------------------------------------------------------------------------------------------- *
565 * @brief select
566 *
567 * James McCaffrey's algorithm for "Generating the mth Lexicographical Element of a Mathematical Combination"
568 ** ------------------------------------------------------------------------------------------------------------- */
569void MultiplexingPass::selectCandidateSet(const unsigned n, const unsigned k, const unsigned m, const Constraints & S, ConstraintVertex * const element) {
570    unsigned long a = n;
571    unsigned long b = k;
572    unsigned long x = (choose(n, k) - 1) - m;
573    for (unsigned i = 0; i != k; ++i) {
574        unsigned long y = 0;
575        while ((y = choose(--a, b)) > x);
576        x = x - y;
577        b = b - 1;
578        element[i] = S[(n - 1) - a];
579    }
580}
581
582/** ------------------------------------------------------------------------------------------------------------- *
583 * @brief updateCandidateSet
584 ** ------------------------------------------------------------------------------------------------------------- */
585void MultiplexingPass::updateCandidateSet(ConstraintVertex * const begin, ConstraintVertex * const end) {
586
587    using Vertex = CandidateGraph::vertex_descriptor;
588
589    const auto n = num_vertices(mConstraintGraph);
590    const auto m = num_vertices(mCandidateGraph);
591    const auto d = end - begin;
592
593    std::sort(begin, end);
594
595    Vertex u = 0;
596
597    for (Vertex i = n; i != m; ++i) {
598
599        if (LLVM_UNLIKELY(degree(i, mCandidateGraph) == 0)) {
600            u = i;
601            continue;
602        }
603
604        const auto adj = adjacent_vertices(i, mCandidateGraph);
605        if (degree(i, mCandidateGraph) < d) {
606            // set_i can only be a subset of the new set
607            if (LLVM_UNLIKELY(std::includes(begin, end, adj.first, adj.second))) {
608                clear_vertex(i, mCandidateGraph);
609                u = i;
610            }
611        } else if (LLVM_UNLIKELY(std::includes(adj.first, adj.second, begin, end))) {
612            // the new set is a subset of set_i; discard it.
613            return;
614        }
615
616    }
617
618    if (LLVM_LIKELY(u == 0)) { // n must be at least 3 so u is 0 if and only if we're not reusing a set vertex.
619        u = add_vertex(mCandidateGraph);
620    }
621
622    for (ConstraintVertex * i = begin; i != end; ++i) {
623        add_edge(u, *i, mCandidateGraph);
624    }
625
626}
627
628/** ------------------------------------------------------------------------------------------------------------- *
629 * @brief addCandidateSet
630 * @param S an independent set
631 ** ------------------------------------------------------------------------------------------------------------- */
632inline void MultiplexingPass::addCandidateSet(const Constraints & S) {
633    if (S.size() >= 3) {
634        const unsigned setLimit = SetLimit;
635        if (S.size() <= setLimit) {
636            ConstraintVertex E[S.size()];
637            std::copy(S.cbegin(), S.cend(), E);
638            updateCandidateSet(E, E + S.size());
639        } else {
640            assert (setLimit > 0);
641            ConstraintVertex E[setLimit];
642            const auto max = choose(S.size(), setLimit);
643            if (LLVM_UNLIKELY(max <= SelectionLimit)) {
644                for (unsigned i = 0; i != max; ++i) {
645                    selectCandidateSet(S.size(), setLimit, i, S, E);
646                    updateCandidateSet(E, E + setLimit);
647                }
648            } else { // take m random samples
649                for (unsigned i = 0; i != SelectionLimit; ++i) {
650                    selectCandidateSet(S.size(), setLimit, mRNG() % max, S, E);
651                    updateCandidateSet(E, E + setLimit);
652                }
653            }
654        }
655    }
656}
657
658/** ------------------------------------------------------------------------------------------------------------- *
659 * @brief log2_plus_one
660 ** ------------------------------------------------------------------------------------------------------------- */
661static inline size_t log2_plus_one(const size_t n) {
662    return std::log2<size_t>(n) + 1;
663}
664
665/** ------------------------------------------------------------------------------------------------------------- *
666 * @brief selectMultiplexSetsGreedy
667 *
668 * This algorithm is simply computes a greedy set cover. We want an exact max-weight set cover but can generate new
669 * sets by taking a subset of any existing set. With a few modifications, the greedy approach seems to work well
670 * enough but can be shown to produce a suboptimal solution if there are three candidate sets labelled A, B and C,
671 * in which A ∩ B = ∅, |A| ≀ |B| < |C|, and C ⊂ (A ∪ B).
672 ** ------------------------------------------------------------------------------------------------------------- */
673void MultiplexingPass::selectMultiplexSetsGreedy() {
674
675    using AdjIterator = graph_traits<CandidateGraph>::adjacency_iterator;
676    using degree_t = CandidateGraph::degree_size_type;
677    using vertex_t = CandidateGraph::vertex_descriptor;
678
679    const size_t m = num_vertices(mConstraintGraph);
680    const size_t n = num_vertices(mCandidateGraph) - m;
681
682    bool chosen[n] = {};
683
684    for (;;) {
685
686        // Choose the set with the greatest number of vertices not already included in some other set.
687        vertex_t u = 0;
688        degree_t w = 0;
689        for (vertex_t i = 0; i != n; ++i) {
690            if (chosen[i]) continue;
691            const auto t = i + m;
692            degree_t r = degree(t, mCandidateGraph);
693            if (LLVM_LIKELY(r >= 3)) { // if this set has at least 3 elements.
694                r *= r;
695                if (w < r) {
696                    u = t;
697                    w = r;
698                }
699            } else if (r) {
700                clear_vertex(t, mCandidateGraph);
701            }
702        }
703
704        // Multiplexing requires 3 or more elements; if no set contains at least 3, abort.
705        if (LLVM_UNLIKELY(w == 0)) {
706            break;
707        }
708
709        chosen[u - m] = true;
710
711        // If this contains 2^n elements for any n, discard the member that is most likely to be added
712        // to some future set.
713        if (LLVM_UNLIKELY(is_power_of_2(degree(u, mCandidateGraph)))) {
714            vertex_t x = 0;
715            degree_t w = 0;
716            for (const auto v : make_iterator_range(adjacent_vertices(u, mCandidateGraph))) {
717                if (degree(v, mCandidateGraph) > w) {
718                    x = v;
719                    w = degree(v, mCandidateGraph);
720                }
721            }
722            remove_edge(u, x, mCandidateGraph);
723        }
724
725        AdjIterator begin, end;
726        std::tie(begin, end) = adjacent_vertices(u, mCandidateGraph);
727        for (auto vi = begin; vi != end; ) {
728            const auto v = *vi++;
729            clear_vertex(v, mCandidateGraph);
730            add_edge(v, u, mCandidateGraph);
731        }
732
733    }
734
735    #ifndef NDEBUG
736    for (unsigned i = 0; i != m; ++i) {
737        assert (degree(i, mCandidateGraph) <= 1);
738    }
739    for (unsigned i = m; i != (m + n); ++i) {
740        assert (degree(i, mCandidateGraph) == 0 || degree(i, mCandidateGraph) >= 3);
741    }
742    #endif
743}
744
745/** ------------------------------------------------------------------------------------------------------------- *
746 * @brief selectMultiplexSetsWorkingSet
747 ** ------------------------------------------------------------------------------------------------------------- */
748void MultiplexingPass::selectMultiplexSetsWorkingSet() {
749
750    // The inputs to each Advance must be different; otherwise the SimplificationPass would consider all but
751    // one of the Advances redundant. However, if the input is short lived, we can ignore it in favour of its
752    // operands, which *may* be shared amongst more than one of the Advances (or may be short lived themselves,
753    // in which we can consider their operands instead.) Ideally, if we can keep the set of live values small,
754    // we may be able to reduce register pressure.
755
756
757}
758
759/** ------------------------------------------------------------------------------------------------------------- *
760 * @brief eliminateSubsetConstraints
761 ** ------------------------------------------------------------------------------------------------------------- */
762void MultiplexingPass::eliminateSubsetConstraints() {
763    using SubsetEdgeIterator = graph_traits<SubsetGraph>::edge_iterator;
764    // If Ai ⊂ Aj then the subset graph will contain the arc (i, j). Remove all arcs corresponding to vertices
765    // that are not elements of the same multiplexing set.
766    SubsetEdgeIterator ei, ei_end, ei_next;
767    std::tie(ei, ei_end) = edges(mSubsetGraph);
768    for (ei_next = ei; ei != ei_end; ei = ei_next) {
769        ++ei_next;
770        const auto u = source(*ei, mSubsetGraph);
771        const auto v = target(*ei, mSubsetGraph);
772        if (degree(u, mCandidateGraph) != 0 && degree(v, mCandidateGraph) != 0) {
773            assert (degree(u, mCandidateGraph) == 1);
774            assert (degree(v, mCandidateGraph) == 1);
775            const auto su = *(adjacent_vertices(u, mCandidateGraph).first);
776            const auto sv = *(adjacent_vertices(v, mCandidateGraph).first);
777            if (su == sv) {
778                continue;
779            }
780        }
781        remove_edge(*ei, mSubsetGraph);
782    }
783
784    if (num_edges(mSubsetGraph) != 0) {
785
786        // At least one subset constraint exists; perform a transitive reduction on the graph to ensure that
787        // we perform the minimum number of AST modifications for the selected multiplexing sets.
788
789        doTransitiveReductionOfSubsetGraph();
790
791        // Afterwards modify the AST to ensure that multiplexing algorithm can ignore any subset constraints
792        for (auto e : make_iterator_range(edges(mSubsetGraph))) {
793            Advance * const adv1 = mConstraintGraph[source(e, mSubsetGraph)];
794            Advance * const adv2 = mConstraintGraph[target(e, mSubsetGraph)];
795            assert (adv1->getParent() == adv2->getParent());
796            PabloBlock * const pb = adv1->getParent();
797            pb->setInsertPoint(adv2->getPrevNode());
798            adv2->setOperand(0, pb->createAnd(adv2->getOperand(0), pb->createNot(adv1->getOperand(0)), "subset"));
799            pb->setInsertPoint(adv2);
800            adv2->replaceAllUsesWith(pb->createOr(adv1, adv2, "merge"));
801        }
802
803    }
804}
805
806/** ------------------------------------------------------------------------------------------------------------- *
807 * @brief dominates
808 *
809 * does Statement a dominate Statement b?
810 ** ------------------------------------------------------------------------------------------------------------- */
811bool dominates(const Statement * const a, const Statement * const b) {
812
813    if (LLVM_UNLIKELY(b == nullptr)) {
814        return true;
815    } else if (LLVM_UNLIKELY(a == nullptr)) {
816        return false;
817    }
818
819    assert (a->getParent());
820    assert (b->getParent());
821
822    const PabloBlock * const parent = a->getParent();
823    if (LLVM_LIKELY(parent == b->getParent())) {
824        for (const Statement * t : *parent) {
825            if (t == a) {
826                return true;
827            } else if (t == b) {
828                break;
829            }
830        }
831        return false;
832    } else {
833        const PabloBlock * block = b->getParent();
834        for (;;) {
835            Statement * br = block->getBranch();
836            if (br == nullptr) {
837                return dominates(parent->getBranch(), b);
838            }
839            block = br->getParent();
840            if (block == parent) {
841                return dominates(a, br);
842            }
843        }
844    }
845}
846
847/** ------------------------------------------------------------------------------------------------------------- *
848 * @brief multiplexSelectedSets
849 ** ------------------------------------------------------------------------------------------------------------- */
850inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block) {
851
852
853//    Z3_config cfg = Z3_mk_config();
854//    Z3_context ctx = Z3_mk_context_rc(cfg);
855//    Z3_del_config(cfg);
856//    Z3_solver solver = Z3_mk_solver(ctx);
857//    Z3_solver_inc_ref(ctx, solver);
858
859
860    const auto first_set = num_vertices(mConstraintGraph);
861    const auto last_set = num_vertices(mCandidateGraph);
862    for (auto idx = first_set; idx != last_set; ++idx) {
863        const size_t n = degree(idx, mCandidateGraph);
864        assert (n == 0 || n > 2);
865        if (n) {
866            const size_t m = log2_plus_one(n);
867            Advance * input[n];
868            PabloAST * muxed[m];
869            PabloAST * muxed_n[m];
870            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
871            unsigned i = 0;
872            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) { // orderMultiplexSet(idx)) {
873                input[i++] = mConstraintGraph[u];
874            }
875            Advance * const adv = input[0];
876            assert (block == adv->getParent());
877
878            circular_buffer<PabloAST *> Q(n);
879
880            PabloBuilder builder(block);
881            block->setInsertPoint(nullptr);
882            /// Perform n-to-m Multiplexing           
883            for (size_t j = 0; j != m; ++j) {               
884                std::ostringstream prefix;
885                prefix << "mux" << n << "to" << m << '.' << (j);
886                assert (Q.empty());
887                for (size_t i = 0; i != n; ++i) {
888                    if (((i + 1) & (1UL << j)) != 0) {
889                        Q.push_back(input[i]->getOperand(0));
890                    }
891                }
892                while (Q.size() > 1) {
893                    PabloAST * a = Q.front(); Q.pop_front();
894                    PabloAST * b = Q.front(); Q.pop_front();
895                    Q.push_back(builder.createOr(a, b));
896                }
897                PabloAST * const muxing =  Q.front(); Q.clear();
898                muxed[j] = builder.createAdvance(muxing, adv->getOperand(1), prefix.str());
899                muxed_n[j] = builder.createNot(muxed[j]);
900            }
901            /// Perform m-to-n Demultiplexing
902            block->setInsertPoint(block->back());
903            for (size_t i = 0; i != n; ++i) {
904                // Construct the demuxed values and replaces all the users of the original advances with them.
905                assert (Q.empty());
906                for (size_t j = 0; j != m; ++j) {
907                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
908                }
909                while (Q.size() > 1) {
910                    PabloAST * const a = Q.front(); Q.pop_front();
911                    PabloAST * const b = Q.front(); Q.pop_front();
912                    Q.push_back(builder.createAnd(a, b));
913                }
914                PabloAST * const demuxed =  Q.front(); Q.clear();
915                input[i]->replaceWith(demuxed, true, true);
916            }
917        }
918    }
919
920    flat_set<PabloAST *> encountered;
921    for (Statement * stmt = block->front(); stmt; ) {
922
923        assert (stmt->getParent() == block);
924        Statement * const next = stmt->getNextNode();
925
926        bool unmoved = true;
927        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
928            PabloAST * const op = stmt->getOperand(i);
929            if (isa<Statement>(op)) {
930                Statement * ip = cast<Statement>(op);
931                if (ip->getParent() != block) {
932                    // If we haven't already encountered the Assign or Next node, it must come from a If or
933                    // While node that we haven't processed yet. Scan ahead and try to locate it.
934                    if (isa<Assign>(op)) {
935                        for (PabloAST * user : cast<Assign>(op)->users()) {
936                            if (isa<If>(user) && cast<If>(user)->getParent() == block) {
937                                const auto & defs = cast<If>(user)->getDefined();
938                                if (LLVM_LIKELY(std::find(defs.begin(), defs.end(), op) != defs.end())) {
939                                    ip = cast<If>(user);
940                                    break;
941                                }
942                            }
943                        }
944                    } else if (isa<Next>(op)) {
945                        for (PabloAST * user : cast<Next>(op)->users()) {
946                            if (isa<While>(user) && cast<While>(user)->getParent() == block) {
947                                const auto & vars = cast<While>(user)->getVariants();
948                                if (LLVM_LIKELY(std::find(vars.begin(), vars.end(), op) != vars.end())) {
949                                    ip = cast<While>(user);
950                                    break;
951                                }
952                            }
953                        }
954                    }
955                }
956                if (encountered.count(ip) == 0) {
957                    if (dominates(ip, stmt)) {
958                        encountered.insert(ip);
959                    } else {
960                        assert (ip->getParent() == block);
961                        stmt->insertAfter(ip);
962                        unmoved = false;
963                        break;
964                    }
965                }
966            }
967        }
968        if (unmoved) {
969            encountered.insert(stmt);
970        }
971        stmt = next;
972    }
973
974//    Z3_solver_dec_ref(ctx, solver);
975//    Z3_del_context(ctx);
976
977
978}
979
980/** ------------------------------------------------------------------------------------------------------------- *
981 * @brief doTransitiveReductionOfSubsetGraph
982 ** ------------------------------------------------------------------------------------------------------------- */
983void MultiplexingPass::doTransitiveReductionOfSubsetGraph() {
984    std::vector<SubsetGraph::vertex_descriptor> Q;
985    for (auto u : make_iterator_range(vertices(mSubsetGraph))) {
986        if (in_degree(u, mSubsetGraph) == 0 && out_degree(u, mSubsetGraph) != 0) {
987            Q.push_back(u);
988        }
989    }
990    flat_set<SubsetGraph::vertex_descriptor> targets;
991    flat_set<SubsetGraph::vertex_descriptor> visited;
992    do {
993        const auto u = Q.back(); Q.pop_back();
994        for (auto ei : make_iterator_range(out_edges(u, mSubsetGraph))) {
995            for (auto ej : make_iterator_range(out_edges(target(ei, mSubsetGraph), mSubsetGraph))) {
996                targets.insert(target(ej, mSubsetGraph));
997            }
998        }
999        for (auto v : targets) {
1000            remove_edge(u, v, mSubsetGraph);
1001        }
1002        for (auto e : make_iterator_range(out_edges(u, mSubsetGraph))) {
1003            const auto v = target(e, mSubsetGraph);
1004            if (visited.insert(v).second) {
1005                Q.push_back(v);
1006            }
1007        }
1008    } while (!Q.empty());
1009}
1010
1011/** ------------------------------------------------------------------------------------------------------------- *
1012 * @brief get
1013 ** ------------------------------------------------------------------------------------------------------------- */
1014inline Z3_ast & MultiplexingPass::get(const PabloAST * const expr, const bool deref) {
1015    assert (expr);
1016    auto f = mCharacterization.find(expr);
1017    assert (f != mCharacterization.end());
1018    auto & val = f->second;
1019    if (deref) {
1020        unsigned & refs = std::get<1>(val);
1021        assert (refs > 0);
1022        --refs;
1023    }
1024    return std::get<0>(val);
1025}
1026
1027/** ------------------------------------------------------------------------------------------------------------- *
1028 * @brief make
1029 ** ------------------------------------------------------------------------------------------------------------- */
1030inline Z3_ast MultiplexingPass::make(const PabloAST * const expr) {
1031    assert (expr);
1032    Z3_sort ty = Z3_mk_bool_sort(mContext);
1033    Z3_symbol s = Z3_mk_string_symbol(mContext, nullptr); // expr->getName()->to_string().c_str()
1034    Z3_ast node = Z3_mk_const(mContext, s, ty);
1035    Z3_inc_ref(mContext, node);
1036    return add(expr, node);
1037}
1038
1039/** ------------------------------------------------------------------------------------------------------------- *
1040 * @brief add
1041 ** ------------------------------------------------------------------------------------------------------------- */
1042inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node) {   
1043    mCharacterization.insert(std::make_pair(expr, std::make_pair(node, expr->getNumUses())));
1044    return node;
1045}
1046
1047/** ------------------------------------------------------------------------------------------------------------- *
1048 * @brief constructor
1049 ** ------------------------------------------------------------------------------------------------------------- */
1050inline MultiplexingPass::MultiplexingPass(PabloFunction & f, const RNG::result_type seed, Z3_context context, Z3_solver solver)
1051: mContext(context)
1052, mSolver(solver)
1053, mFunction(f)
1054, mRNG(seed)
1055, mConstraintGraph(0)
1056{
1057
1058}
1059
1060
1061} // end of namespace pablo
Note: See TracBrowser for help on using the repository browser.