source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_automultiplexing.cpp @ 5233

Last change on this file since 5233 was 5230, checked in by nmedfort, 3 years ago

Multi-threading support for PabloAST / PabloCompiler?. Requires unique LLVM Context / Module for each thread.

File size: 41.5 KB
Line 
1#include "pablo_automultiplexing.hpp"
2#include <pablo/builder.hpp>
3#include <pablo/printer_pablos.h>
4#include <boost/container/flat_set.hpp>
5#include <boost/container/flat_map.hpp>
6#include <boost/numeric/ublas/matrix.hpp>
7#include <boost/circular_buffer.hpp>
8#include <boost/graph/topological_sort.hpp>
9#include <boost/range/iterator_range.hpp>
10#include <pablo/analysis/pabloverifier.hpp>
11#include <pablo/optimizers/pablo_simplifier.hpp>
12#include <pablo/builder.hpp>
13#include <stack>
14#include <queue>
15#include <unordered_set>
16#include <functional>
17#include <llvm/Support/CommandLine.h>
18#include "maxsat.hpp"
19
20using namespace llvm;
21using namespace boost;
22using namespace boost::container;
23using namespace boost::numeric::ublas;
24
25static cl::OptionCategory MultiplexingOptions("Multiplexing Optimization Options", "These options control the Pablo Multiplexing optimization pass.");
26
27static cl::opt<unsigned> WindowSize("multiplexing-window-size", cl::init(100),
28                                        cl::desc("maximum sequence distance to consider for candidate set."),
29                                        cl::cat(MultiplexingOptions));
30
31
32namespace pablo {
33
34using TypeId = PabloAST::ClassTypeId;
35
36/** ------------------------------------------------------------------------------------------------------------- *
37 * @brief optimize
38 * @param function the function to optimize
39 ** ------------------------------------------------------------------------------------------------------------- */
40bool MultiplexingPass::optimize(PabloFunction & function) {
41
42
43    Z3_config cfg = Z3_mk_config();
44    Z3_context ctx = Z3_mk_context_rc(cfg);
45    Z3_del_config(cfg);
46    Z3_solver solver = Z3_mk_solver(ctx);
47    Z3_solver_inc_ref(ctx, solver);
48
49    MultiplexingPass mp(function, ctx, solver);
50
51    mp.optimize();
52
53    Z3_solver_dec_ref(ctx, solver);
54    Z3_del_context(ctx);
55
56    #ifndef NDEBUG
57    PabloVerifier::verify(function, "post-multiplexing");
58    #endif
59
60    Simplifier::optimize(function);
61
62    return true;
63}
64
65/** ------------------------------------------------------------------------------------------------------------- *
66 * @brief characterize
67 * @param function the function to optimize
68 ** ------------------------------------------------------------------------------------------------------------- */
69void MultiplexingPass::optimize() {
70    // Map the constants and input variables
71    PabloBlock * entryBlock = mFunction.getEntryBlock();
72    add(entryBlock->createZeroes(), Z3_mk_false(mContext), -1);
73    add(entryBlock->createOnes(), Z3_mk_true(mContext), -1);
74    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
75        add(mFunction.getParameter(i), makeVar(), -1);
76    }
77    optimize(mFunction.getEntryBlock());
78}
79
80/** ------------------------------------------------------------------------------------------------------------- *
81 * @brief characterize
82 ** ------------------------------------------------------------------------------------------------------------- */
83void MultiplexingPass::optimize(PabloBlock * const block) {
84    Statement * begin = block->front();
85    Statement * end = initialize(begin);
86    for (Statement * stmt : *block) {
87        if (LLVM_UNLIKELY(stmt == end)) {
88            Statement * const next = stmt->getNextNode();
89            multiplex(block, begin, stmt);
90            if (isa<If>(stmt)) {
91                optimize(cast<If>(stmt)->getBody());
92            } else if (isa<While>(stmt)) {
93                for (const Var * var : cast<While>(stmt)->getEscaped()) {
94                    Z3_inc_ref(mContext, get(var));
95                }
96                optimize(cast<While>(stmt)->getBody());
97                // since we cannot be certain that we'll always execute at least one iteration of a loop, we must
98                // assume that the variants could either be their initial or resulting value.
99                for (const Var * var : cast<While>(stmt)->getEscaped()) {
100                    Z3_ast v0 = get(var);
101                    Z3_ast & v1 = get(var);
102                    Z3_ast merge[2] = { v0, v1 };
103                    Z3_ast r = Z3_mk_or(mContext, 2, merge);
104                    Z3_inc_ref(mContext, r);
105                    Z3_dec_ref(mContext, v0);
106                    Z3_dec_ref(mContext, v1);
107                    v1 = r;
108                    assert (get(var) == r);
109                }
110            }
111            end = initialize(begin = next);
112        } else {
113            characterize(stmt);
114        }
115    }
116    multiplex(block, begin, nullptr);
117}
118
119/** ------------------------------------------------------------------------------------------------------------- *
120 * @brief multiplex
121 ** ------------------------------------------------------------------------------------------------------------- */
122void MultiplexingPass::multiplex(PabloBlock * const block, Statement * const begin, Statement * const end) {
123    if (generateCandidateSets(begin, end)) {
124        selectMultiplexSetsGreedy();
125        // eliminateSubsetConstraints();
126        multiplexSelectedSets(block, begin, end);
127    }
128}
129
130/** ------------------------------------------------------------------------------------------------------------- *
131 * @brief equals
132 ** ------------------------------------------------------------------------------------------------------------- */
133inline bool MultiplexingPass::equals(Z3_ast a, Z3_ast b) {
134    Z3_solver_push(mContext, mSolver);
135    Z3_ast test = Z3_mk_eq(mContext, a, b);
136    Z3_inc_ref(mContext, test);
137    Z3_solver_assert(mContext, mSolver, test);
138    const auto r = Z3_solver_check(mContext, mSolver);
139    Z3_dec_ref(mContext, test);
140    Z3_solver_pop(mContext, mSolver, 1);
141    return (r == Z3_L_TRUE);
142}
143
144/** ------------------------------------------------------------------------------------------------------------- *
145 * @brief handle_unexpected_statement
146 ** ------------------------------------------------------------------------------------------------------------- */
147static void handle_unexpected_statement(const Statement * const stmt) {
148    std::string tmp;
149    raw_string_ostream err(tmp);
150    err << "Unexpected statement type: ";
151    PabloPrinter::print(stmt, err);
152    throw std::runtime_error(err.str());
153}
154
155/** ------------------------------------------------------------------------------------------------------------- *
156 * @brief characterize
157 ** ------------------------------------------------------------------------------------------------------------- */
158Z3_ast MultiplexingPass::characterize(const Statement * const stmt, const bool deref) {
159
160    const size_t n = stmt->getNumOperands(); assert (n > 0);
161    Z3_ast operands[n];
162    for (size_t i = 0; i < n; ++i) {
163        PabloAST * op = stmt->getOperand(i);
164        if (LLVM_UNLIKELY(isa<Integer>(op) || isa<String>(op))) {
165            continue;
166        }
167        operands[i] = get(op, deref);
168    }
169
170    Z3_ast node = operands[0];
171    switch (stmt->getClassTypeId()) {
172        case TypeId::Assign:
173        case TypeId::AtEOF:
174        case TypeId::InFile:
175            node = operands[0]; break;
176        case TypeId::And:
177            node = Z3_mk_and(mContext, n, operands); break;
178        case TypeId::Or:
179            node = Z3_mk_or(mContext, n, operands); break;
180        case TypeId::Xor:
181            node = Z3_mk_xor(mContext, operands[0], operands[1]);
182            Z3_inc_ref(mContext, node);
183            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
184                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
185                Z3_inc_ref(mContext, temp);
186                Z3_dec_ref(mContext, node);
187                node = temp;
188            }
189            return add(stmt, node, stmt->getNumUses());
190        case TypeId::Not:
191            node = Z3_mk_not(mContext, node);
192            break;
193        case TypeId::Sel:
194            node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
195            break;
196        case TypeId::Advance:
197            return characterize(cast<Advance>(stmt), operands[0]);
198        case TypeId::ScanThru:
199            // ScanThru(c, m) := (c + m) ∧ ¬m. Thus we can conservatively represent this statement using the BDD
200            // for ¬m --- provided no derivative of this statement is negated in any fashion.
201        case TypeId::MatchStar:
202        case TypeId::Count:
203            node = makeVar();
204            break;
205        default:
206            handle_unexpected_statement(stmt);
207    }
208    Z3_inc_ref(mContext, node);
209    return add(stmt, node, stmt->getNumUses());
210}
211
212
213/** ------------------------------------------------------------------------------------------------------------- *
214 * @brief characterize
215 ** ------------------------------------------------------------------------------------------------------------- */
216inline Z3_ast MultiplexingPass::characterize(const Advance * const adv, Z3_ast Ik) {
217    const auto k = mNegatedAdvance.size();
218
219    assert (adv);
220    assert (mConstraintGraph[k] == adv);
221    std::vector<bool> unconstrained(k);
222
223    Z3_solver_push(mContext, mSolver);
224
225    for (size_t i = 0; i < k; ++i) {
226
227        // Have we already proven that they are unconstrained by their subset relationship?
228        if (unconstrained[i]) continue;
229
230        // If these Advances are mutually exclusive, in the same scope, transitively independent, and shift their
231        // values by the same amount, we can safely multiplex them. Otherwise mark the constraint in the graph.
232        const Advance * const ithAdv = mConstraintGraph[i];
233        if (ithAdv->getOperand(1) == adv->getOperand(1)) {
234
235            Z3_ast Ii = get(ithAdv->getOperand(0));
236
237            // Is there any satisfying truth assignment? If not, these streams are mutually exclusive.
238
239            Z3_solver_push(mContext, mSolver);
240            Z3_ast conj[2] = { Ii, Ik };
241            Z3_ast IiIk = Z3_mk_and(mContext, 2, conj);
242            Z3_inc_ref(mContext, IiIk);
243            Z3_solver_assert(mContext, mSolver, IiIk);
244            if (Z3_solver_check(mContext, mSolver) == Z3_L_FALSE) {
245                // If Ai ∩ Ak = ∅ and Aj ⊂ Ai, Aj ∩ Ak = ∅.
246//                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
247//                    unconstrained[source(e, mSubsetGraph)] = true;
248//                }
249                unconstrained[i] = true;
250            }/* else if (equals(Ii, IiIk)) {
251                // If Ii = Ii ∩ Ik then Ii ⊆ Ik. Record this in the subset graph with the arc (i, k).
252                // Note: the AST will be modified to make these mutually exclusive if Ai and Ak end up in
253                // the same multiplexing set.
254                add_edge(i, k, mSubsetGraph);
255                // If Ai ⊂ Ak and Aj ⊂ Ai, Aj ⊂ Ak.
256                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
257                    const auto j = source(e, mSubsetGraph);
258                    add_edge(j, k, mSubsetGraph);
259                    unconstrained[j] = true;
260                }
261                unconstrained[i] = true;
262            } else if (equals(Ik, IiIk)) {
263                // If Ik = Ii ∩ Ik then Ik ⊆ Ii. Record this in the subset graph with the arc (k, i).
264                add_edge(k, i, mSubsetGraph);
265                // If Ak ⊂ Ai and Ai ⊂ Aj, Ak ⊂ Aj.
266                for (auto e : make_iterator_range(out_edges(i, mSubsetGraph))) {
267                    const auto j = target(e, mSubsetGraph);
268                    add_edge(k, j, mSubsetGraph);
269                    unconstrained[j] = true;
270                }
271                unconstrained[i] = true;
272            }*/
273            Z3_dec_ref(mContext, IiIk);
274            Z3_solver_pop(mContext, mSolver, 1);
275        }
276    }
277
278    Z3_solver_pop(mContext, mSolver, 1);
279
280    Z3_ast Ak0 = makeVar();
281    Z3_inc_ref(mContext, Ak0);
282    Z3_ast Nk = Z3_mk_not(mContext, Ak0);
283    Z3_inc_ref(mContext, Nk);
284
285    Z3_ast vars[k + 1];
286    vars[0] = Ak0;
287
288    unsigned m = 1;
289    for (unsigned i = 0; i < k; ++i) {
290        if (unconstrained[i]) {
291            // This algorithm deems two streams mutually exclusive if and only if their conjuntion is a contradiction.
292            // To generate a contradiction when comparing Advances, the BDD of each Advance is represented by the conjunction of
293            // variables representing the k-th Advance and the negation of all variables for the Advances whose inputs are mutually
294            // exclusive with the k-th input.
295
296            // For example, if the input of the i-th Advance is mutually exclusive with the input of the j-th and k-th Advance, the
297            // BDD of the i-th Advance is Ai ∧ ¬Aj ∧ ¬Ak. Similarly, the j- and k-th Advance is Aj ∧ ¬Ai and Ak ∧ ¬Ai, respectively
298            // (assuming that the j-th and k-th Advance are not mutually exclusive.)
299
300            Z3_ast & Ai0 = get(mConstraintGraph[i]);
301            Z3_ast conj[2] = { Ai0, Nk };
302            Z3_ast Ai = Z3_mk_and(mContext, 2, conj);
303            Z3_inc_ref(mContext, Ai);
304            Z3_dec_ref(mContext, Ai0);
305            Ai0 = Ai;
306            assert (get(mConstraintGraph[i]) == Ai);
307
308            vars[m++] = mNegatedAdvance[i];
309
310            continue; // note: if these Advances are transitively dependent, an edge will still exist.
311        }
312        const auto ei = add_edge(i, k, mConstraintGraph);
313        // if this is not a new edge, it must have a dependency constraint.
314        if (ei.second) {
315            mConstraintGraph[ei.first] = ConstraintType::Inclusive;
316        }
317    }
318    // To minimize the number of BDD computations, we store the negated variable instead of negating it each time.
319    mNegatedAdvance.emplace_back(Nk);
320    Z3_ast Ak = Z3_mk_and(mContext, m, vars);
321    if (LLVM_UNLIKELY(Ak != Ak0)) {
322        Z3_inc_ref(mContext, Ak);
323        Z3_dec_ref(mContext, Ak0);
324    }
325    return add(adv, Ak, -1);
326}
327
328/** ------------------------------------------------------------------------------------------------------------- *
329 * @brief initialize
330 ** ------------------------------------------------------------------------------------------------------------- */
331Statement * MultiplexingPass::initialize(Statement * const initial) {
332
333    // clean up any unneeded refs / characterizations.
334    for (auto i = mCharacterization.begin(); i != mCharacterization.end(); ) {
335        const auto ref = i->second;
336        auto next = i; ++next;
337        if (LLVM_UNLIKELY(ref.second == 0)) {
338            assert (isa<Statement>(i->first));
339            Z3_dec_ref(mContext, ref.first);
340            mCharacterization.erase(i);
341        }
342        i = next;
343    }
344    for (Z3_ast var : mNegatedAdvance) {
345        Z3_dec_ref(mContext, var);
346    }
347    mNegatedAdvance.clear();
348
349    // Scan through and count all the advances and statements ...
350    unsigned statements = 0, advances = 0;
351    Statement * last = nullptr;
352    for (Statement * stmt = initial; stmt; stmt = stmt->getNextNode()) {
353        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
354            last = stmt;
355            break;
356        } else if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
357            ++advances;
358        }
359        ++statements;
360    }
361
362    flat_map<const PabloAST *, unsigned> M;
363    M.reserve(statements);
364    matrix<bool> G(statements, advances, false);
365    for (unsigned i = 0; i != advances; ++i) {
366        G(i, i) = true;
367    }
368
369    mConstraintGraph = ConstraintGraph(advances);
370    unsigned n = advances;
371    unsigned k = 0;
372    for (Statement * stmt = initial; stmt != last; stmt = stmt->getNextNode()) {
373        assert (!isa<If>(stmt) && !isa<While>(stmt));
374        unsigned u = 0;
375        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
376            mConstraintGraph[k] = cast<Advance>(stmt);
377            u = k++;
378        } else {
379            u = n++;
380        }
381        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
382            const PabloAST * const op = stmt->getOperand(i);
383            if (LLVM_LIKELY(isa<Statement>(op))) {
384                auto f = M.find(op);
385                if (f != M.end()) {
386                    const unsigned v = std::get<1>(*f);
387                    for (unsigned w = 0; w != k; ++w) {
388                        G(u, w) |= G(v, w);
389                    }
390                }
391            }
392        }
393        M.emplace(stmt, u);
394    }
395
396    assert (k == advances);
397
398    // Initialize the base constraint graph by transposing G and removing reflective loops
399    for (unsigned i = 0; i != advances; ++i) {
400        for (unsigned j = 0; j < i; ++j) {
401            if (G(i, j)) {
402                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
403            }
404        }
405        for (unsigned j = i + 1; j < advances; ++j) {
406            if (G(i, j)) {
407                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
408            }
409        }
410    }
411
412    mSubsetGraph = SubsetGraph(advances);
413    mNegatedAdvance.reserve(advances);
414
415    return last;
416}
417
418/** ------------------------------------------------------------------------------------------------------------- *
419 * @brief generateCandidateSets
420 ** ------------------------------------------------------------------------------------------------------------- */
421bool MultiplexingPass::generateCandidateSets(Statement * const begin, Statement * const end) {
422
423    const auto n = mNegatedAdvance.size();
424    if (LLVM_UNLIKELY(n < 3)) {
425        return false;
426    }
427    assert (num_vertices(mConstraintGraph) == n);
428
429    mCandidateGraph = CandidateGraph(n);
430
431    Z3_config cfg = Z3_mk_config();
432    Z3_set_param_value(cfg, "MODEL", "true");
433    Z3_context ctx = Z3_mk_context(cfg);
434    Z3_del_config(cfg);
435    Z3_solver solver = Z3_mk_solver(ctx);
436    Z3_solver_inc_ref(ctx, solver);
437
438    std::vector<Z3_ast> V(n);
439    for (unsigned i = 0; i < n; ++i) {
440        V[i] = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_bool_sort(ctx)); assert (V[i]);
441    }
442    std::vector<std::pair<unsigned, unsigned>> S;
443    S.reserve(n);
444
445    unsigned line = 0;
446    unsigned i = 0;
447    for (Statement * stmt = begin; stmt != end; stmt = stmt->getNextNode()) {
448        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
449            assert (S.empty() || line > std::get<0>(S.back()));
450            assert (cast<Advance>(stmt) == mConstraintGraph[i]);
451            if (S.size() > 0 && (line - std::get<0>(S.front())) > WindowSize) {
452                // try to compute a maximal set for this given set of Advances
453                if (S.size() > 2) {
454                    generateCandidateSets(ctx, solver, S, V);
455                }
456                // erase any that preceed our window
457                auto end = S.begin();
458                while (++end != S.end() && ((line - std::get<0>(*end)) > WindowSize));
459                S.erase(S.begin(), end);
460            }
461            for (unsigned j : make_iterator_range(adjacent_vertices(i, mConstraintGraph))) {
462                Z3_ast disj[2] = { Z3_mk_not(ctx, V[j]), Z3_mk_not(ctx, V[i]) };
463                Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, disj));
464            }
465            S.emplace_back(line, i++);
466        }
467        ++line;
468    }
469    if (S.size() > 2) {
470        generateCandidateSets(ctx, solver, S, V);
471    }
472
473    Z3_solver_dec_ref(ctx, solver);
474    Z3_del_context(ctx);
475
476    return num_vertices(mCandidateGraph) > n;
477}
478
479/** ------------------------------------------------------------------------------------------------------------- *
480 * @brief generateCandidateSets
481 ** ------------------------------------------------------------------------------------------------------------- */
482void MultiplexingPass::generateCandidateSets(Z3_context ctx, Z3_solver solver, const std::vector<std::pair<unsigned, unsigned>> & S, const std::vector<Z3_ast> & V) {
483    assert (S.size() > 2);
484    assert (std::get<0>(S.front()) < std::get<0>(S.back()));
485    assert ((std::get<0>(S.back()) - std::get<0>(S.front())) <= WindowSize);
486
487    Z3_solver_push(ctx, solver);
488
489    const auto n = V.size();
490    std::vector<unsigned> M(S.size());
491    for (unsigned i = 0; i < S.size(); ++i) {
492        M[i] = std::get<1>(S[i]);
493    }
494
495    for (;;) {
496
497        std::vector<Z3_ast> assumptions(M.size());
498        unsigned j = 0;
499        for (unsigned i = 0; i < n; ++i) {
500            if (LLVM_UNLIKELY((j < M.size()) && (M[j] == i))) { // in our window range
501                assumptions[j++] = V[i]; assert (V[i]);
502            } else {
503                Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[i]));
504            }
505        }
506        assert (j == M.size());
507
508        if (maxsat(ctx, solver, assumptions) >= 0) {
509            Z3_model m = Z3_solver_get_model(ctx, solver);
510            Z3_model_inc_ref(ctx, m);
511            const auto k = add_vertex(mCandidateGraph); assert(k >= V.size());
512            Z3_ast TRUE = Z3_mk_true(ctx);
513            for (auto i = M.begin(); i != M.end(); ) {
514                Z3_ast value;
515                if (LLVM_UNLIKELY(Z3_model_eval(ctx, m, V[*i], 1, &value) != Z3_TRUE)) {
516                    throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from constraint model!");
517                }
518                if (value == TRUE) {
519                    add_edge(*i, k, mCandidateGraph);
520                    Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[*i]));
521                    i = M.erase(i);
522                } else {
523                    ++i;
524                }
525            }
526            Z3_model_dec_ref(ctx, m);
527            if (M.size() > 2) {
528                continue;
529            }
530        }
531        break;
532    }
533
534    Z3_solver_pop(ctx, solver, 1);
535}
536
537/** ------------------------------------------------------------------------------------------------------------- *
538 * @brief is_power_of_2
539 * @param n an integer
540 ** ------------------------------------------------------------------------------------------------------------- */
541static inline bool is_power_of_2(const size_t n) {
542    return ((n & (n - 1)) == 0);
543}
544
545/** ------------------------------------------------------------------------------------------------------------- *
546 * @brief log2_plus_one
547 ** ------------------------------------------------------------------------------------------------------------- */
548static inline size_t log2_plus_one(const size_t n) {
549    return std::log2<size_t>(n) + 1;
550}
551
552/** ------------------------------------------------------------------------------------------------------------- *
553 * @brief selectMultiplexSetsGreedy
554 *
555 * This algorithm is simply computes a greedy set cover. We want an exact max-weight set cover but can generate new
556 * sets by taking a subset of any existing set. With a few modifications, the greedy approach seems to work well
557 * enough but can be shown to produce a suboptimal solution if there are three candidate sets labelled A, B and C,
558 * in which A ∩ B = ∅, |A| ≀ |B| < |C|, and C ⊂ (A ∪ B).
559 ** ------------------------------------------------------------------------------------------------------------- */
560void MultiplexingPass::selectMultiplexSetsGreedy() {
561
562    using AdjIterator = graph_traits<CandidateGraph>::adjacency_iterator;
563    using degree_t = CandidateGraph::degree_size_type;
564    using vertex_t = CandidateGraph::vertex_descriptor;
565
566    const size_t m = num_vertices(mConstraintGraph);
567    const size_t n = num_vertices(mCandidateGraph) - m;
568
569    std::vector<bool> chosen(n);
570
571    for (;;) {
572
573        // Choose the set with the greatest number of vertices not already included in some other set.
574        vertex_t u = 0;
575        degree_t w = 0;
576        for (vertex_t i = 0; i != n; ++i) {
577            if (chosen[i]) continue;
578            const auto t = i + m;
579            degree_t r = degree(t, mCandidateGraph);
580            if (LLVM_LIKELY(r >= 3)) { // if this set has at least 3 elements.
581                if (w < r) {
582                    u = t;
583                    w = r;
584                }
585            } else if (r) {
586                clear_vertex(t, mCandidateGraph);
587            }
588        }
589
590        // Multiplexing requires 3 or more elements; if no set contains at least 3, abort.
591        if (LLVM_UNLIKELY(w == 0)) {
592            break;
593        }
594
595        chosen[u - m] = true;
596
597        // If this contains 2^n elements for any n, discard the member that is most likely to be added
598        // to some future set.
599        if (LLVM_UNLIKELY(is_power_of_2(degree(u, mCandidateGraph)))) {
600            vertex_t x = 0;
601            degree_t w = 0;
602            for (const auto v : make_iterator_range(adjacent_vertices(u, mCandidateGraph))) {
603                if (degree(v, mCandidateGraph) > w) {
604                    x = v;
605                    w = degree(v, mCandidateGraph);
606                }
607            }
608            remove_edge(u, x, mCandidateGraph);
609        }
610
611        AdjIterator begin, end;
612        std::tie(begin, end) = adjacent_vertices(u, mCandidateGraph);
613        for (auto vi = begin; vi != end; ) {
614            const auto v = *vi++;
615            clear_vertex(v, mCandidateGraph);
616            add_edge(v, u, mCandidateGraph);
617        }
618
619    }
620
621    #ifndef NDEBUG
622    for (unsigned i = 0; i != m; ++i) {
623        assert (degree(i, mCandidateGraph) <= 1);
624    }
625    for (unsigned i = m; i != (m + n); ++i) {
626        assert (degree(i, mCandidateGraph) == 0 || degree(i, mCandidateGraph) >= 3);
627    }
628    #endif
629}
630
631/** ------------------------------------------------------------------------------------------------------------- *
632 * @brief eliminateSubsetConstraints
633 ** ------------------------------------------------------------------------------------------------------------- */
634void MultiplexingPass::eliminateSubsetConstraints() {
635    using SubsetEdgeIterator = graph_traits<SubsetGraph>::edge_iterator;
636    // If Ai ⊂ Aj then the subset graph will contain the arc (i, j). Remove all arcs corresponding to vertices
637    // that are not elements of the same multiplexing set.
638    SubsetEdgeIterator ei, ei_end, ei_next;
639    std::tie(ei, ei_end) = edges(mSubsetGraph);
640    for (ei_next = ei; ei != ei_end; ei = ei_next) {
641        ++ei_next;
642        const auto u = source(*ei, mSubsetGraph);
643        const auto v = target(*ei, mSubsetGraph);
644        if (degree(u, mCandidateGraph) != 0 && degree(v, mCandidateGraph) != 0) {
645            assert (degree(u, mCandidateGraph) == 1);
646            assert (degree(v, mCandidateGraph) == 1);
647            const auto su = *(adjacent_vertices(u, mCandidateGraph).first);
648            const auto sv = *(adjacent_vertices(v, mCandidateGraph).first);
649            if (su == sv) {
650                continue;
651            }
652        }
653        remove_edge(*ei, mSubsetGraph);
654    }
655
656    if (num_edges(mSubsetGraph) != 0) {
657
658        // At least one subset constraint exists; perform a transitive reduction on the graph to ensure that
659        // we perform the minimum number of AST modifications for the selected multiplexing sets.
660
661        doTransitiveReductionOfSubsetGraph();
662
663        // Afterwards modify the AST to ensure that multiplexing algorithm can ignore any subset constraints
664        for (auto e : make_iterator_range(edges(mSubsetGraph))) {
665            Advance * const adv1 = mConstraintGraph[source(e, mSubsetGraph)];
666            Advance * const adv2 = mConstraintGraph[target(e, mSubsetGraph)];
667            assert (adv1->getParent() == adv2->getParent());
668            PabloBlock * const pb = adv1->getParent();
669            pb->setInsertPoint(adv2->getPrevNode());
670            adv2->setOperand(0, pb->createAnd(adv2->getOperand(0), pb->createNot(adv1->getOperand(0)), "subset"));
671            pb->setInsertPoint(adv2);
672            adv2->replaceAllUsesWith(pb->createOr(adv1, adv2, "merge"));
673        }
674
675    }
676}
677
678/** ------------------------------------------------------------------------------------------------------------- *
679 * @brief addWithHardConstraints
680 ** ------------------------------------------------------------------------------------------------------------- */
681Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, Statement * const stmt, flat_map<Statement *, Z3_ast> & M) {
682    assert (M.count(stmt) == 0 && stmt->getParent() == block);
683    // compute the hard dependency constraints
684    Z3_ast node = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_int_sort(ctx)); assert (node);
685    // we want all numbers to be positive so that the soft assertion that the first statement ought to stay at the first location
686    // whenever possible isn't satisfied by making preceeding numbers negative.
687    Z3_solver_assert(ctx, solver, Z3_mk_gt(ctx, node, Z3_mk_int(ctx, 0, Z3_mk_int_sort(ctx))));
688    for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
689        PabloAST * const op = stmt->getOperand(i);
690        if (isa<Statement>(op) && cast<Statement>(op)->getParent() == block) {
691            const auto f = M.find(cast<Statement>(op));
692            if (f != M.end()) {
693                Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, f->second, node));
694            }
695        }
696    }
697    M.emplace(stmt, node);
698    return node;
699}
700
701/** ------------------------------------------------------------------------------------------------------------- *
702 * @brief dominates
703 *
704 * does Statement a dominate Statement b?
705 ** ------------------------------------------------------------------------------------------------------------- */
706bool dominates(const Statement * const a, const Statement * const b) {
707    assert (a);
708    if (LLVM_UNLIKELY(b == nullptr)) {
709        return false;
710    }
711    assert (a->getParent() == b->getParent());
712    for (const Statement * t : *a->getParent()) {
713        if (t == a) {
714            return true;
715        } else if (t == b) {
716            return false;
717        }
718    }
719    llvm_unreachable("Neither a nor b are in their reported block!");
720    return false;
721}
722
723/** ------------------------------------------------------------------------------------------------------------- *
724 * @brief addWithHardConstraints
725 ** ------------------------------------------------------------------------------------------------------------- */
726Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, PabloAST * expr, flat_map<Statement *, Z3_ast> & M, Statement * const ip) {
727    if (isa<Statement>(expr)) {
728        Statement * const stmt = cast<Statement>(expr);
729        if (stmt->getParent() == block) {
730            const auto f = M.find(stmt);
731            if (LLVM_UNLIKELY(f != M.end())) {
732                return f->second;
733            } else if (!dominates(stmt, ip)) {
734                for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
735                    addWithHardConstraints(ctx, solver, block, stmt->getOperand(i), M, ip);
736                }
737                return addWithHardConstraints(ctx, solver, block, stmt, M);
738            }
739        }
740    }
741    return nullptr;
742}
743
744/** ------------------------------------------------------------------------------------------------------------- *
745 * @brief multiplexSelectedSets
746 ** ------------------------------------------------------------------------------------------------------------- */
747inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block, Statement * const begin, Statement * const end) {
748
749    assert ("begin cannot be null!" && begin);
750    assert (begin->getParent() == block);
751    assert (!end || end->getParent() == block);
752    assert (!end || isa<If>(end) || isa<While>(end));
753
754    // TODO: should we test whether sets overlap and merge the computations together?
755
756    Z3_config cfg = Z3_mk_config();
757    Z3_set_param_value(cfg, "MODEL", "true");
758    Z3_context ctx = Z3_mk_context(cfg);
759    Z3_del_config(cfg);
760    Z3_solver solver = Z3_mk_solver(ctx);
761    Z3_solver_inc_ref(ctx, solver);
762
763    const auto first_set = num_vertices(mConstraintGraph);
764    const auto last_set = num_vertices(mCandidateGraph);
765
766    for (auto idx = first_set; idx != last_set; ++idx) {
767        const size_t n = degree(idx, mCandidateGraph);
768        if (n) {
769            const size_t m = log2_plus_one(n); assert (n > 2 && m < n);
770            Advance * input[n];
771            PabloAST * muxed[m];
772            PabloAST * muxed_n[m];
773            PabloAST * demuxed[n];
774
775            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
776            unsigned i = 0;
777            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) {
778                input[i] = mConstraintGraph[u];
779                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
780                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
781                ++i;
782            }
783
784            // We can't trust the AST will be in the original order as we can multiplex a region of the program
785            // more than once.
786            Statement * initial = nullptr, * sentinal = nullptr;
787            for (Statement * stmt : *block) {
788                if (isa<Advance>(stmt)) {
789                    for (unsigned i = 0; i < n; ++i) {
790                        if (stmt == input[i]) {
791                            initial = initial ? initial : stmt;
792                            sentinal = stmt;
793                            break;
794                        }
795                    }
796                }
797            }
798            assert (initial);
799
800            Statement * const ip = initial->getPrevNode(); // save our insertion point prior to modifying the AST
801            sentinal = sentinal->getNextNode();
802
803            Z3_solver_push(ctx, solver);
804
805            // Compute the hard and soft constraints for any part of the AST that we are not intending to modify.
806            flat_map<Statement *, Z3_ast> M;
807
808            Z3_ast prior = nullptr;
809            Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
810            std::vector<Z3_ast> ordering;
811
812            for (Statement * stmt = initial; stmt != sentinal; stmt = stmt->getNextNode()) { assert (stmt != ip);
813                Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
814                // compute the soft ordering constraints
815                Z3_ast num = one;
816                if (prior) {
817                    Z3_ast prior_plus_one[2] = { prior, one };
818                    num = Z3_mk_add(ctx, 2, prior_plus_one);
819                }
820                ordering.push_back(Z3_mk_eq(ctx, node, num));
821                if (prior) {
822                    ordering.push_back(Z3_mk_lt(ctx, prior, node));
823                }
824                prior = node;
825            }
826
827            block->setInsertPoint(block->back()); // <- necessary for domination check!
828
829            circular_buffer<PabloAST *> Q(n);
830
831            /// Perform n-to-m Multiplexing
832            for (size_t j = 0; j != m; ++j) {
833                std::ostringstream prefix;
834                prefix << "mux" << n << "to" << m << '.' << (j);
835                assert (Q.empty());
836                for (size_t i = 0; i != n; ++i) {
837                    if (((i + 1) & (1UL << j)) != 0) {
838                        Q.push_back(input[i]->getOperand(0));
839                    }
840                }
841                while (Q.size() > 1) {
842                    PabloAST * a = Q.front(); Q.pop_front();
843                    PabloAST * b = Q.front(); Q.pop_front();
844                    PabloAST * expr = block->createOr(a, b);
845                    addWithHardConstraints(ctx, solver, block, expr, M, ip);
846                    Q.push_back(expr);
847                }
848                PabloAST * const muxing = Q.front(); Q.clear();
849                muxed[j] = block->createAdvance(muxing, input[0]->getOperand(1), prefix.str());
850                addWithHardConstraints(ctx, solver, block, muxed[j], M, ip);
851                muxed_n[j] = block->createNot(muxed[j]);
852                addWithHardConstraints(ctx, solver, block, muxed_n[j], M, ip);
853            }
854
855            /// Perform m-to-n Demultiplexing
856            for (size_t i = 0; i != n; ++i) {
857                // Construct the demuxed values and replaces all the users of the original advances with them.
858                assert (Q.empty());
859                for (size_t j = 0; j != m; ++j) {
860                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
861                }
862                Z3_ast replacement = nullptr;
863                while (Q.size() > 1) {
864                    PabloAST * const a = Q.front(); Q.pop_front();
865                    PabloAST * const b = Q.front(); Q.pop_front();
866                    PabloAST * expr = block->createAnd(a, b);
867                    replacement = addWithHardConstraints(ctx, solver, block, expr, M, ip);
868                    Q.push_back(expr);
869                }
870                assert (replacement);
871                demuxed[i] = Q.front(); Q.clear();
872
873                const auto f = M.find(input[i]);
874                assert (f != M.end());
875                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
876                M.erase(f);
877            }
878
879            assert (M.count(ip) == 0);
880
881            const auto satisfied = maxsat(ctx, solver, ordering);
882
883            if (LLVM_UNLIKELY(satisfied >= 0)) {
884
885                Z3_model model = Z3_solver_get_model(ctx, solver);
886                Z3_model_inc_ref(ctx, model);
887
888                std::vector<std::pair<long long int, Statement *>> I;
889
890                for (const auto i : M) {
891                    Z3_ast value;
892                    if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE)) {
893                        throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
894                    }
895                    long long int line;
896                    if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
897                        throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
898                    }
899                    I.emplace_back(line, std::get<0>(i));
900                }
901
902                Z3_model_dec_ref(ctx, model);
903
904                std::sort(I.begin(), I.end());
905
906                block->setInsertPoint(ip);
907                for (auto i : I) {
908                    block->insert(std::get<1>(i));
909                }
910
911                for (unsigned i = 0; i < n; ++i) {
912                    input[i]->replaceWith(demuxed[i], true, true);
913                    auto ref = mCharacterization.find(input[i]);
914                    assert (ref != mCharacterization.end());
915                    add(demuxed[i], std::get<0>(ref->second), -1);
916                }
917
918            } else { // fatal error; delete any statements we created.
919
920                for (unsigned i = 0; i < n; ++i) {
921                    if (LLVM_LIKELY(isa<Statement>(demuxed[i]))) {
922                        cast<Statement>(demuxed[i])->eraseFromParent(true);
923                    }
924                }
925
926            }
927
928            Z3_solver_pop(ctx, solver, 1);
929        }
930    }
931
932    Z3_solver_dec_ref(ctx, solver);
933    Z3_del_context(ctx);
934}
935
936/** ------------------------------------------------------------------------------------------------------------- *
937 * @brief doTransitiveReductionOfSubsetGraph
938 ** ------------------------------------------------------------------------------------------------------------- */
939void MultiplexingPass::doTransitiveReductionOfSubsetGraph() {
940    std::vector<SubsetGraph::vertex_descriptor> Q;
941    for (auto u : make_iterator_range(vertices(mSubsetGraph))) {
942        if (in_degree(u, mSubsetGraph) == 0 && out_degree(u, mSubsetGraph) != 0) {
943            Q.push_back(u);
944        }
945    }
946    flat_set<SubsetGraph::vertex_descriptor> targets;
947    flat_set<SubsetGraph::vertex_descriptor> visited;
948    do {
949        const auto u = Q.back(); Q.pop_back();
950        for (auto ei : make_iterator_range(out_edges(u, mSubsetGraph))) {
951            for (auto ej : make_iterator_range(out_edges(target(ei, mSubsetGraph), mSubsetGraph))) {
952                targets.insert(target(ej, mSubsetGraph));
953            }
954        }
955        for (auto v : targets) {
956            remove_edge(u, v, mSubsetGraph);
957        }
958        for (auto e : make_iterator_range(out_edges(u, mSubsetGraph))) {
959            const auto v = target(e, mSubsetGraph);
960            if (visited.insert(v).second) {
961                Q.push_back(v);
962            }
963        }
964    } while (!Q.empty());
965}
966
967/** ------------------------------------------------------------------------------------------------------------- *
968 * @brief get
969 ** ------------------------------------------------------------------------------------------------------------- */
970inline Z3_ast & MultiplexingPass::get(const PabloAST * const expr, const bool deref) {
971    assert (expr);
972    auto f = mCharacterization.find(expr);
973    if (LLVM_UNLIKELY(f == mCharacterization.end())) {
974        characterize(cast<Statement>(expr), false);
975        f = mCharacterization.find(expr);
976        assert (f != mCharacterization.end());
977    }
978    CharacterizationRef & ref = f->second;
979    if (deref) {
980        if (LLVM_LIKELY(std::get<1>(ref)) > 0) {
981            std::get<1>(ref) -= 1;
982        }
983    }
984    return std::get<0>(ref);
985}
986
987/** ------------------------------------------------------------------------------------------------------------- *
988 * @brief make
989 ** ------------------------------------------------------------------------------------------------------------- */
990inline Z3_ast MultiplexingPass::makeVar() {
991    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
992    Z3_inc_ref(mContext, node);
993    return node;
994}
995
996/** ------------------------------------------------------------------------------------------------------------- *
997 * @brief add
998 ** ------------------------------------------------------------------------------------------------------------- */
999inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node, const size_t refs) {
1000    mCharacterization.insert(std::make_pair(expr, std::make_pair(node, refs)));
1001    return node;
1002}
1003
1004/** ------------------------------------------------------------------------------------------------------------- *
1005 * @brief constructor
1006 ** ------------------------------------------------------------------------------------------------------------- */
1007inline MultiplexingPass::MultiplexingPass(PabloFunction & f, Z3_context context, Z3_solver solver)
1008: mContext(context)
1009, mSolver(solver)
1010, mFunction(f)
1011, mConstraintGraph(0)
1012{
1013
1014}
1015
1016} // end of namespace pablo
Note: See TracBrowser for help on using the repository browser.