source: icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_automultiplexing.cpp @ 5350

Last change on this file since 5350 was 5267, checked in by nmedfort, 3 years ago

Code clean-up. Removed Pablo Call, SetIthBit? and Prototype.

File size: 41.5 KB
Line 
1#include "pablo_automultiplexing.hpp"
2#include <pablo/builder.hpp>
3#include <pablo/printer_pablos.h>
4#include <boost/container/flat_set.hpp>
5#include <boost/container/flat_map.hpp>
6#include <boost/numeric/ublas/matrix.hpp>
7#include <boost/circular_buffer.hpp>
8#include <boost/graph/topological_sort.hpp>
9#include <boost/range/iterator_range.hpp>
10#ifndef NDEBUG
11#include <pablo/analysis/pabloverifier.hpp>
12#endif
13#include <pablo/optimizers/pablo_simplifier.hpp>
14#include <pablo/builder.hpp>
15#include <stack>
16#include <queue>
17#include <unordered_set>
18#include <functional>
19#include <llvm/Support/CommandLine.h>
20#include "maxsat.hpp"
21
22using namespace llvm;
23using namespace boost;
24using namespace boost::container;
25using namespace boost::numeric::ublas;
26
27static cl::OptionCategory MultiplexingOptions("Multiplexing Optimization Options", "These options control the Pablo Multiplexing optimization pass.");
28
29static cl::opt<unsigned> WindowSize("multiplexing-window-size", cl::init(100),
30                                        cl::desc("maximum sequence distance to consider for candidate set."),
31                                        cl::cat(MultiplexingOptions));
32
33
34namespace pablo {
35
36using TypeId = PabloAST::ClassTypeId;
37
38/** ------------------------------------------------------------------------------------------------------------- *
39 * @brief optimize
40 * @param function the function to optimize
41 ** ------------------------------------------------------------------------------------------------------------- */
42bool MultiplexingPass::optimize(PabloFunction & function) {
43
44
45    Z3_config cfg = Z3_mk_config();
46    Z3_context ctx = Z3_mk_context_rc(cfg);
47    Z3_del_config(cfg);
48    Z3_solver solver = Z3_mk_solver(ctx);
49    Z3_solver_inc_ref(ctx, solver);
50
51    MultiplexingPass mp(function, ctx, solver);
52
53    mp.optimize();
54
55    Z3_solver_dec_ref(ctx, solver);
56    Z3_del_context(ctx);
57
58    #ifndef NDEBUG
59    PabloVerifier::verify(function, "post-multiplexing");
60    #endif
61
62    Simplifier::optimize(function);
63
64    return true;
65}
66
67/** ------------------------------------------------------------------------------------------------------------- *
68 * @brief characterize
69 * @param function the function to optimize
70 ** ------------------------------------------------------------------------------------------------------------- */
71void MultiplexingPass::optimize() {
72    // Map the constants and input variables
73    PabloBlock * entryBlock = mFunction.getEntryBlock();
74    add(entryBlock->createZeroes(), Z3_mk_false(mContext), -1);
75    add(entryBlock->createOnes(), Z3_mk_true(mContext), -1);
76    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
77        add(mFunction.getParameter(i), makeVar(), -1);
78    }
79    optimize(mFunction.getEntryBlock());
80}
81
82/** ------------------------------------------------------------------------------------------------------------- *
83 * @brief characterize
84 ** ------------------------------------------------------------------------------------------------------------- */
85void MultiplexingPass::optimize(PabloBlock * const block) {
86    Statement * begin = block->front();
87    Statement * end = initialize(begin);
88    for (Statement * stmt : *block) {
89        if (LLVM_UNLIKELY(stmt == end)) {
90            Statement * const next = stmt->getNextNode();
91            multiplex(block, begin, stmt);
92            if (isa<If>(stmt)) {
93                optimize(cast<If>(stmt)->getBody());
94            } else if (isa<While>(stmt)) {
95                for (const Var * var : cast<While>(stmt)->getEscaped()) {
96                    Z3_inc_ref(mContext, get(var));
97                }
98                optimize(cast<While>(stmt)->getBody());
99                // since we cannot be certain that we'll always execute at least one iteration of a loop, we must
100                // assume that the variants could either be their initial or resulting value.
101                for (const Var * var : cast<While>(stmt)->getEscaped()) {
102                    Z3_ast v0 = get(var);
103                    Z3_ast & v1 = get(var);
104                    Z3_ast merge[2] = { v0, v1 };
105                    Z3_ast r = Z3_mk_or(mContext, 2, merge);
106                    Z3_inc_ref(mContext, r);
107                    Z3_dec_ref(mContext, v0);
108                    Z3_dec_ref(mContext, v1);
109                    v1 = r;
110                    assert (get(var) == r);
111                }
112            }
113            end = initialize(begin = next);
114        } else {
115            characterize(stmt);
116        }
117    }
118    multiplex(block, begin, nullptr);
119}
120
121/** ------------------------------------------------------------------------------------------------------------- *
122 * @brief multiplex
123 ** ------------------------------------------------------------------------------------------------------------- */
124void MultiplexingPass::multiplex(PabloBlock * const block, Statement * const begin, Statement * const end) {
125    if (generateCandidateSets(begin, end)) {
126        selectMultiplexSetsGreedy();
127        // eliminateSubsetConstraints();
128        multiplexSelectedSets(block, begin, end);
129    }
130}
131
132/** ------------------------------------------------------------------------------------------------------------- *
133 * @brief equals
134 ** ------------------------------------------------------------------------------------------------------------- */
135inline bool MultiplexingPass::equals(Z3_ast a, Z3_ast b) {
136    Z3_solver_push(mContext, mSolver);
137    Z3_ast test = Z3_mk_eq(mContext, a, b);
138    Z3_inc_ref(mContext, test);
139    Z3_solver_assert(mContext, mSolver, test);
140    const auto r = Z3_solver_check(mContext, mSolver);
141    Z3_dec_ref(mContext, test);
142    Z3_solver_pop(mContext, mSolver, 1);
143    return (r == Z3_L_TRUE);
144}
145
146/** ------------------------------------------------------------------------------------------------------------- *
147 * @brief handle_unexpected_statement
148 ** ------------------------------------------------------------------------------------------------------------- */
149static void handle_unexpected_statement(const Statement * const stmt) {
150    std::string tmp;
151    raw_string_ostream err(tmp);
152    err << "Unexpected statement type: ";
153    PabloPrinter::print(stmt, err);
154    throw std::runtime_error(err.str());
155}
156
157/** ------------------------------------------------------------------------------------------------------------- *
158 * @brief characterize
159 ** ------------------------------------------------------------------------------------------------------------- */
160Z3_ast MultiplexingPass::characterize(const Statement * const stmt, const bool deref) {
161
162    const size_t n = stmt->getNumOperands(); assert (n > 0);
163    Z3_ast operands[n];
164    for (size_t i = 0; i < n; ++i) {
165        PabloAST * op = stmt->getOperand(i);
166        if (LLVM_UNLIKELY(isa<Integer>(op) || isa<String>(op))) {
167            continue;
168        }
169        operands[i] = get(op, deref);
170    }
171
172    Z3_ast node = operands[0];
173    switch (stmt->getClassTypeId()) {
174        case TypeId::Assign:
175        case TypeId::AtEOF:
176        case TypeId::InFile:
177            node = operands[0]; break;
178        case TypeId::And:
179            node = Z3_mk_and(mContext, n, operands); break;
180        case TypeId::Or:
181            node = Z3_mk_or(mContext, n, operands); break;
182        case TypeId::Xor:
183            node = Z3_mk_xor(mContext, operands[0], operands[1]);
184            Z3_inc_ref(mContext, node);
185            for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
186                Z3_ast temp = Z3_mk_xor(mContext, node, operands[i]);
187                Z3_inc_ref(mContext, temp);
188                Z3_dec_ref(mContext, node);
189                node = temp;
190            }
191            return add(stmt, node, stmt->getNumUses());
192        case TypeId::Not:
193            node = Z3_mk_not(mContext, node);
194            break;
195        case TypeId::Sel:
196            node = Z3_mk_ite(mContext, operands[0], operands[1], operands[2]);
197            break;
198        case TypeId::Advance:
199            return characterize(cast<Advance>(stmt), operands[0]);
200        case TypeId::ScanThru:
201            // ScanThru(c, m) := (c + m) ∧ ¬m. Thus we can conservatively represent this statement using the BDD
202            // for ¬m --- provided no derivative of this statement is negated in any fashion.
203        case TypeId::MatchStar:
204        case TypeId::Count:
205            node = makeVar();
206            break;
207        default:
208            handle_unexpected_statement(stmt);
209    }
210    Z3_inc_ref(mContext, node);
211    return add(stmt, node, stmt->getNumUses());
212}
213
214
215/** ------------------------------------------------------------------------------------------------------------- *
216 * @brief characterize
217 ** ------------------------------------------------------------------------------------------------------------- */
218inline Z3_ast MultiplexingPass::characterize(const Advance * const adv, Z3_ast Ik) {
219    const auto k = mNegatedAdvance.size();
220
221    assert (adv);
222    assert (mConstraintGraph[k] == adv);
223    std::vector<bool> unconstrained(k);
224
225    Z3_solver_push(mContext, mSolver);
226
227    for (size_t i = 0; i < k; ++i) {
228
229        // Have we already proven that they are unconstrained by their subset relationship?
230        if (unconstrained[i]) continue;
231
232        // If these Advances are mutually exclusive, in the same scope, transitively independent, and shift their
233        // values by the same amount, we can safely multiplex them. Otherwise mark the constraint in the graph.
234        const Advance * const ithAdv = mConstraintGraph[i];
235        if (ithAdv->getOperand(1) == adv->getOperand(1)) {
236
237            Z3_ast Ii = get(ithAdv->getOperand(0));
238
239            // Is there any satisfying truth assignment? If not, these streams are mutually exclusive.
240
241            Z3_solver_push(mContext, mSolver);
242            Z3_ast conj[2] = { Ii, Ik };
243            Z3_ast IiIk = Z3_mk_and(mContext, 2, conj);
244            Z3_inc_ref(mContext, IiIk);
245            Z3_solver_assert(mContext, mSolver, IiIk);
246            if (Z3_solver_check(mContext, mSolver) == Z3_L_FALSE) {
247                // If Ai ∩ Ak = ∅ and Aj ⊂ Ai, Aj ∩ Ak = ∅.
248//                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
249//                    unconstrained[source(e, mSubsetGraph)] = true;
250//                }
251                unconstrained[i] = true;
252            }/* else if (equals(Ii, IiIk)) {
253                // If Ii = Ii ∩ Ik then Ii ⊆ Ik. Record this in the subset graph with the arc (i, k).
254                // Note: the AST will be modified to make these mutually exclusive if Ai and Ak end up in
255                // the same multiplexing set.
256                add_edge(i, k, mSubsetGraph);
257                // If Ai ⊂ Ak and Aj ⊂ Ai, Aj ⊂ Ak.
258                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
259                    const auto j = source(e, mSubsetGraph);
260                    add_edge(j, k, mSubsetGraph);
261                    unconstrained[j] = true;
262                }
263                unconstrained[i] = true;
264            } else if (equals(Ik, IiIk)) {
265                // If Ik = Ii ∩ Ik then Ik ⊆ Ii. Record this in the subset graph with the arc (k, i).
266                add_edge(k, i, mSubsetGraph);
267                // If Ak ⊂ Ai and Ai ⊂ Aj, Ak ⊂ Aj.
268                for (auto e : make_iterator_range(out_edges(i, mSubsetGraph))) {
269                    const auto j = target(e, mSubsetGraph);
270                    add_edge(k, j, mSubsetGraph);
271                    unconstrained[j] = true;
272                }
273                unconstrained[i] = true;
274            }*/
275            Z3_dec_ref(mContext, IiIk);
276            Z3_solver_pop(mContext, mSolver, 1);
277        }
278    }
279
280    Z3_solver_pop(mContext, mSolver, 1);
281
282    Z3_ast Ak0 = makeVar();
283    Z3_inc_ref(mContext, Ak0);
284    Z3_ast Nk = Z3_mk_not(mContext, Ak0);
285    Z3_inc_ref(mContext, Nk);
286
287    Z3_ast vars[k + 1];
288    vars[0] = Ak0;
289
290    unsigned m = 1;
291    for (unsigned i = 0; i < k; ++i) {
292        if (unconstrained[i]) {
293            // This algorithm deems two streams mutually exclusive if and only if their conjuntion is a contradiction.
294            // To generate a contradiction when comparing Advances, the BDD of each Advance is represented by the conjunction of
295            // variables representing the k-th Advance and the negation of all variables for the Advances whose inputs are mutually
296            // exclusive with the k-th input.
297
298            // For example, if the input of the i-th Advance is mutually exclusive with the input of the j-th and k-th Advance, the
299            // BDD of the i-th Advance is Ai ∧ ¬Aj ∧ ¬Ak. Similarly, the j- and k-th Advance is Aj ∧ ¬Ai and Ak ∧ ¬Ai, respectively
300            // (assuming that the j-th and k-th Advance are not mutually exclusive.)
301
302            Z3_ast & Ai0 = get(mConstraintGraph[i]);
303            Z3_ast conj[2] = { Ai0, Nk };
304            Z3_ast Ai = Z3_mk_and(mContext, 2, conj);
305            Z3_inc_ref(mContext, Ai);
306            Z3_dec_ref(mContext, Ai0);
307            Ai0 = Ai;
308            assert (get(mConstraintGraph[i]) == Ai);
309
310            vars[m++] = mNegatedAdvance[i];
311
312            continue; // note: if these Advances are transitively dependent, an edge will still exist.
313        }
314        const auto ei = add_edge(i, k, mConstraintGraph);
315        // if this is not a new edge, it must have a dependency constraint.
316        if (ei.second) {
317            mConstraintGraph[ei.first] = ConstraintType::Inclusive;
318        }
319    }
320    // To minimize the number of BDD computations, we store the negated variable instead of negating it each time.
321    mNegatedAdvance.emplace_back(Nk);
322    Z3_ast Ak = Z3_mk_and(mContext, m, vars);
323    if (LLVM_UNLIKELY(Ak != Ak0)) {
324        Z3_inc_ref(mContext, Ak);
325        Z3_dec_ref(mContext, Ak0);
326    }
327    return add(adv, Ak, -1);
328}
329
330/** ------------------------------------------------------------------------------------------------------------- *
331 * @brief initialize
332 ** ------------------------------------------------------------------------------------------------------------- */
333Statement * MultiplexingPass::initialize(Statement * const initial) {
334
335    // clean up any unneeded refs / characterizations.
336    for (auto i = mCharacterization.begin(); i != mCharacterization.end(); ) {
337        const auto ref = i->second;
338        auto next = i; ++next;
339        if (LLVM_UNLIKELY(ref.second == 0)) {
340            assert (isa<Statement>(i->first));
341            Z3_dec_ref(mContext, ref.first);
342            mCharacterization.erase(i);
343        }
344        i = next;
345    }
346    for (Z3_ast var : mNegatedAdvance) {
347        Z3_dec_ref(mContext, var);
348    }
349    mNegatedAdvance.clear();
350
351    // Scan through and count all the advances and statements ...
352    unsigned statements = 0, advances = 0;
353    Statement * last = nullptr;
354    for (Statement * stmt = initial; stmt; stmt = stmt->getNextNode()) {
355        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
356            last = stmt;
357            break;
358        } else if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
359            ++advances;
360        }
361        ++statements;
362    }
363
364    flat_map<const PabloAST *, unsigned> M;
365    M.reserve(statements);
366    matrix<bool> G(statements, advances, false);
367    for (unsigned i = 0; i != advances; ++i) {
368        G(i, i) = true;
369    }
370
371    mConstraintGraph = ConstraintGraph(advances);
372    unsigned n = advances;
373    unsigned k = 0;
374    for (Statement * stmt = initial; stmt != last; stmt = stmt->getNextNode()) {
375        assert (!isa<If>(stmt) && !isa<While>(stmt));
376        unsigned u = 0;
377        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
378            mConstraintGraph[k] = cast<Advance>(stmt);
379            u = k++;
380        } else {
381            u = n++;
382        }
383        for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
384            const PabloAST * const op = stmt->getOperand(i);
385            if (LLVM_LIKELY(isa<Statement>(op))) {
386                auto f = M.find(op);
387                if (f != M.end()) {
388                    const unsigned v = std::get<1>(*f);
389                    for (unsigned w = 0; w != k; ++w) {
390                        G(u, w) |= G(v, w);
391                    }
392                }
393            }
394        }
395        M.emplace(stmt, u);
396    }
397
398    assert (k == advances);
399
400    // Initialize the base constraint graph by transposing G and removing reflective loops
401    for (unsigned i = 0; i != advances; ++i) {
402        for (unsigned j = 0; j < i; ++j) {
403            if (G(i, j)) {
404                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
405            }
406        }
407        for (unsigned j = i + 1; j < advances; ++j) {
408            if (G(i, j)) {
409                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
410            }
411        }
412    }
413
414    mSubsetGraph = SubsetGraph(advances);
415    mNegatedAdvance.reserve(advances);
416
417    return last;
418}
419
420/** ------------------------------------------------------------------------------------------------------------- *
421 * @brief generateCandidateSets
422 ** ------------------------------------------------------------------------------------------------------------- */
423bool MultiplexingPass::generateCandidateSets(Statement * const begin, Statement * const end) {
424
425    const auto n = mNegatedAdvance.size();
426    if (LLVM_UNLIKELY(n < 3)) {
427        return false;
428    }
429    assert (num_vertices(mConstraintGraph) == n);
430
431    mCandidateGraph = CandidateGraph(n);
432
433    Z3_config cfg = Z3_mk_config();
434    Z3_set_param_value(cfg, "MODEL", "true");
435    Z3_context ctx = Z3_mk_context(cfg);
436    Z3_del_config(cfg);
437    Z3_solver solver = Z3_mk_solver(ctx);
438    Z3_solver_inc_ref(ctx, solver);
439
440    std::vector<Z3_ast> V(n);
441    for (unsigned i = 0; i < n; ++i) {
442        V[i] = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_bool_sort(ctx)); assert (V[i]);
443    }
444    std::vector<std::pair<unsigned, unsigned>> S;
445    S.reserve(n);
446
447    unsigned line = 0;
448    unsigned i = 0;
449    for (Statement * stmt = begin; stmt != end; stmt = stmt->getNextNode()) {
450        if (LLVM_UNLIKELY(isa<Advance>(stmt))) {
451            assert (S.empty() || line > std::get<0>(S.back()));
452            assert (cast<Advance>(stmt) == mConstraintGraph[i]);
453            if (S.size() > 0 && (line - std::get<0>(S.front())) > WindowSize) {
454                // try to compute a maximal set for this given set of Advances
455                if (S.size() > 2) {
456                    generateCandidateSets(ctx, solver, S, V);
457                }
458                // erase any that preceed our window
459                auto end = S.begin();
460                while (++end != S.end() && ((line - std::get<0>(*end)) > WindowSize));
461                S.erase(S.begin(), end);
462            }
463            for (unsigned j : make_iterator_range(adjacent_vertices(i, mConstraintGraph))) {
464                Z3_ast disj[2] = { Z3_mk_not(ctx, V[j]), Z3_mk_not(ctx, V[i]) };
465                Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, disj));
466            }
467            S.emplace_back(line, i++);
468        }
469        ++line;
470    }
471    if (S.size() > 2) {
472        generateCandidateSets(ctx, solver, S, V);
473    }
474
475    Z3_solver_dec_ref(ctx, solver);
476    Z3_del_context(ctx);
477
478    return num_vertices(mCandidateGraph) > n;
479}
480
481/** ------------------------------------------------------------------------------------------------------------- *
482 * @brief generateCandidateSets
483 ** ------------------------------------------------------------------------------------------------------------- */
484void MultiplexingPass::generateCandidateSets(Z3_context ctx, Z3_solver solver, const std::vector<std::pair<unsigned, unsigned>> & S, const std::vector<Z3_ast> & V) {
485    assert (S.size() > 2);
486    assert (std::get<0>(S.front()) < std::get<0>(S.back()));
487    assert ((std::get<0>(S.back()) - std::get<0>(S.front())) <= WindowSize);
488
489    Z3_solver_push(ctx, solver);
490
491    const auto n = V.size();
492    std::vector<unsigned> M(S.size());
493    for (unsigned i = 0; i < S.size(); ++i) {
494        M[i] = std::get<1>(S[i]);
495    }
496
497    for (;;) {
498
499        std::vector<Z3_ast> assumptions(M.size());
500        unsigned j = 0;
501        for (unsigned i = 0; i < n; ++i) {
502            if (LLVM_UNLIKELY((j < M.size()) && (M[j] == i))) { // in our window range
503                assumptions[j++] = V[i]; assert (V[i]);
504            } else {
505                Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[i]));
506            }
507        }
508        assert (j == M.size());
509
510        if (maxsat(ctx, solver, assumptions) >= 0) {
511            Z3_model m = Z3_solver_get_model(ctx, solver);
512            Z3_model_inc_ref(ctx, m);
513            const auto k = add_vertex(mCandidateGraph); assert(k >= V.size());
514            Z3_ast TRUE = Z3_mk_true(ctx);
515            for (auto i = M.begin(); i != M.end(); ) {
516                Z3_ast value;
517                if (LLVM_UNLIKELY(Z3_model_eval(ctx, m, V[*i], 1, &value) != Z3_TRUE)) {
518                    throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from constraint model!");
519                }
520                if (value == TRUE) {
521                    add_edge(*i, k, mCandidateGraph);
522                    Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[*i]));
523                    i = M.erase(i);
524                } else {
525                    ++i;
526                }
527            }
528            Z3_model_dec_ref(ctx, m);
529            if (M.size() > 2) {
530                continue;
531            }
532        }
533        break;
534    }
535
536    Z3_solver_pop(ctx, solver, 1);
537}
538
539/** ------------------------------------------------------------------------------------------------------------- *
540 * @brief is_power_of_2
541 * @param n an integer
542 ** ------------------------------------------------------------------------------------------------------------- */
543static inline bool is_power_of_2(const size_t n) {
544    return ((n & (n - 1)) == 0);
545}
546
547/** ------------------------------------------------------------------------------------------------------------- *
548 * @brief log2_plus_one
549 ** ------------------------------------------------------------------------------------------------------------- */
550static inline size_t log2_plus_one(const size_t n) {
551    return std::log2<size_t>(n) + 1;
552}
553
554/** ------------------------------------------------------------------------------------------------------------- *
555 * @brief selectMultiplexSetsGreedy
556 *
557 * This algorithm is simply computes a greedy set cover. We want an exact max-weight set cover but can generate new
558 * sets by taking a subset of any existing set. With a few modifications, the greedy approach seems to work well
559 * enough but can be shown to produce a suboptimal solution if there are three candidate sets labelled A, B and C,
560 * in which A ∩ B = ∅, |A| ≀ |B| < |C|, and C ⊂ (A ∪ B).
561 ** ------------------------------------------------------------------------------------------------------------- */
562void MultiplexingPass::selectMultiplexSetsGreedy() {
563
564    using AdjIterator = graph_traits<CandidateGraph>::adjacency_iterator;
565    using degree_t = CandidateGraph::degree_size_type;
566    using vertex_t = CandidateGraph::vertex_descriptor;
567
568    const size_t m = num_vertices(mConstraintGraph);
569    const size_t n = num_vertices(mCandidateGraph) - m;
570
571    std::vector<bool> chosen(n);
572
573    for (;;) {
574
575        // Choose the set with the greatest number of vertices not already included in some other set.
576        vertex_t u = 0;
577        degree_t w = 0;
578        for (vertex_t i = 0; i != n; ++i) {
579            if (chosen[i]) continue;
580            const auto t = i + m;
581            degree_t r = degree(t, mCandidateGraph);
582            if (LLVM_LIKELY(r >= 3)) { // if this set has at least 3 elements.
583                if (w < r) {
584                    u = t;
585                    w = r;
586                }
587            } else if (r) {
588                clear_vertex(t, mCandidateGraph);
589            }
590        }
591
592        // Multiplexing requires 3 or more elements; if no set contains at least 3, abort.
593        if (LLVM_UNLIKELY(w == 0)) {
594            break;
595        }
596
597        chosen[u - m] = true;
598
599        // If this contains 2^n elements for any n, discard the member that is most likely to be added
600        // to some future set.
601        if (LLVM_UNLIKELY(is_power_of_2(degree(u, mCandidateGraph)))) {
602            vertex_t x = 0;
603            degree_t w = 0;
604            for (const auto v : make_iterator_range(adjacent_vertices(u, mCandidateGraph))) {
605                if (degree(v, mCandidateGraph) > w) {
606                    x = v;
607                    w = degree(v, mCandidateGraph);
608                }
609            }
610            remove_edge(u, x, mCandidateGraph);
611        }
612
613        AdjIterator begin, end;
614        std::tie(begin, end) = adjacent_vertices(u, mCandidateGraph);
615        for (auto vi = begin; vi != end; ) {
616            const auto v = *vi++;
617            clear_vertex(v, mCandidateGraph);
618            add_edge(v, u, mCandidateGraph);
619        }
620
621    }
622
623    #ifndef NDEBUG
624    for (unsigned i = 0; i != m; ++i) {
625        assert (degree(i, mCandidateGraph) <= 1);
626    }
627    for (unsigned i = m; i != (m + n); ++i) {
628        assert (degree(i, mCandidateGraph) == 0 || degree(i, mCandidateGraph) >= 3);
629    }
630    #endif
631}
632
633/** ------------------------------------------------------------------------------------------------------------- *
634 * @brief eliminateSubsetConstraints
635 ** ------------------------------------------------------------------------------------------------------------- */
636void MultiplexingPass::eliminateSubsetConstraints() {
637    using SubsetEdgeIterator = graph_traits<SubsetGraph>::edge_iterator;
638    // If Ai ⊂ Aj then the subset graph will contain the arc (i, j). Remove all arcs corresponding to vertices
639    // that are not elements of the same multiplexing set.
640    SubsetEdgeIterator ei, ei_end, ei_next;
641    std::tie(ei, ei_end) = edges(mSubsetGraph);
642    for (ei_next = ei; ei != ei_end; ei = ei_next) {
643        ++ei_next;
644        const auto u = source(*ei, mSubsetGraph);
645        const auto v = target(*ei, mSubsetGraph);
646        if (degree(u, mCandidateGraph) != 0 && degree(v, mCandidateGraph) != 0) {
647            assert (degree(u, mCandidateGraph) == 1);
648            assert (degree(v, mCandidateGraph) == 1);
649            const auto su = *(adjacent_vertices(u, mCandidateGraph).first);
650            const auto sv = *(adjacent_vertices(v, mCandidateGraph).first);
651            if (su == sv) {
652                continue;
653            }
654        }
655        remove_edge(*ei, mSubsetGraph);
656    }
657
658    if (num_edges(mSubsetGraph) != 0) {
659
660        // At least one subset constraint exists; perform a transitive reduction on the graph to ensure that
661        // we perform the minimum number of AST modifications for the selected multiplexing sets.
662
663        doTransitiveReductionOfSubsetGraph();
664
665        // Afterwards modify the AST to ensure that multiplexing algorithm can ignore any subset constraints
666        for (auto e : make_iterator_range(edges(mSubsetGraph))) {
667            Advance * const adv1 = mConstraintGraph[source(e, mSubsetGraph)];
668            Advance * const adv2 = mConstraintGraph[target(e, mSubsetGraph)];
669            assert (adv1->getParent() == adv2->getParent());
670            PabloBlock * const pb = adv1->getParent();
671            pb->setInsertPoint(adv2->getPrevNode());
672            adv2->setOperand(0, pb->createAnd(adv2->getOperand(0), pb->createNot(adv1->getOperand(0)), "subset"));
673            pb->setInsertPoint(adv2);
674            adv2->replaceAllUsesWith(pb->createOr(adv1, adv2, "merge"));
675        }
676
677    }
678}
679
680/** ------------------------------------------------------------------------------------------------------------- *
681 * @brief addWithHardConstraints
682 ** ------------------------------------------------------------------------------------------------------------- */
683Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, Statement * const stmt, flat_map<Statement *, Z3_ast> & M) {
684    assert (M.count(stmt) == 0 && stmt->getParent() == block);
685    // compute the hard dependency constraints
686    Z3_ast node = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_int_sort(ctx)); assert (node);
687    // we want all numbers to be positive so that the soft assertion that the first statement ought to stay at the first location
688    // whenever possible isn't satisfied by making preceeding numbers negative.
689    Z3_solver_assert(ctx, solver, Z3_mk_gt(ctx, node, Z3_mk_int(ctx, 0, Z3_mk_int_sort(ctx))));
690    for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
691        PabloAST * const op = stmt->getOperand(i);
692        if (isa<Statement>(op) && cast<Statement>(op)->getParent() == block) {
693            const auto f = M.find(cast<Statement>(op));
694            if (f != M.end()) {
695                Z3_solver_assert(ctx, solver, Z3_mk_lt(ctx, f->second, node));
696            }
697        }
698    }
699    M.emplace(stmt, node);
700    return node;
701}
702
703/** ------------------------------------------------------------------------------------------------------------- *
704 * @brief dominates
705 *
706 * does Statement a dominate Statement b?
707 ** ------------------------------------------------------------------------------------------------------------- */
708bool dominates(const Statement * const a, const Statement * const b) {
709    assert (a);
710    if (LLVM_UNLIKELY(b == nullptr)) {
711        return false;
712    }
713    assert (a->getParent() == b->getParent());
714    for (const Statement * t : *a->getParent()) {
715        if (t == a) {
716            return true;
717        } else if (t == b) {
718            return false;
719        }
720    }
721    llvm_unreachable("Neither a nor b are in their reported block!");
722    return false;
723}
724
725/** ------------------------------------------------------------------------------------------------------------- *
726 * @brief addWithHardConstraints
727 ** ------------------------------------------------------------------------------------------------------------- */
728Z3_ast addWithHardConstraints(Z3_context ctx, Z3_solver solver, PabloBlock * const block, PabloAST * expr, flat_map<Statement *, Z3_ast> & M, Statement * const ip) {
729    if (isa<Statement>(expr)) {
730        Statement * const stmt = cast<Statement>(expr);
731        if (stmt->getParent() == block) {
732            const auto f = M.find(stmt);
733            if (LLVM_UNLIKELY(f != M.end())) {
734                return f->second;
735            } else if (!dominates(stmt, ip)) {
736                for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
737                    addWithHardConstraints(ctx, solver, block, stmt->getOperand(i), M, ip);
738                }
739                return addWithHardConstraints(ctx, solver, block, stmt, M);
740            }
741        }
742    }
743    return nullptr;
744}
745
746/** ------------------------------------------------------------------------------------------------------------- *
747 * @brief multiplexSelectedSets
748 ** ------------------------------------------------------------------------------------------------------------- */
749inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block, Statement * const begin, Statement * const end) {
750
751    assert ("begin cannot be null!" && begin);
752    assert (begin->getParent() == block);
753    assert (!end || end->getParent() == block);
754    assert (!end || isa<If>(end) || isa<While>(end));
755
756    // TODO: should we test whether sets overlap and merge the computations together?
757
758    Z3_config cfg = Z3_mk_config();
759    Z3_set_param_value(cfg, "MODEL", "true");
760    Z3_context ctx = Z3_mk_context(cfg);
761    Z3_del_config(cfg);
762    Z3_solver solver = Z3_mk_solver(ctx);
763    Z3_solver_inc_ref(ctx, solver);
764
765    const auto first_set = num_vertices(mConstraintGraph);
766    const auto last_set = num_vertices(mCandidateGraph);
767
768    for (auto idx = first_set; idx != last_set; ++idx) {
769        const size_t n = degree(idx, mCandidateGraph);
770        if (n) {
771            const size_t m = log2_plus_one(n); assert (n > 2 && m < n);
772            Advance * input[n];
773            PabloAST * muxed[m];
774            PabloAST * muxed_n[m];
775            PabloAST * demuxed[n];
776
777            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
778            unsigned i = 0;
779            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) {
780                input[i] = mConstraintGraph[u];
781                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
782                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
783                ++i;
784            }
785
786            // We can't trust the AST will be in the original order as we can multiplex a region of the program
787            // more than once.
788            Statement * initial = nullptr, * sentinal = nullptr;
789            for (Statement * stmt : *block) {
790                if (isa<Advance>(stmt)) {
791                    for (unsigned i = 0; i < n; ++i) {
792                        if (stmt == input[i]) {
793                            initial = initial ? initial : stmt;
794                            sentinal = stmt;
795                            break;
796                        }
797                    }
798                }
799            }
800            assert (initial);
801
802            Statement * const ip = initial->getPrevNode(); // save our insertion point prior to modifying the AST
803            sentinal = sentinal->getNextNode();
804
805            Z3_solver_push(ctx, solver);
806
807            // Compute the hard and soft constraints for any part of the AST that we are not intending to modify.
808            flat_map<Statement *, Z3_ast> M;
809
810            Z3_ast prior = nullptr;
811            Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
812            std::vector<Z3_ast> ordering;
813
814            for (Statement * stmt = initial; stmt != sentinal; stmt = stmt->getNextNode()) { assert (stmt != ip);
815                Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
816                // compute the soft ordering constraints
817                Z3_ast num = one;
818                if (prior) {
819                    Z3_ast prior_plus_one[2] = { prior, one };
820                    num = Z3_mk_add(ctx, 2, prior_plus_one);
821                }
822                ordering.push_back(Z3_mk_eq(ctx, node, num));
823                if (prior) {
824                    ordering.push_back(Z3_mk_lt(ctx, prior, node));
825                }
826                prior = node;
827            }
828
829            block->setInsertPoint(block->back()); // <- necessary for domination check!
830
831            circular_buffer<PabloAST *> Q(n);
832
833            /// Perform n-to-m Multiplexing
834            for (size_t j = 0; j != m; ++j) {
835                std::ostringstream prefix;
836                prefix << "mux" << n << "to" << m << '.' << (j);
837                assert (Q.empty());
838                for (size_t i = 0; i != n; ++i) {
839                    if (((i + 1) & (1UL << j)) != 0) {
840                        Q.push_back(input[i]->getOperand(0));
841                    }
842                }
843                while (Q.size() > 1) {
844                    PabloAST * a = Q.front(); Q.pop_front();
845                    PabloAST * b = Q.front(); Q.pop_front();
846                    PabloAST * expr = block->createOr(a, b);
847                    addWithHardConstraints(ctx, solver, block, expr, M, ip);
848                    Q.push_back(expr);
849                }
850                PabloAST * const muxing = Q.front(); Q.clear();
851                muxed[j] = block->createAdvance(muxing, input[0]->getOperand(1), prefix.str());
852                addWithHardConstraints(ctx, solver, block, muxed[j], M, ip);
853                muxed_n[j] = block->createNot(muxed[j]);
854                addWithHardConstraints(ctx, solver, block, muxed_n[j], M, ip);
855            }
856
857            /// Perform m-to-n Demultiplexing
858            for (size_t i = 0; i != n; ++i) {
859                // Construct the demuxed values and replaces all the users of the original advances with them.
860                assert (Q.empty());
861                for (size_t j = 0; j != m; ++j) {
862                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
863                }
864                Z3_ast replacement = nullptr;
865                while (Q.size() > 1) {
866                    PabloAST * const a = Q.front(); Q.pop_front();
867                    PabloAST * const b = Q.front(); Q.pop_front();
868                    PabloAST * expr = block->createAnd(a, b);
869                    replacement = addWithHardConstraints(ctx, solver, block, expr, M, ip);
870                    Q.push_back(expr);
871                }
872                assert (replacement);
873                demuxed[i] = Q.front(); Q.clear();
874
875                const auto f = M.find(input[i]);
876                assert (f != M.end());
877                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
878                M.erase(f);
879            }
880
881            assert (M.count(ip) == 0);
882
883            const auto satisfied = maxsat(ctx, solver, ordering);
884
885            if (LLVM_UNLIKELY(satisfied >= 0)) {
886
887                Z3_model model = Z3_solver_get_model(ctx, solver);
888                Z3_model_inc_ref(ctx, model);
889
890                std::vector<std::pair<long long int, Statement *>> I;
891
892                for (const auto i : M) {
893                    Z3_ast value;
894                    if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE)) {
895                        throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
896                    }
897                    long long int line;
898                    if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
899                        throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
900                    }
901                    I.emplace_back(line, std::get<0>(i));
902                }
903
904                Z3_model_dec_ref(ctx, model);
905
906                std::sort(I.begin(), I.end());
907
908                block->setInsertPoint(ip);
909                for (auto i : I) {
910                    block->insert(std::get<1>(i));
911                }
912
913                for (unsigned i = 0; i < n; ++i) {
914                    input[i]->replaceWith(demuxed[i], true, true);
915                    auto ref = mCharacterization.find(input[i]);
916                    assert (ref != mCharacterization.end());
917                    add(demuxed[i], std::get<0>(ref->second), -1);
918                }
919
920            } else { // fatal error; delete any statements we created.
921
922                for (unsigned i = 0; i < n; ++i) {
923                    if (LLVM_LIKELY(isa<Statement>(demuxed[i]))) {
924                        cast<Statement>(demuxed[i])->eraseFromParent(true);
925                    }
926                }
927
928            }
929
930            Z3_solver_pop(ctx, solver, 1);
931        }
932    }
933
934    Z3_solver_dec_ref(ctx, solver);
935    Z3_del_context(ctx);
936}
937
938/** ------------------------------------------------------------------------------------------------------------- *
939 * @brief doTransitiveReductionOfSubsetGraph
940 ** ------------------------------------------------------------------------------------------------------------- */
941void MultiplexingPass::doTransitiveReductionOfSubsetGraph() {
942    std::vector<SubsetGraph::vertex_descriptor> Q;
943    for (auto u : make_iterator_range(vertices(mSubsetGraph))) {
944        if (in_degree(u, mSubsetGraph) == 0 && out_degree(u, mSubsetGraph) != 0) {
945            Q.push_back(u);
946        }
947    }
948    flat_set<SubsetGraph::vertex_descriptor> targets;
949    flat_set<SubsetGraph::vertex_descriptor> visited;
950    do {
951        const auto u = Q.back(); Q.pop_back();
952        for (auto ei : make_iterator_range(out_edges(u, mSubsetGraph))) {
953            for (auto ej : make_iterator_range(out_edges(target(ei, mSubsetGraph), mSubsetGraph))) {
954                targets.insert(target(ej, mSubsetGraph));
955            }
956        }
957        for (auto v : targets) {
958            remove_edge(u, v, mSubsetGraph);
959        }
960        for (auto e : make_iterator_range(out_edges(u, mSubsetGraph))) {
961            const auto v = target(e, mSubsetGraph);
962            if (visited.insert(v).second) {
963                Q.push_back(v);
964            }
965        }
966    } while (!Q.empty());
967}
968
969/** ------------------------------------------------------------------------------------------------------------- *
970 * @brief get
971 ** ------------------------------------------------------------------------------------------------------------- */
972inline Z3_ast & MultiplexingPass::get(const PabloAST * const expr, const bool deref) {
973    assert (expr);
974    auto f = mCharacterization.find(expr);
975    if (LLVM_UNLIKELY(f == mCharacterization.end())) {
976        characterize(cast<Statement>(expr), false);
977        f = mCharacterization.find(expr);
978        assert (f != mCharacterization.end());
979    }
980    CharacterizationRef & ref = f->second;
981    if (deref) {
982        if (LLVM_LIKELY(std::get<1>(ref)) > 0) {
983            std::get<1>(ref) -= 1;
984        }
985    }
986    return std::get<0>(ref);
987}
988
989/** ------------------------------------------------------------------------------------------------------------- *
990 * @brief make
991 ** ------------------------------------------------------------------------------------------------------------- */
992inline Z3_ast MultiplexingPass::makeVar() {
993    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
994    Z3_inc_ref(mContext, node);
995    return node;
996}
997
998/** ------------------------------------------------------------------------------------------------------------- *
999 * @brief add
1000 ** ------------------------------------------------------------------------------------------------------------- */
1001inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node, const size_t refs) {
1002    mCharacterization.insert(std::make_pair(expr, std::make_pair(node, refs)));
1003    return node;
1004}
1005
1006/** ------------------------------------------------------------------------------------------------------------- *
1007 * @brief constructor
1008 ** ------------------------------------------------------------------------------------------------------------- */
1009inline MultiplexingPass::MultiplexingPass(PabloFunction & f, Z3_context context, Z3_solver solver)
1010: mContext(context)
1011, mSolver(solver)
1012, mFunction(f)
1013, mConstraintGraph(0)
1014{
1015
1016}
1017
1018} // end of namespace pablo
Note: See TracBrowser for help on using the repository browser.