source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_automultiplexing.cpp @ 5202

Last change on this file since 5202 was 5202, checked in by nmedfort, 3 years ago

Initial work on adding types to PabloAST and mutable Var objects.

File size: 41.5 KB
Line 
1#include "pablo_automultiplexing.hpp"
2#include <pablo/builder.hpp>
3#include <pablo/function.h>
4#include <pablo/printer_pablos.h>
5#include <boost/container/flat_set.hpp>
6#include <boost/container/flat_map.hpp>
7#include <boost/numeric/ublas/matrix.hpp>
8#include <boost/circular_buffer.hpp>
9#include <boost/graph/topological_sort.hpp>
10#include <boost/range/iterator_range.hpp>
11#include <pablo/analysis/pabloverifier.hpp>
12#include <pablo/optimizers/pablo_simplifier.hpp>
13#include <pablo/builder.hpp>
14#include <stack>
15#include <queue>
16#include <unordered_set>
17#include <functional>
18#include <llvm/Support/CommandLine.h>
19#include "maxsat.hpp"
20
21using namespace llvm;
22using namespace boost;
23using namespace boost::container;
24using namespace boost::numeric::ublas;
25
26static cl::OptionCategory MultiplexingOptions("Multiplexing Optimization Options", "These options control the Pablo Multiplexing optimization pass.");
27
28static cl::opt<unsigned> WindowSize("multiplexing-window-size", cl::init(100),
29                                        cl::desc("maximum sequence distance to consider for candidate set."),
30                                        cl::cat(MultiplexingOptions));
31
32
33namespace pablo {
34
35using TypeId = PabloAST::ClassTypeId;
36
37/** ------------------------------------------------------------------------------------------------------------- *
38 * @brief optimize
39 * @param function the function to optimize
40 ** ------------------------------------------------------------------------------------------------------------- */
41bool MultiplexingPass::optimize(PabloFunction & function) {
42
43
44    Z3_config cfg = Z3_mk_config();
45    Z3_context ctx = Z3_mk_context_rc(cfg);
46    Z3_del_config(cfg);
47    Z3_solver solver = Z3_mk_solver(ctx);
48    Z3_solver_inc_ref(ctx, solver);
49
50    MultiplexingPass mp(function, ctx, solver);
51
52    mp.optimize();
53
54    Z3_solver_dec_ref(ctx, solver);
55    Z3_del_context(ctx);
56
57    #ifndef NDEBUG
58    PabloVerifier::verify(function, "post-multiplexing");
59    #endif
60
61    Simplifier::optimize(function);
62
63    return true;
64}
65
66/** ------------------------------------------------------------------------------------------------------------- *
67 * @brief characterize
68 * @param function the function to optimize
69 ** ------------------------------------------------------------------------------------------------------------- */
70void MultiplexingPass::optimize() {
71    // Map the constants and input variables
72    PabloBlock * entryBlock = mFunction.getEntryBlock();
73    add(entryBlock->createZeroes(), Z3_mk_false(mContext), -1);
74    add(entryBlock->createOnes(), Z3_mk_true(mContext), -1);
75    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
76        add(mFunction.getParameter(i), makeVar(), -1);
77    }
78    optimize(mFunction.getEntryBlock());
79}
80
81/** ------------------------------------------------------------------------------------------------------------- *
82 * @brief characterize
83 ** ------------------------------------------------------------------------------------------------------------- */
84void MultiplexingPass::optimize(PabloBlock * const block) {
85    Statement * begin = block->front();
86    Statement * end = initialize(begin);
87    for (Statement * stmt : *block) {
88        if (LLVM_UNLIKELY(stmt == end)) {
89            Statement * const next = stmt->getNextNode();
90            multiplex(block, begin, stmt);
91            if (isa<If>(stmt)) {
92                optimize(cast<If>(stmt)->getBody());
93            } else if (isa<While>(stmt)) {
94                for (const Var * var : cast<While>(stmt)->getEscaped()) {
95                    Z3_inc_ref(mContext, get(var));
96                }
97                optimize(cast<While>(stmt)->getBody());
98                // since we cannot be certain that we'll always execute at least one iteration of a loop, we must
99                // assume that the variants could either be their initial or resulting value.
100                for (const Var * var : cast<While>(stmt)->getEscaped()) {
101                    Z3_ast v0 = get(var);
102                    Z3_ast & v1 = get(var);
103                    Z3_ast merge[2] = { v0, v1 };
104                    Z3_ast r = Z3_mk_or(mContext, 2, merge);
105                    Z3_inc_ref(mContext, r);
106                    Z3_dec_ref(mContext, v0);
107                    Z3_dec_ref(mContext, v1);
108                    v1 = r;
109                    assert (get(var) == r);
110                }
111            }
112            end = initialize(begin = next);
113        } else {
114            characterize(stmt);
115        }
116    }
117    multiplex(block, begin, nullptr);
118}
119
120/** ------------------------------------------------------------------------------------------------------------- *
121 * @brief multiplex
122 ** ------------------------------------------------------------------------------------------------------------- */
123void MultiplexingPass::multiplex(PabloBlock * const block, Statement * const begin, Statement * const end) {
124    if (generateCandidateSets(begin, end)) {
125        selectMultiplexSetsGreedy();
126        // eliminateSubsetConstraints();
127        multiplexSelectedSets(block, begin, end);
128    }
129}
130
131/** ------------------------------------------------------------------------------------------------------------- *
132 * @brief equals
133 ** ------------------------------------------------------------------------------------------------------------- */
134inline bool MultiplexingPass::equals(Z3_ast a, Z3_ast b) {
135    Z3_solver_push(mContext, mSolver);
136    Z3_ast test = Z3_mk_eq(mContext, a, b);
137    Z3_inc_ref(mContext, test);
138    Z3_solver_assert(mContext, mSolver, test);
139    const auto r = Z3_solver_check(mContext, mSolver);
140    Z3_dec_ref(mContext, test);
141    Z3_solver_pop(mContext, mSolver, 1);
142    return (r == Z3_L_TRUE);
143}
144
145/** ------------------------------------------------------------------------------------------------------------- *
146 * @brief handle_unexpected_statement
147 ** ------------------------------------------------------------------------------------------------------------- */
148static void handle_unexpected_statement(const Statement * const stmt) {
149    std::string tmp;
150    raw_string_ostream err(tmp);
151    err << "Unexpected statement type: ";
152    PabloPrinter::print(stmt, err);
153    throw std::runtime_error(err.str());
154}
155
156/** ------------------------------------------------------------------------------------------------------------- *
157 * @brief characterize
158 ** ------------------------------------------------------------------------------------------------------------- */
159Z3_ast MultiplexingPass::characterize(const Statement * const stmt, const bool deref) {
160
161    const size_t n = stmt->getNumOperands(); assert (n > 0);
162    Z3_ast operands[n];
163    for (size_t i = 0; i < n; ++i) {
164        PabloAST * op = stmt->getOperand(i);
165        if (LLVM_UNLIKELY(isa<Integer>(op) || isa<String>(op))) {
166            continue;
167        }
168        operands[i] = get(op, deref);
169    }
170
171    Z3_ast node = operands[0];
172    switch (stmt->getClassTypeId()) {
173        case TypeId::Assign:
174        case TypeId::AtEOF:
175        case TypeId::InFile:
176            node = operands[0]; break;
177        case TypeId::And:
178            node = Z3_mk_and(mContext, n, operands); break;
179        case TypeId::Or:
180            node = Z3_mk_or(mContext, n, operands); break;
181        case TypeId::Xor:
182            node = Z3_mk_xor(mContext, operands[0], operands[1]);
183            Z3_inc_ref(mContext, node);
184            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
185                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
186                Z3_inc_ref(mContext, temp);
187                Z3_dec_ref(mContext, node);
188                node = temp;
189            }
190            return add(stmt, node, stmt->getNumUses());
191        case TypeId::Not:
192            node = Z3_mk_not(mContext, node);
193            break;
194        case TypeId::Sel:
195            node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
196            break;
197        case TypeId::Advance:
198            return characterize(cast<Advance>(stmt), operands[0]);
199        case TypeId::ScanThru:
200            // ScanThru(c, m) := (c + m) ∧ ¬m. Thus we can conservatively represent this statement using the BDD
201            // for ¬m --- provided no derivative of this statement is negated in any fashion.
202        case TypeId::MatchStar:
203        case TypeId::Count:
204            node = makeVar();
205            break;
206        default:
207            handle_unexpected_statement(stmt);
208    }
209    Z3_inc_ref(mContext, node);
210    return add(stmt, node, stmt->getNumUses());
211}
212
213
214/** ------------------------------------------------------------------------------------------------------------- *
215 * @brief characterize
216 ** ------------------------------------------------------------------------------------------------------------- */
217inline Z3_ast MultiplexingPass::characterize(const Advance * const adv, Z3_ast Ik) {
218    const auto k = mNegatedAdvance.size();
219
220    assert (adv);
221    assert (mConstraintGraph[k] == adv);
222    std::vector<bool> unconstrained(k);
223
224    Z3_solver_push(mContext, mSolver);
225
226    for (size_t i = 0; i < k; ++i) {
227
228        // Have we already proven that they are unconstrained by their subset relationship?
229        if (unconstrained[i]) continue;
230
231        // If these Advances are mutually exclusive, in the same scope, transitively independent, and shift their
232        // values by the same amount, we can safely multiplex them. Otherwise mark the constraint in the graph.
233        const Advance * const ithAdv = mConstraintGraph[i];
234        if (ithAdv->getOperand(1) == adv->getOperand(1)) {
235
236            Z3_ast Ii = get(ithAdv->getOperand(0));
237
238            // Is there any satisfying truth assignment? If not, these streams are mutually exclusive.
239
240            Z3_solver_push(mContext, mSolver);
241            Z3_ast conj[2] = { Ii, Ik };
242            Z3_ast IiIk = Z3_mk_and(mContext, 2, conj);
243            Z3_inc_ref(mContext, IiIk);
244            Z3_solver_assert(mContext, mSolver, IiIk);
245            if (Z3_solver_check(mContext, mSolver) == Z3_L_FALSE) {
246                // If Ai ∩ Ak = ∅ and Aj ⊂ Ai, Aj ∩ Ak = ∅.
247//                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
248//                    unconstrained[source(e, mSubsetGraph)] = true;
249//                }
250                unconstrained[i] = true;
251            }/* else if (equals(Ii, IiIk)) {
252                // If Ii = Ii ∩ Ik then Ii ⊆ Ik. Record this in the subset graph with the arc (i, k).
253                // Note: the AST will be modified to make these mutually exclusive if Ai and Ak end up in
254                // the same multiplexing set.
255                add_edge(i, k, mSubsetGraph);
256                // If Ai ⊂ Ak and Aj ⊂ Ai, Aj ⊂ Ak.
257                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
258                    const auto j = source(e, mSubsetGraph);
259                    add_edge(j, k, mSubsetGraph);
260                    unconstrained[j] = true;
261                }
262                unconstrained[i] = true;
263            } else if (equals(Ik, IiIk)) {
264                // If Ik = Ii ∩ Ik then Ik ⊆ Ii. Record this in the subset graph with the arc (k, i).
265                add_edge(k, i, mSubsetGraph);
266                // If Ak ⊂ Ai and Ai ⊂ Aj, Ak ⊂ Aj.
267                for (auto e : make_iterator_range(out_edges(i, mSubsetGraph))) {
268                    const auto j = target(e, mSubsetGraph);
269                    add_edge(k, j, mSubsetGraph);
270                    unconstrained[j] = true;
271                }
272                unconstrained[i] = true;
273            }*/
274            Z3_dec_ref(mContext, IiIk);
275            Z3_solver_pop(mContext, mSolver, 1);
276        }
277    }
278
279    Z3_solver_pop(mContext, mSolver, 1);
280
281    Z3_ast Ak0 = makeVar();
282    Z3_inc_ref(mContext, Ak0);
283    Z3_ast Nk = Z3_mk_not(mContext, Ak0);
284    Z3_inc_ref(mContext, Nk);
285
286    Z3_ast vars[k + 1];
287    vars[0] = Ak0;
288
289    unsigned m = 1;
290    for (unsigned i = 0; i < k; ++i) {
291        if (unconstrained[i]) {
292            // This algorithm deems two streams mutually exclusive if and only if their conjuntion is a contradiction.
293            // To generate a contradiction when comparing Advances, the BDD of each Advance is represented by the conjunction of
294            // variables representing the k-th Advance and the negation of all variables for the Advances whose inputs are mutually
295            // exclusive with the k-th input.
296
297            // For example, if the input of the i-th Advance is mutually exclusive with the input of the j-th and k-th Advance, the
298            // BDD of the i-th Advance is Ai ∧ ¬Aj ∧ ¬Ak. Similarly, the j- and k-th Advance is Aj ∧ ¬Ai and Ak ∧ ¬Ai, respectively
299            // (assuming that the j-th and k-th Advance are not mutually exclusive.)
300
301            Z3_ast & Ai0 = get(mConstraintGraph[i]);
302            Z3_ast conj[2] = { Ai0, Nk };
303            Z3_ast Ai = Z3_mk_and(mContext, 2, conj);
304            Z3_inc_ref(mContext, Ai);
305            Z3_dec_ref(mContext, Ai0);
306            Ai0 = Ai;
307            assert (get(mConstraintGraph[i]) == Ai);
308
309            vars[m++] = mNegatedAdvance[i];
310
311            continue; // note: if these Advances are transitively dependent, an edge will still exist.
312        }
313        const auto ei = add_edge(i, k, mConstraintGraph);
314        // if this is not a new edge, it must have a dependency constraint.
315        if (ei.second) {
316            mConstraintGraph[ei.first] = ConstraintType::Inclusive;
317        }
318    }
319    // To minimize the number of BDD computations, we store the negated variable instead of negating it each time.
320    mNegatedAdvance.emplace_back(Nk);
321    Z3_ast Ak = Z3_mk_and(mContext, m, vars);
322    if (LLVM_UNLIKELY(Ak != Ak0)) {
323        Z3_inc_ref(mContext, Ak);
324        Z3_dec_ref(mContext, Ak0);
325    }
326    return add(adv, Ak, -1);
327}
328
329/** ------------------------------------------------------------------------------------------------------------- *
330 * @brief initialize
331 ** ------------------------------------------------------------------------------------------------------------- */
332Statement * MultiplexingPass::initialize(Statement * const initial) {
333
334    // clean up any unneeded refs / characterizations.
335    for (auto i = mCharacterization.begin(); i != mCharacterization.end(); ) {
336        const auto ref = i->second;
337        auto next = i; ++next;
338        if (LLVM_UNLIKELY(ref.second == 0)) {
339            assert (isa<Statement>(i->first));
340            Z3_dec_ref(mContext, ref.first);
341            mCharacterization.erase(i);
342        }
343        i = next;
344    }
345    for (Z3_ast var : mNegatedAdvance) {
346        Z3_dec_ref(mContext, var);
347    }
348    mNegatedAdvance.clear();
349
350    // Scan through and count all the advances and statements ...
351    unsigned statements = 0, advances = 0;
352    Statement * last = nullptr;
353    for (Statement * stmt = initial; stmt; stmt = stmt->getNextNode()) {
354        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
355            last = stmt;
356            break;
357        } else if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
358            ++advances;
359        }
360        ++statements;
361    }
362
363    flat_map<const PabloAST *, unsigned> M;
364    M.reserve(statements);
365    matrix<bool> G(statements, advances, false);
366    for (unsigned i = 0; i != advances; ++i) {
367        G(i, i) = true;
368    }
369
370    mConstraintGraph = ConstraintGraph(advances);
371    unsigned n = advances;
372    unsigned k = 0;
373    for (Statement * stmt = initial; stmt != last; stmt = stmt->getNextNode()) {
374        assert (!isa<If>(stmt) && !isa<While>(stmt));
375        unsigned u = 0;
376        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
377            mConstraintGraph[k] = cast<Advance>(stmt);
378            u = k++;
379        } else {
380            u = n++;
381        }
382        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
383            const PabloAST * const op = stmt->getOperand(i);
384            if (LLVM_LIKELY(isa<Statement>(op))) {
385                auto f = M.find(op);
386                if (f != M.end()) {
387                    const unsigned v = std::get<1>(*f);
388                    for (unsigned w = 0; w != k; ++w) {
389                        G(u, w) |= G(v, w);
390                    }
391                }
392            }
393        }
394        M.emplace(stmt, u);
395    }
396
397    assert (k == advances);
398
399    // Initialize the base constraint graph by transposing G and removing reflective loops
400    for (unsigned i = 0; i != advances; ++i) {
401        for (unsigned j = 0; j < i; ++j) {
402            if (G(i, j)) {
403                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
404            }
405        }
406        for (unsigned j = i + 1; j < advances; ++j) {
407            if (G(i, j)) {
408                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
409            }
410        }
411    }
412
413    mSubsetGraph = SubsetGraph(advances);
414    mNegatedAdvance.reserve(advances);
415
416    return last;
417}
418
419/** ------------------------------------------------------------------------------------------------------------- *
420 * @brief generateCandidateSets
421 ** ------------------------------------------------------------------------------------------------------------- */
422bool MultiplexingPass::generateCandidateSets(Statement * const begin, Statement * const end) {
423
424    const auto n = mNegatedAdvance.size();
425    if (LLVM_UNLIKELY(n < 3)) {
426        return false;
427    }
428    assert (num_vertices(mConstraintGraph) == n);
429
430    mCandidateGraph = CandidateGraph(n);
431
432    Z3_config cfg = Z3_mk_config();
433    Z3_set_param_value(cfg, "MODEL", "true");
434    Z3_context ctx = Z3_mk_context(cfg);
435    Z3_del_config(cfg);
436    Z3_solver solver = Z3_mk_solver(ctx);
437    Z3_solver_inc_ref(ctx, solver);
438
439    std::vector<Z3_ast> V(n);
440    for (unsigned i = 0; i < n; ++i) {
441        V[i] = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_bool_sort(ctx)); assert (V[i]);
442    }
443    std::vector<std::pair<unsigned, unsigned>> S;
444    S.reserve(n);
445
446    unsigned line = 0;
447    unsigned i = 0;
448    for (Statement * stmt = begin; stmt != end; stmt = stmt->getNextNode()) {
449        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
450            assert (S.empty() || line > std::get<0>(S.back()));
451            assert (cast<Advance>(stmt) == mConstraintGraph[i]);
452            if (S.size() > 0 && (line - std::get<0>(S.front())) > WindowSize) {
453                // try to compute a maximal set for this given set of Advances
454                if (S.size() > 2) {
455                    generateCandidateSets(ctx, solver, S, V);
456                }
457                // erase any that preceed our window
458                auto end = S.begin();
459                while (++end != S.end() && ((line - std::get<0>(*end)) > WindowSize));
460                S.erase(S.begin(), end);
461            }
462            for (unsigned j : make_iterator_range(adjacent_vertices(i, mConstraintGraph))) {
463                Z3_ast disj[2] = { Z3_mk_not(ctx, V[j]), Z3_mk_not(ctx, V[i]) };
464                Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, disj));
465            }
466            S.emplace_back(line, i++);
467        }
468        ++line;
469    }
470    if (S.size() > 2) {
471        generateCandidateSets(ctx, solver, S, V);
472    }
473
474    Z3_solver_dec_ref(ctx, solver);
475    Z3_del_context(ctx);
476
477    return num_vertices(mCandidateGraph) > n;
478}
479
480/** ------------------------------------------------------------------------------------------------------------- *
481 * @brief generateCandidateSets
482 ** ------------------------------------------------------------------------------------------------------------- */
483void MultiplexingPass::generateCandidateSets(Z3_context ctx, Z3_solver solver, const std::vector<std::pair<unsigned, unsigned>> & S, const std::vector<Z3_ast> & V) {
484    assert (S.size() > 2);
485    assert (std::get<0>(S.front()) < std::get<0>(S.back()));
486    assert ((std::get<0>(S.back()) - std::get<0>(S.front())) <= WindowSize);
487
488    Z3_solver_push(ctx, solver);
489
490    const auto n = V.size();
491    std::vector<unsigned> M(S.size());
492    for (unsigned i = 0; i < S.size(); ++i) {
493        M[i] = std::get<1>(S[i]);
494    }
495
496    for (;;) {
497
498        std::vector<Z3_ast> assumptions(M.size());
499        unsigned j = 0;
500        for (unsigned i = 0; i < n; ++i) {
501            if (LLVM_UNLIKELY((j < M.size()) && (M[j] == i))) { // in our window range
502                assumptions[j++] = V[i]; assert (V[i]);
503            } else {
504                Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[i]));
505            }
506        }
507        assert (j == M.size());
508
509        if (maxsat(ctx, solver, assumptions) >= 0) {
510            Z3_model m = Z3_solver_get_model(ctx, solver);
511            Z3_model_inc_ref(ctx, m);
512            const auto k = add_vertex(mCandidateGraph); assert(k >= V.size());
513            Z3_ast TRUE = Z3_mk_true(ctx);
514            for (auto i = M.begin(); i != M.end(); ) {
515                Z3_ast value;
516                if (LLVM_UNLIKELY(Z3_model_eval(ctx, m, V[*i], 1, &value) != Z3_TRUE)) {
517                    throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from constraint model!");
518                }
519                if (value == TRUE) {
520                    add_edge(*i, k, mCandidateGraph);
521                    Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[*i]));
522                    i = M.erase(i);
523                } else {
524                    ++i;
525                }
526            }
527            Z3_model_dec_ref(ctx, m);
528            if (M.size() > 2) {
529                continue;
530            }
531        }
532        break;
533    }
534
535    Z3_solver_pop(ctx, solver, 1);
536}
537
538/** ------------------------------------------------------------------------------------------------------------- *
539 * @brief is_power_of_2
540 * @param n an integer
541 ** ------------------------------------------------------------------------------------------------------------- */
542static inline bool is_power_of_2(const size_t n) {
543    return ((n & (n - 1)) == 0);
544}
545
546/** ------------------------------------------------------------------------------------------------------------- *
547 * @brief log2_plus_one
548 ** ------------------------------------------------------------------------------------------------------------- */
549static inline size_t log2_plus_one(const size_t n) {
550    return std::log2<size_t>(n) + 1;
551}
552
553/** ------------------------------------------------------------------------------------------------------------- *
554 * @brief selectMultiplexSetsGreedy
555 *
556 * This algorithm is simply computes a greedy set cover. We want an exact max-weight set cover but can generate new
557 * sets by taking a subset of any existing set. With a few modifications, the greedy approach seems to work well
558 * enough but can be shown to produce a suboptimal solution if there are three candidate sets labelled A, B and C,
559 * in which A ∩ B = ∅, |A| ≀ |B| < |C|, and C ⊂ (A ∪ B).
560 ** ------------------------------------------------------------------------------------------------------------- */
561void MultiplexingPass::selectMultiplexSetsGreedy() {
562
563    using AdjIterator = graph_traits<CandidateGraph>::adjacency_iterator;
564    using degree_t = CandidateGraph::degree_size_type;
565    using vertex_t = CandidateGraph::vertex_descriptor;
566
567    const size_t m = num_vertices(mConstraintGraph);
568    const size_t n = num_vertices(mCandidateGraph) - m;
569
570    std::vector<bool> chosen(n);
571
572    for (;;) {
573
574        // Choose the set with the greatest number of vertices not already included in some other set.
575        vertex_t u = 0;
576        degree_t w = 0;
577        for (vertex_t i = 0; i != n; ++i) {
578            if (chosen[i]) continue;
579            const auto t = i + m;
580            degree_t r = degree(t, mCandidateGraph);
581            if (LLVM_LIKELY(r >= 3)) { // if this set has at least 3 elements.
582                if (w < r) {
583                    u = t;
584                    w = r;
585                }
586            } else if (r) {
587                clear_vertex(t, mCandidateGraph);
588            }
589        }
590
591        // Multiplexing requires 3 or more elements; if no set contains at least 3, abort.
592        if (LLVM_UNLIKELY(w == 0)) {
593            break;
594        }
595
596        chosen[u - m] = true;
597
598        // If this contains 2^n elements for any n, discard the member that is most likely to be added
599        // to some future set.
600        if (LLVM_UNLIKELY(is_power_of_2(degree(u, mCandidateGraph)))) {
601            vertex_t x = 0;
602            degree_t w = 0;
603            for (const auto v : make_iterator_range(adjacent_vertices(u, mCandidateGraph))) {
604                if (degree(v, mCandidateGraph) > w) {
605                    x = v;
606                    w = degree(v, mCandidateGraph);
607                }
608            }
609            remove_edge(u, x, mCandidateGraph);
610        }
611
612        AdjIterator begin, end;
613        std::tie(begin, end) = adjacent_vertices(u, mCandidateGraph);
614        for (auto vi = begin; vi != end; ) {
615            const auto v = *vi++;
616            clear_vertex(v, mCandidateGraph);
617            add_edge(v, u, mCandidateGraph);
618        }
619
620    }
621
622    #ifndef NDEBUG
623    for (unsigned i = 0; i != m; ++i) {
624        assert (degree(i, mCandidateGraph) <= 1);
625    }
626    for (unsigned i = m; i != (m + n); ++i) {
627        assert (degree(i, mCandidateGraph) == 0 || degree(i, mCandidateGraph) >= 3);
628    }
629    #endif
630}
631
632/** ------------------------------------------------------------------------------------------------------------- *
633 * @brief eliminateSubsetConstraints
634 ** ------------------------------------------------------------------------------------------------------------- */
635void MultiplexingPass::eliminateSubsetConstraints() {
636    using SubsetEdgeIterator = graph_traits<SubsetGraph>::edge_iterator;
637    // If Ai ⊂ Aj then the subset graph will contain the arc (i, j). Remove all arcs corresponding to vertices
638    // that are not elements of the same multiplexing set.
639    SubsetEdgeIterator ei, ei_end, ei_next;
640    std::tie(ei, ei_end) = edges(mSubsetGraph);
641    for (ei_next = ei; ei != ei_end; ei = ei_next) {
642        ++ei_next;
643        const auto u = source(*ei, mSubsetGraph);
644        const auto v = target(*ei, mSubsetGraph);
645        if (degree(u, mCandidateGraph) != 0 && degree(v, mCandidateGraph) != 0) {
646            assert (degree(u, mCandidateGraph) == 1);
647            assert (degree(v, mCandidateGraph) == 1);
648            const auto su = *(adjacent_vertices(u, mCandidateGraph).first);
649            const auto sv = *(adjacent_vertices(v, mCandidateGraph).first);
650            if (su == sv) {
651                continue;
652            }
653        }
654        remove_edge(*ei, mSubsetGraph);
655    }
656
657    if (num_edges(mSubsetGraph) != 0) {
658
659        // At least one subset constraint exists; perform a transitive reduction on the graph to ensure that
660        // we perform the minimum number of AST modifications for the selected multiplexing sets.
661
662        doTransitiveReductionOfSubsetGraph();
663
664        // Afterwards modify the AST to ensure that multiplexing algorithm can ignore any subset constraints
665        for (auto e : make_iterator_range(edges(mSubsetGraph))) {
666            Advance * const adv1 = mConstraintGraph[source(e, mSubsetGraph)];
667            Advance * const adv2 = mConstraintGraph[target(e, mSubsetGraph)];
668            assert (adv1->getParent() == adv2->getParent());
669            PabloBlock * const pb = adv1->getParent();
670            pb->setInsertPoint(adv2->getPrevNode());
671            adv2->setOperand(0, pb->createAnd(adv2->getOperand(0), pb->createNot(adv1->getOperand(0)), "subset"));
672            pb->setInsertPoint(adv2);
673            adv2->replaceAllUsesWith(pb->createOr(adv1, adv2, "merge"));
674        }
675
676    }
677}
678
679/** ------------------------------------------------------------------------------------------------------------- *
680 * @brief addWithHardConstraints
681 ** ------------------------------------------------------------------------------------------------------------- */
682Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, Statement * const stmt, flat_map<Statement *, Z3_ast> & M) {
683    assert (M.count(stmt) == 0 && stmt->getParent() == block);
684    // compute the hard dependency constraints
685    Z3_ast node = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_int_sort(ctx)); assert (node);
686    // we want all numbers to be positive so that the soft assertion that the first statement ought to stay at the first location
687    // whenever possible isn't satisfied by making preceeding numbers negative.
688    Z3_solver_assert(ctx, solver, Z3_mk_gt(ctx, node, Z3_mk_int(ctx, 0, Z3_mk_int_sort(ctx))));
689    for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
690        PabloAST * const op = stmt->getOperand(i);
691        if (isa<Statement>(op) && cast<Statement>(op)->getParent() == block) {
692            const auto f = M.find(cast<Statement>(op));
693            if (f != M.end()) {
694                Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, f->second, node));
695            }
696        }
697    }
698    M.emplace(stmt, node);
699    return node;
700}
701
702/** ------------------------------------------------------------------------------------------------------------- *
703 * @brief dominates
704 *
705 * does Statement a dominate Statement b?
706 ** ------------------------------------------------------------------------------------------------------------- */
707bool dominates(const Statement * const a, const Statement * const b) {
708    assert (a);
709    if (LLVM_UNLIKELY(b == nullptr)) {
710        return false;
711    }
712    assert (a->getParent() == b->getParent());
713    for (const Statement * t : *a->getParent()) {
714        if (t == a) {
715            return true;
716        } else if (t == b) {
717            return false;
718        }
719    }
720    llvm_unreachable("Neither a nor b are in their reported block!");
721    return false;
722}
723
724/** ------------------------------------------------------------------------------------------------------------- *
725 * @brief addWithHardConstraints
726 ** ------------------------------------------------------------------------------------------------------------- */
727Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, PabloAST * expr, flat_map<Statement *, Z3_ast> & M, Statement * const ip) {
728    if (isa<Statement>(expr)) {
729        Statement * const stmt = cast<Statement>(expr);
730        if (stmt->getParent() == block) {
731            const auto f = M.find(stmt);
732            if (LLVM_UNLIKELY(f != M.end())) {
733                return f->second;
734            } else if (!dominates(stmt, ip)) {
735                for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
736                    addWithHardConstraints(ctx, solver, block, stmt->getOperand(i), M, ip);
737                }
738                return addWithHardConstraints(ctx, solver, block, stmt, M);
739            }
740        }
741    }
742    return nullptr;
743}
744
745/** ------------------------------------------------------------------------------------------------------------- *
746 * @brief multiplexSelectedSets
747 ** ------------------------------------------------------------------------------------------------------------- */
748inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block, Statement * const begin, Statement * const end) {
749
750    assert ("begin cannot be null!" && begin);
751    assert (begin->getParent() == block);
752    assert (!end || end->getParent() == block);
753    assert (!end || isa<If>(end) || isa<While>(end));
754
755    // TODO: should we test whether sets overlap and merge the computations together?
756
757    Z3_config cfg = Z3_mk_config();
758    Z3_set_param_value(cfg, "MODEL", "true");
759    Z3_context ctx = Z3_mk_context(cfg);
760    Z3_del_config(cfg);
761    Z3_solver solver = Z3_mk_solver(ctx);
762    Z3_solver_inc_ref(ctx, solver);
763
764    const auto first_set = num_vertices(mConstraintGraph);
765    const auto last_set = num_vertices(mCandidateGraph);
766
767    for (auto idx = first_set; idx != last_set; ++idx) {
768        const size_t n = degree(idx, mCandidateGraph);
769        if (n) {
770            const size_t m = log2_plus_one(n); assert (n > 2 && m < n);
771            Advance * input[n];
772            PabloAST * muxed[m];
773            PabloAST * muxed_n[m];
774            PabloAST * demuxed[n];
775
776            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
777            unsigned i = 0;
778            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) {
779                input[i] = mConstraintGraph[u];
780                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
781                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
782                ++i;
783            }
784
785            // We can't trust the AST will be in the original order as we can multiplex a region of the program
786            // more than once.
787            Statement * initial = nullptr, * sentinal = nullptr;
788            for (Statement * stmt : *block) {
789                if (isa<Advance>(stmt)) {
790                    for (unsigned i = 0; i < n; ++i) {
791                        if (stmt == input[i]) {
792                            initial = initial ? initial : stmt;
793                            sentinal = stmt;
794                            break;
795                        }
796                    }
797                }
798            }
799            assert (initial);
800
801            Statement * const ip = initial->getPrevNode(); // save our insertion point prior to modifying the AST
802            sentinal = sentinal->getNextNode();
803
804            Z3_solver_push(ctx, solver);
805
806            // Compute the hard and soft constraints for any part of the AST that we are not intending to modify.
807            flat_map<Statement *, Z3_ast> M;
808
809            Z3_ast prior = nullptr;
810            Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
811            std::vector<Z3_ast> ordering;
812
813            for (Statement * stmt = initial; stmt != sentinal; stmt = stmt->getNextNode()) { assert (stmt != ip);
814                Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
815                // compute the soft ordering constraints
816                Z3_ast num = one;
817                if (prior) {
818                    Z3_ast prior_plus_one[2] = { prior, one };
819                    num = Z3_mk_add(ctx, 2, prior_plus_one);
820                }
821                ordering.push_back(Z3_mk_eq(ctx, node, num));
822                if (prior) {
823                    ordering.push_back(Z3_mk_lt(ctx, prior, node));
824                }
825                prior = node;
826            }
827
828            block->setInsertPoint(block->back()); // <- necessary for domination check!
829
830            circular_buffer<PabloAST *> Q(n);
831
832            /// Perform n-to-m Multiplexing
833            for (size_t j = 0; j != m; ++j) {
834                std::ostringstream prefix;
835                prefix << "mux" << n << "to" << m << '.' << (j);
836                assert (Q.empty());
837                for (size_t i = 0; i != n; ++i) {
838                    if (((i + 1) & (1UL << j)) != 0) {
839                        Q.push_back(input[i]->getOperand(0));
840                    }
841                }
842                while (Q.size() > 1) {
843                    PabloAST * a = Q.front(); Q.pop_front();
844                    PabloAST * b = Q.front(); Q.pop_front();
845                    PabloAST * expr = block->createOr(a, b);
846                    addWithHardConstraints(ctx, solver, block, expr, M, ip);
847                    Q.push_back(expr);
848                }
849                PabloAST * const muxing = Q.front(); Q.clear();
850                muxed[j] = block->createAdvance(muxing, input[0]->getOperand(1), prefix.str());
851                addWithHardConstraints(ctx, solver, block, muxed[j], M, ip);
852                muxed_n[j] = block->createNot(muxed[j]);
853                addWithHardConstraints(ctx, solver, block, muxed_n[j], M, ip);
854            }
855
856            /// Perform m-to-n Demultiplexing
857            for (size_t i = 0; i != n; ++i) {
858                // Construct the demuxed values and replaces all the users of the original advances with them.
859                assert (Q.empty());
860                for (size_t j = 0; j != m; ++j) {
861                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
862                }
863                Z3_ast replacement = nullptr;
864                while (Q.size() > 1) {
865                    PabloAST * const a = Q.front(); Q.pop_front();
866                    PabloAST * const b = Q.front(); Q.pop_front();
867                    PabloAST * expr = block->createAnd(a, b);
868                    replacement = addWithHardConstraints(ctx, solver, block, expr, M, ip);
869                    Q.push_back(expr);
870                }
871                assert (replacement);
872                demuxed[i] = Q.front(); Q.clear();
873
874                const auto f = M.find(input[i]);
875                assert (f != M.end());
876                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
877                M.erase(f);
878            }
879
880            assert (M.count(ip) == 0);
881
882            const auto satisfied = maxsat(ctx, solver, ordering);
883
884            if (LLVM_UNLIKELY(satisfied >= 0)) {
885
886                Z3_model model = Z3_solver_get_model(ctx, solver);
887                Z3_model_inc_ref(ctx, model);
888
889                std::vector<std::pair<long long int, Statement *>> I;
890
891                for (const auto i : M) {
892                    Z3_ast value;
893                    if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE)) {
894                        throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
895                    }
896                    long long int line;
897                    if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
898                        throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
899                    }
900                    I.emplace_back(line, std::get<0>(i));
901                }
902
903                Z3_model_dec_ref(ctx, model);
904
905                std::sort(I.begin(), I.end());
906
907                block->setInsertPoint(ip);
908                for (auto i : I) {
909                    block->insert(std::get<1>(i));
910                }
911
912                for (unsigned i = 0; i < n; ++i) {
913                    input[i]->replaceWith(demuxed[i], true, true);
914                    auto ref = mCharacterization.find(input[i]);
915                    assert (ref != mCharacterization.end());
916                    add(demuxed[i], std::get<0>(ref->second), -1);
917                }
918
919            } else { // fatal error; delete any statements we created.
920
921                for (unsigned i = 0; i < n; ++i) {
922                    if (LLVM_LIKELY(isa<Statement>(demuxed[i]))) {
923                        cast<Statement>(demuxed[i])->eraseFromParent(true);
924                    }
925                }
926
927            }
928
929            Z3_solver_pop(ctx, solver, 1);
930        }
931    }
932
933    Z3_solver_dec_ref(ctx, solver);
934    Z3_del_context(ctx);
935}
936
937/** ------------------------------------------------------------------------------------------------------------- *
938 * @brief doTransitiveReductionOfSubsetGraph
939 ** ------------------------------------------------------------------------------------------------------------- */
940void MultiplexingPass::doTransitiveReductionOfSubsetGraph() {
941    std::vector<SubsetGraph::vertex_descriptor> Q;
942    for (auto u : make_iterator_range(vertices(mSubsetGraph))) {
943        if (in_degree(u, mSubsetGraph) == 0 && out_degree(u, mSubsetGraph) != 0) {
944            Q.push_back(u);
945        }
946    }
947    flat_set<SubsetGraph::vertex_descriptor> targets;
948    flat_set<SubsetGraph::vertex_descriptor> visited;
949    do {
950        const auto u = Q.back(); Q.pop_back();
951        for (auto ei : make_iterator_range(out_edges(u, mSubsetGraph))) {
952            for (auto ej : make_iterator_range(out_edges(target(ei, mSubsetGraph), mSubsetGraph))) {
953                targets.insert(target(ej, mSubsetGraph));
954            }
955        }
956        for (auto v : targets) {
957            remove_edge(u, v, mSubsetGraph);
958        }
959        for (auto e : make_iterator_range(out_edges(u, mSubsetGraph))) {
960            const auto v = target(e, mSubsetGraph);
961            if (visited.insert(v).second) {
962                Q.push_back(v);
963            }
964        }
965    } while (!Q.empty());
966}
967
968/** ------------------------------------------------------------------------------------------------------------- *
969 * @brief get
970 ** ------------------------------------------------------------------------------------------------------------- */
971inline Z3_ast & MultiplexingPass::get(const PabloAST * const expr, const bool deref) {
972    assert (expr);
973    auto f = mCharacterization.find(expr);
974    if (LLVM_UNLIKELY(f == mCharacterization.end())) {
975        characterize(cast<Statement>(expr), false);
976        f = mCharacterization.find(expr);
977        assert (f != mCharacterization.end());
978    }
979    CharacterizationRef & ref = f->second;
980    if (deref) {
981        if (LLVM_LIKELY(std::get<1>(ref)) > 0) {
982            std::get<1>(ref) -= 1;
983        }
984    }
985    return std::get<0>(ref);
986}
987
988/** ------------------------------------------------------------------------------------------------------------- *
989 * @brief make
990 ** ------------------------------------------------------------------------------------------------------------- */
991inline Z3_ast MultiplexingPass::makeVar() {
992    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
993    Z3_inc_ref(mContext, node);
994    return node;
995}
996
997/** ------------------------------------------------------------------------------------------------------------- *
998 * @brief add
999 ** ------------------------------------------------------------------------------------------------------------- */
1000inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node, const size_t refs) {
1001    mCharacterization.insert(std::make_pair(expr, std::make_pair(node, refs)));
1002    return node;
1003}
1004
1005/** ------------------------------------------------------------------------------------------------------------- *
1006 * @brief constructor
1007 ** ------------------------------------------------------------------------------------------------------------- */
1008inline MultiplexingPass::MultiplexingPass(PabloFunction & f, Z3_context context, Z3_solver solver)
1009: mContext(context)
1010, mSolver(solver)
1011, mFunction(f)
1012, mConstraintGraph(0)
1013{
1014
1015}
1016
1017} // end of namespace pablo
Note: See TracBrowser for help on using the repository browser.