Ignore:
Timestamp:
Sep 14, 2016, 2:56:54 PM (3 years ago)
Author:
nmedfort
Message:

Work on multiplexing and distribution passes + a few AST modification bug fixes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/pablo/optimizers/pablo_automultiplexing.cpp

    r5119 r5156  
    44#include <pablo/printer_pablos.h>
    55#include <boost/container/flat_set.hpp>
     6#include <boost/container/flat_map.hpp>
    67#include <boost/numeric/ublas/matrix.hpp>
    78#include <boost/circular_buffer.hpp>
     
    1617#include <functional>
    1718#include <llvm/Support/CommandLine.h>
     19#include "maxsat.hpp"
    1820
    1921using namespace llvm;
     
    2426static cl::OptionCategory MultiplexingOptions("Multiplexing Optimization Options", "These options control the Pablo Multiplexing optimization pass.");
    2527
    26 #ifdef NDEBUG
    27 #define INITIAL_SEED_VALUE (std::random_device()())
    28 #else
    29 #define INITIAL_SEED_VALUE (83234827342)
    30 #endif
    31 
    32 static cl::opt<std::mt19937::result_type> Seed("multiplexing-seed", cl::init(INITIAL_SEED_VALUE),
    33                                         cl::desc("randomization seed used when performing any non-deterministic operations."),
    34                                         cl::cat(MultiplexingOptions));
    35 
    36 #undef INITIAL_SEED_VALUE
    37 
    3828static cl::opt<unsigned> WindowSize("multiplexing-window-size", cl::init(100),
    3929                                        cl::desc("maximum sequence distance to consider for candidate set."),
     
    4333namespace pablo {
    4434
    45 Z3_bool maxsat(Z3_context ctx, Z3_solver solver, std::vector<Z3_ast> & soft);
    46 
    4735using TypeId = PabloAST::ClassTypeId;
    4836
     
    5341bool MultiplexingPass::optimize(PabloFunction & function) {
    5442
    55     PabloVerifier::verify(function, "pre-multiplexing");
    56 
    57     errs() << "PRE-MULTIPLEXING\n==============================================\n";
    58     PabloPrinter::print(function, errs());
    5943
    6044    Z3_config cfg = Z3_mk_config();
     
    6448    Z3_solver_inc_ref(ctx, solver);
    6549
    66     MultiplexingPass mp(function, Seed, ctx, solver);
     50    MultiplexingPass mp(function, ctx, solver);
    6751
    6852    mp.optimize();
     
    7155    Z3_del_context(ctx);
    7256
     57    #ifndef NDEBUG
    7358    PabloVerifier::verify(function, "post-multiplexing");
     59    #endif
    7460
    7561    Simplifier::optimize(function);
    76 
    77     errs() << "POST-MULTIPLEXING\n==============================================\n";
    78     PabloPrinter::print(function, errs());
    7962
    8063    return true;
     
    8770void MultiplexingPass::optimize() {
    8871    // Map the constants and input variables
    89 
    90     add(PabloBlock::createZeroes(), Z3_mk_false(mContext));
    91     add(PabloBlock::createOnes(), Z3_mk_true(mContext));
     72    add(PabloBlock::createZeroes(), Z3_mk_false(mContext), -1);
     73    add(PabloBlock::createOnes(), Z3_mk_true(mContext), -1);
    9274    for (unsigned i = 0; i < mFunction.getNumOfParameters(); ++i) {
    93         make(mFunction.getParameter(i));
    94     }
    95 
     75        add(mFunction.getParameter(i), makeVar(), -1);
     76    }
    9677    optimize(mFunction.getEntryBlock());
    9778}
     
    142123    if (generateCandidateSets(begin, end)) {
    143124        selectMultiplexSetsGreedy();
    144         eliminateSubsetConstraints();
     125        // eliminateSubsetConstraints();
    145126        multiplexSelectedSets(block, begin, end);
    146127    }
     
    152133inline bool MultiplexingPass::equals(Z3_ast a, Z3_ast b) {
    153134    Z3_solver_push(mContext, mSolver);
    154     Z3_ast test = Z3_mk_eq(mContext, a, b); // try using check assumption instead?
     135    Z3_ast test = Z3_mk_eq(mContext, a, b);
    155136    Z3_inc_ref(mContext, test);
    156137    Z3_solver_assert(mContext, mSolver, test);
     
    164145 * @brief handle_unexpected_statement
    165146 ** ------------------------------------------------------------------------------------------------------------- */
    166 static void handle_unexpected_statement(Statement * const stmt) {
     147static void handle_unexpected_statement(const Statement * const stmt) {
    167148    std::string tmp;
    168149    raw_string_ostream err(tmp);
     
    175156 * @brief characterize
    176157 ** ------------------------------------------------------------------------------------------------------------- */
    177 inline Z3_ast MultiplexingPass::characterize(Statement * const stmt) {
     158Z3_ast MultiplexingPass::characterize(const Statement * const stmt, const bool deref) {
    178159
    179160    const size_t n = stmt->getNumOperands(); assert (n > 0);
    180     Z3_ast operands[n] = {};
     161    Z3_ast operands[n];
    181162    for (size_t i = 0; i < n; ++i) {
    182163        PabloAST * op = stmt->getOperand(i);
    183         if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
    184             operands[i] = get(op, true);
    185         }
     164        if (LLVM_UNLIKELY(isa<Integer>(op) || isa<String>(op))) {
     165            continue;
     166        }
     167        operands[i] = get(op, deref);
    186168    }
    187169
     
    206188                node = temp;
    207189            }
    208             return add(stmt, node);
     190            return add(stmt, node, stmt->getNumUses());
    209191        case TypeId::Not:
    210192            node = Z3_mk_not(mContext, node);
     
    220202        case TypeId::MatchStar:
    221203        case TypeId::Count:
    222             return make(stmt);
     204            node = makeVar();
     205            break;
    223206        default:
    224207            handle_unexpected_statement(stmt);
    225208    }
    226209    Z3_inc_ref(mContext, node);
    227     return add(stmt, node);
     210    return add(stmt, node, stmt->getNumUses());
    228211}
    229212
     
    232215 * @brief characterize
    233216 ** ------------------------------------------------------------------------------------------------------------- */
    234 inline Z3_ast MultiplexingPass::characterize(Advance * const adv, Z3_ast Ik) {
     217inline Z3_ast MultiplexingPass::characterize(const Advance * const adv, Z3_ast Ik) {
    235218    const auto k = mNegatedAdvance.size();
    236219
    237220    assert (adv);
    238221    assert (mConstraintGraph[k] == adv);
    239 
    240     bool unconstrained[k] = {};
     222    std::vector<bool> unconstrained(k);
    241223
    242224    Z3_solver_push(mContext, mSolver);
     
    265247 and Aj ⊂ Ai, Aj ∩ Ak = âˆ
    266248.
    267                 for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
    268                     unconstrained[source(e, mSubsetGraph)] = true;
    269                 }
     249//                for (auto e : make_iterator_range(in_edges(i, mSubsetGraph))) {
     250//                    unconstrained[source(e, mSubsetGraph)] = true;
     251//                }
    270252                unconstrained[i] = true;
    271             } else if (equals(Ii, IiIk)) {
     253            }/* else if (equals(Ii, IiIk)) {
    272254                // If Ii = Ii ∩ Ik then Ii ⊆ Ik. Record this in the subset graph with the arc (i, k).
    273255                // Note: the AST will be modified to make these mutually exclusive if Ai and Ak end up in
     
    291273                }
    292274                unconstrained[i] = true;
    293             }
     275            }*/
    294276            Z3_dec_ref(mContext, IiIk);
    295277            Z3_solver_pop(mContext, mSolver, 1);
     
    299281    Z3_solver_pop(mContext, mSolver, 1);
    300282
    301     Z3_ast Ak0 = make(adv);
     283    Z3_ast Ak0 = makeVar();
    302284    Z3_inc_ref(mContext, Ak0);
    303285    Z3_ast Nk = Z3_mk_not(mContext, Ak0);
     
    331313            continue; // note: if these Advances are transitively dependent, an edge will still exist.
    332314        }
    333         add_edge(i, k, mConstraintGraph);
     315        const auto ei = add_edge(i, k, mConstraintGraph);
     316        // if this is not a new edge, it must have a dependency constraint.
     317        if (ei.second) {
     318            mConstraintGraph[ei.first] = ConstraintType::Inclusive;
     319        }
    334320    }
    335321    // To minimize the number of BDD computations, we store the negated variable instead of negating it each time.
     
    340326        Z3_dec_ref(mContext, Ak0);
    341327    }
    342     return add(adv, Ak);
     328    return add(adv, Ak, -1);
    343329}
    344330
     
    350336    // clean up any unneeded refs / characterizations.
    351337    for (auto i = mCharacterization.begin(); i != mCharacterization.end(); ) {
    352         const CharacterizationRef & r = std::get<1>(*i);
    353         const auto e = i++;
    354         if (LLVM_UNLIKELY(std::get<1>(r) == 0)) {
    355             Z3_dec_ref(mContext, std::get<0>(r));
    356             mCharacterization.erase(e);
    357         }
    358     }
    359 
     338        const auto ref = i->second;
     339        auto next = i; ++next;
     340        if (LLVM_UNLIKELY(ref.second == 0)) {
     341            assert (isa<Statement>(i->first));
     342            Z3_dec_ref(mContext, ref.first);
     343            mCharacterization.erase(i);
     344        }
     345        i = next;
     346    }
    360347    for (Z3_ast var : mNegatedAdvance) {
    361348        Z3_dec_ref(mContext, var);
     
    416403        for (unsigned j = 0; j < i; ++j) {
    417404            if (G(i, j)) {
    418                 add_edge(j, i, mConstraintGraph);
     405                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
    419406            }
    420407        }
    421408        for (unsigned j = i + 1; j < advances; ++j) {
    422409            if (G(i, j)) {
    423                 add_edge(j, i, mConstraintGraph);
     410                mConstraintGraph[add_edge(j, i, mConstraintGraph).first] = ConstraintType::Dependency;
    424411            }
    425412        }
     
    431418    return last;
    432419}
    433 
    434420
    435421/** ------------------------------------------------------------------------------------------------------------- *
     
    443429    }
    444430    assert (num_vertices(mConstraintGraph) == n);
    445 
    446     // The naive way to handle this would be to compute a DNF formula consisting of the
    447     // powerset of all independent (candidate) sets of G, assign a weight to each, and
    448     // try to maximally satisfy the clauses. However, this would be extremely costly to
    449     // compute let alone solve as we could easily generate O(2^100) clauses for a complex
    450     // problem. Further the vast majority of clauses would be false in the end.
    451 
    452     // Moreover, for every set that can Advance is contained in would need a unique
    453     // variable and selector. In other words:
    454 
    455     // Suppose Advance A has a selector variable I. If I is true, then A must be in ONE set.
    456     // Assume A could be in m sets. To enforce this, there are m(m - 1)/2 clauses:
    457 
    458     //   (¬A_1 √ ¬A_2 √ ¬I), (¬A_1 √ ¬A_3 √ ¬I), ..., (¬A_m-1 √ ¬A_m √ ¬I)
    459 
    460     // m here is be equivalent to number of independent sets in the constraint graph G
    461     // that contains A.
    462 
    463     // If two sets have a DEPENDENCY constraint between them, it will introduce a cyclic
    464     // relationship even if those sets are legal on their own. Thus we'd also need need
    465     // hard constraints between all constrained variables related to the pair of Advances.
    466 
    467     // Instead, we only try to solve for one set at a time. This eliminate the need for
    468     // the above constraints and computing m but this process to be closer to a simple
    469     // greedy search.
    470 
    471     // We do want to weight whether to include or exclude an item in a set but what should
    472     // this be? The weight must be related to the elements already in the set. If our goal
    473     // is to balance the perturbation of the AST with the reduction in # of Advances, the
    474     // cost of inclusion / exclusion could be proportional to the # of instructions that
    475     // it increases / decreases the span by --- but how many statements is an Advance worth?
    476 
    477     // What if instead we maintain a queue of advances and discard any that are outside of
    478     // the current window?
    479431
    480432    mCandidateGraph = CandidateGraph(n);
     
    486438    Z3_solver solver = Z3_mk_solver(ctx);
    487439    Z3_solver_inc_ref(ctx, solver);
    488     std::vector<Z3_ast> N(n);
     440
     441    std::vector<Z3_ast> V(n);
    489442    for (unsigned i = 0; i < n; ++i) {
    490         N[i] = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_bool_sort(ctx)); assert (N[i]);
     443        V[i] = Z3_mk_fresh_const(ctx, nullptr, Z3_mk_bool_sort(ctx)); assert (V[i]);
    491444    }
    492445    std::vector<std::pair<unsigned, unsigned>> S;
     
    502455                // try to compute a maximal set for this given set of Advances
    503456                if (S.size() > 2) {
    504                     generateCandidateSets(ctx, solver, S, N);
     457                    generateCandidateSets(ctx, solver, S, V);
    505458                }
    506459                // erase any that preceed our window
    507                 for (auto i = S.begin();;) {
    508                     if (++i == S.end() || (line - std::get<0>(*i)) <= WindowSize) {
    509                         S.erase(S.begin(), i);
    510                         break;
    511                     }
    512                 }
     460                auto end = S.begin();
     461                while (++end != S.end() && ((line - std::get<0>(*end)) > WindowSize));
     462                S.erase(S.begin(), end);
    513463            }
    514464            for (unsigned j : make_iterator_range(adjacent_vertices(i, mConstraintGraph))) {
    515                 Z3_ast disj[2] = { Z3_mk_not(ctx, N[j]), Z3_mk_not(ctx, N[i]) };
     465                Z3_ast disj[2] = { Z3_mk_not(ctx, V[j]), Z3_mk_not(ctx, V[i]) };
    516466                Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, disj));
    517467            }
     
    521471    }
    522472    if (S.size() > 2) {
    523         generateCandidateSets(ctx, solver, S, N);
     473        generateCandidateSets(ctx, solver, S, V);
    524474    }
    525475
     
    533483 * @brief generateCandidateSets
    534484 ** ------------------------------------------------------------------------------------------------------------- */
    535 void MultiplexingPass::generateCandidateSets(Z3_context ctx, Z3_solver solver, const std::vector<std::pair<unsigned, unsigned>> & S, const std::vector<Z3_ast> & N) {
     485void MultiplexingPass::generateCandidateSets(Z3_context ctx, Z3_solver solver, const std::vector<std::pair<unsigned, unsigned>> & S, const std::vector<Z3_ast> & V) {
    536486    assert (S.size() > 2);
    537487    assert (std::get<0>(S.front()) < std::get<0>(S.back()));
    538488    assert ((std::get<0>(S.back()) - std::get<0>(S.front())) <= WindowSize);
     489
    539490    Z3_solver_push(ctx, solver);
    540     const auto n = N.size();
    541     std::vector<Z3_ast> assumptions(S.size());
    542     for (unsigned i = 0, j = 0; i < n; ++i) {
    543         if (LLVM_UNLIKELY(j < S.size() && std::get<1>(S[j]) == i)) { // in our window range
    544             assumptions[j++] = N[i];
    545         } else {
    546             Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, N[i]));
    547         }
    548     }
    549     if (maxsat(ctx, solver, assumptions) != Z3_L_FALSE) {
    550         Z3_model m = Z3_solver_get_model(ctx, solver);
    551         Z3_model_inc_ref(ctx, m);
    552         const auto k = add_vertex(mCandidateGraph); assert(k >= N.size());
    553         Z3_ast TRUE = Z3_mk_true(ctx);
    554         Z3_ast FALSE = Z3_mk_false(ctx);
    555         for (const auto i : S) {
    556             Z3_ast value;
    557             if (LLVM_UNLIKELY(Z3_model_eval(ctx, m, N[std::get<1>(i)], 1, &value) != Z3_TRUE)) {
    558                 throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from constraint model!");
    559             }
    560             if (value == TRUE) {
    561                 add_edge(std::get<1>(i), k, mCandidateGraph);
    562             } else if (LLVM_UNLIKELY(value != FALSE)) {
    563                 throw std::runtime_error("Unexpected Z3 error constraint model value is a non-terminal!");
    564             }
    565         }
    566         Z3_model_dec_ref(ctx, m);
    567     }
     491
     492    const auto n = V.size();
     493    std::vector<unsigned> M(S.size());
     494    for (unsigned i = 0; i < S.size(); ++i) {
     495        M[i] = std::get<1>(S[i]);
     496    }
     497
     498    for (;;) {
     499
     500        std::vector<Z3_ast> assumptions(M.size());
     501        unsigned j = 0;
     502        for (unsigned i = 0; i < n; ++i) {
     503            if (LLVM_UNLIKELY((j < M.size()) && (M[j] == i))) { // in our window range
     504                assumptions[j++] = V[i]; assert (V[i]);
     505            } else {
     506                Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[i]));
     507            }
     508        }
     509        assert (j == M.size());
     510
     511        if (maxsat(ctx, solver, assumptions) >= 0) {
     512            Z3_model m = Z3_solver_get_model(ctx, solver);
     513            Z3_model_inc_ref(ctx, m);
     514            const auto k = add_vertex(mCandidateGraph); assert(k >= V.size());
     515            Z3_ast TRUE = Z3_mk_true(ctx);
     516            for (auto i = M.begin(); i != M.end(); ) {
     517                Z3_ast value;
     518                if (LLVM_UNLIKELY(Z3_model_eval(ctx, m, V[*i], 1, &value) != Z3_TRUE)) {
     519                    throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from constraint model!");
     520                }
     521                if (value == TRUE) {
     522                    add_edge(*i, k, mCandidateGraph);
     523                    Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, V[*i]));
     524                    i = M.erase(i);
     525                } else {
     526                    ++i;
     527                }
     528            }
     529            Z3_model_dec_ref(ctx, m);
     530            if (M.size() > 2) {
     531                continue;
     532            }
     533        }
     534        break;
     535    }
     536
    568537    Z3_solver_pop(ctx, solver, 1);
    569538}
     
    602571    const size_t n = num_vertices(mCandidateGraph) - m;
    603572
    604     bool chosen[n] = {};
     573    std::vector<bool> chosen(n);
    605574
    606575    for (;;) {
     
    711680}
    712681
    713 ///** ------------------------------------------------------------------------------------------------------------- *
    714 // * Topologically sort the sequence of instructions whilst trying to adhere as best as possible to the original
    715 // * program sequence.
    716 // ** ------------------------------------------------------------------------------------------------------------- */
    717 //inline bool topologicalSort(Z3_context ctx, Z3_solver solver, const std::vector<Z3_ast> & nodes, const int limit) {
    718 //    const auto n = nodes.size();
    719 //    if (LLVM_UNLIKELY(n == 0)) {
    720 //        return true;
    721 //    }
    722 //    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
    723 //        return false;
    724 //    }
    725 
    726 //    Z3_ast aux_vars[n];
    727 //    Z3_ast assumptions[n];
    728 //    Z3_ast ordering[n];
    729 //    int increments[n];
    730 
    731 //    Z3_sort boolTy = Z3_mk_bool_sort(ctx);
    732 //    Z3_sort intTy = Z3_mk_int_sort(ctx);
    733 //    Z3_ast one = Z3_mk_int(ctx, 1, intTy);
    734 
    735 //    for (unsigned i = 0; i < n; ++i) {
    736 //        aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, boolTy);
    737 //        assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
    738 //        Z3_ast num = one;
    739 //        if (i > 0) {
    740 //            Z3_ast prior_plus_one[2] = { nodes[i - 1], one };
    741 //            num = Z3_mk_add(ctx, 2, prior_plus_one);
    742 //        }
    743 //        ordering[i] = Z3_mk_eq(ctx, nodes[i], num);
    744 //        increments[i] = 1;
    745 //    }
    746 
    747 //    unsigned unsat = 0;
    748 
    749 //    for (;;) {
    750 //        Z3_solver_push(ctx, solver);
    751 //        for (unsigned i = 0; i < n; ++i) {
    752 //            Z3_ast constraint[2] = {ordering[i], aux_vars[i]};
    753 //            Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, constraint));
    754 //        }
    755 //        if (LLVM_UNLIKELY(Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE)) {
    756 //            errs() << " SATISFIABLE!  (" << unsat << " of " << n << ")\n";
    757 //            return true; // done
    758 //        }
    759 //        Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver); assert (core);
    760 //        unsigned m = Z3_ast_vector_size(ctx, core); assert (m > 0);
    761 
    762 //        errs() << " UNSATISFIABLE " << m << "  (" << unsat << " of " << n <<")\n";
    763 
    764 //        for (unsigned j = 0; j < m; j++) {
    765 //            // check whether assumption[i] is in the core or not
    766 //            bool not_found = true;
    767 //            for (unsigned i = 0; i < n; i++) {
    768 //                if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
    769 
    770 //                    const auto k = increments[i];
    771 
    772 //                    errs() << " -- " << i << " @k=" << k << "\n";
    773 
    774 //                    if (k < limit) {
    775 //                        Z3_ast gap = Z3_mk_int(ctx, 1UL << k, intTy);
    776 //                        Z3_ast num = gap;
    777 //                        if (LLVM_LIKELY(i > 0)) {
    778 //                            Z3_ast prior_plus_gap[2] = { nodes[i - 1], gap };
    779 //                            num = Z3_mk_add(ctx, 2, prior_plus_gap);
    780 //                        }
    781 //                        Z3_dec_ref(ctx, ordering[i]);
    782 //                        ordering[i] = Z3_mk_le(ctx, num, nodes[i]);
    783 //                    } else if (k == limit && i > 0) {
    784 //                        ordering[i] = Z3_mk_le(ctx, nodes[i - 1], nodes[i]);
    785 //                    } else {
    786 //                        assumptions[i] = aux_vars[i]; // <- trivially satisfiable
    787 //                        ++unsat;
    788 //                    }
    789 //                    increments[i] = k + 1;
    790 //                    not_found = false;
    791 //                    break;
    792 //                }
    793 //            }
    794 //            if (LLVM_UNLIKELY(not_found)) {
    795 //                throw std::runtime_error("Unexpected Z3 failure when attempting to locate unsatisfiable ordering constraint!");
    796 //            }
    797 //        }
    798 //        Z3_solver_pop(ctx, solver, 1);
    799 //    }
    800 //    llvm_unreachable("maxsat wrongly reported this being unsatisfiable despite being able to satisfy the hard constraints!");
    801 //    return false;
    802 //}
    803 
    804 ///** ------------------------------------------------------------------------------------------------------------- *
    805 // * Topologically sort the sequence of instructions whilst trying to adhere as best as possible to the original
    806 // * program sequence.
    807 // ** ------------------------------------------------------------------------------------------------------------- */
    808 //inline bool topologicalSort(Z3_context ctx, Z3_solver solver, const std::vector<Z3_ast> & nodes, const int limit) {
    809 //    const auto n = nodes.size();
    810 //    if (LLVM_UNLIKELY(n == 0)) {
    811 //        return true;
    812 //    }
    813 //    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
    814 //        return false;
    815 //    }
    816 
    817 //    Z3_ast aux_vars[n];
    818 //    Z3_ast assumptions[n];
    819 
    820 //    Z3_sort boolTy = Z3_mk_bool_sort(ctx);
    821 //    Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
    822 
    823 //    for (unsigned i = 0; i < n; ++i) {
    824 //        aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, boolTy);
    825 //        assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
    826 //        Z3_ast num = one;
    827 //        if (i > 0) {
    828 //            Z3_ast prior_plus_one[2] = { nodes[i - 1], one };
    829 //            num = Z3_mk_add(ctx, 2, prior_plus_one);
    830 //        }
    831 //        Z3_ast ordering = Z3_mk_eq(ctx, nodes[i], num);
    832 //        Z3_ast constraint[2] = {ordering, aux_vars[i]};
    833 //        Z3_solver_assert(ctx, solver, Z3_mk_or(ctx, 2, constraint));
    834 //    }
    835 
    836 //    for (unsigned k = 0; k < n; ) {
    837 //        if (LLVM_UNLIKELY(Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE)) {
    838 //            errs() << " SATISFIABLE!\n";
    839 //            return true; // done
    840 //        }
    841 //        Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver); assert (core);
    842 //        unsigned m = Z3_ast_vector_size(ctx, core); assert (m > 0);
    843 
    844 //        k += m;
    845 
    846 //        errs() << " UNSATISFIABLE " << m << " (" << k << ")\n";
    847 
    848 //        for (unsigned j = 0; j < m; j++) {
    849 //            // check whether assumption[i] is in the core or not
    850 //            bool not_found = true;
    851 //            for (unsigned i = 0; i < n; i++) {
    852 //                if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
    853 //                    assumptions[i] = aux_vars[i];
    854 //                    not_found = false;
    855 //                    break;
    856 //                }
    857 //            }
    858 //            if (LLVM_UNLIKELY(not_found)) {
    859 //                throw std::runtime_error("Unexpected Z3 failure when attempting to locate unsatisfiable ordering constraint!");
    860 //            }
    861 //        }
    862 //    }
    863 //    llvm_unreachable("maxsat wrongly reported this being unsatisfiable despite being able to satisfy the hard constraints!");
    864 //    return false;
    865 //}
    866 
    867 
    868682/** ------------------------------------------------------------------------------------------------------------- *
    869683 * @brief addWithHardConstraints
     
    942756    assert (!end || isa<If>(end) || isa<While>(end));
    943757
     758    // TODO: should we test whether sets overlap and merge the computations together?
     759
    944760    Z3_config cfg = Z3_mk_config();
    945761    Z3_set_param_value(cfg, "MODEL", "true");
     
    959775            PabloAST * muxed[m];
    960776            PabloAST * muxed_n[m];
     777            PabloAST * demuxed[n];
    961778
    962779            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
     
    966783                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
    967784                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
    968                 assert ("Inputs are not in sequential order!" && (i == 0 || (i > 0 && dominates(input[i - 1], input[i]))));
    969785                ++i;
    970786            }
    971787
    972             Statement * const A1 = input[0];
    973             Statement * const An = input[n - 1]->getNextNode();
    974 
    975             Statement * const ip = A1->getPrevNode(); // save our insertion point prior to modifying the AST
     788            // We can't trust the AST will be in the original order as we can multiplex a region of the program
     789            // more than once.
     790            Statement * initial = nullptr, * sentinal = nullptr;
     791            for (Statement * stmt : *block) {
     792                if (isa<Advance>(stmt)) {
     793                    for (unsigned i = 0; i < n; ++i) {
     794                        if (stmt == input[i]) {
     795                            initial = initial ? initial : stmt;
     796                            sentinal = stmt;
     797                            break;
     798                        }
     799                    }
     800                }
     801            }
     802            assert (initial);
     803
     804            Statement * const ip = initial->getPrevNode(); // save our insertion point prior to modifying the AST
     805            sentinal = sentinal->getNextNode();
    976806
    977807            Z3_solver_push(ctx, solver);
     
    983813            Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
    984814            std::vector<Z3_ast> ordering;
    985 //            std::vector<Z3_ast> nodes;
    986 
    987             for (Statement * stmt = A1; stmt != An; stmt = stmt->getNextNode()) { assert (stmt != ip);
     815
     816            for (Statement * stmt = initial; stmt != sentinal; stmt = stmt->getNextNode()) { assert (stmt != ip);
    988817                Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
    989818                // compute the soft ordering constraints
     
    997826                    ordering.push_back(Z3_mk_lt(ctx, prior, node));
    998827                }
    999 
    1000 
    1001 //                for (Z3_ast prior : nodes) {
    1002 //                    Z3_solver_assert(ctx, solver, Z3_mk_not(ctx, Z3_mk_eq(ctx, prior, node)));
    1003 //                }
    1004  //               nodes.push_back(node);
    1005 
    1006 
    1007828                prior = node;
    1008829            }
    1009 
    1010             // assert (nodes.size() <= WindowSize);
    1011830
    1012831            block->setInsertPoint(block->back()); // <- necessary for domination check!
     
    1054873                }
    1055874                assert (replacement);
    1056                 PabloAST * const demuxed = Q.front(); Q.clear();
     875                demuxed[i] = Q.front(); Q.clear();
    1057876
    1058877                const auto f = M.find(input[i]);
     
    1060879                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
    1061880                M.erase(f);
    1062 
    1063                 input[i]->replaceWith(demuxed);
    1064                 assert (M.count(input[i]) == 0);
    1065881            }
    1066882
    1067883            assert (M.count(ip) == 0);
    1068884
    1069             if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) != Z3_L_TRUE)) {
    1070                 throw std::runtime_error("Unexpected Z3 failure when attempting to topologically sort the AST!");
    1071             }
    1072 
    1073             Z3_model model = Z3_solver_get_model(ctx, solver);
    1074             Z3_model_inc_ref(ctx, model);
    1075 
    1076             std::vector<std::pair<long long int, Statement *>> I;
    1077 
    1078             for (const auto i : M) {
    1079                 Z3_ast value;
    1080                 if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE)) {
    1081                     throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
    1082                 }
    1083                 long long int line;
    1084                 if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
    1085                     throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
    1086                 }
    1087                 I.emplace_back(line, std::get<0>(i));
    1088             }
    1089 
    1090             Z3_model_dec_ref(ctx, model);
    1091 
    1092             std::sort(I.begin(), I.end());
    1093 
    1094             block->setInsertPoint(ip);
    1095             for (auto i : I) {
    1096                 block->insert(std::get<1>(i));
     885            const auto satisfied = maxsat(ctx, solver, ordering);
     886
     887            if (LLVM_UNLIKELY(satisfied >= 0)) {
     888
     889                Z3_model model = Z3_solver_get_model(ctx, solver);
     890                Z3_model_inc_ref(ctx, model);
     891
     892                std::vector<std::pair<long long int, Statement *>> I;
     893
     894                for (const auto i : M) {
     895                    Z3_ast value;
     896                    if (LLVM_UNLIKELY(Z3_model_eval(ctx, model, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE)) {
     897                        throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
     898                    }
     899                    long long int line;
     900                    if (LLVM_UNLIKELY(Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE)) {
     901                        throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
     902                    }
     903                    I.emplace_back(line, std::get<0>(i));
     904                }
     905
     906                Z3_model_dec_ref(ctx, model);
     907
     908                std::sort(I.begin(), I.end());
     909
     910                block->setInsertPoint(ip);
     911                for (auto i : I) {
     912                    block->insert(std::get<1>(i));
     913                }
     914
     915                for (unsigned i = 0; i < n; ++i) {
     916                    input[i]->replaceWith(demuxed[i], true, true);
     917                    auto ref = mCharacterization.find(input[i]);
     918                    assert (ref != mCharacterization.end());
     919                    add(demuxed[i], std::get<0>(ref->second), -1);
     920                }
     921
     922            } else { // fatal error; delete any statements we created.
     923
     924                for (unsigned i = 0; i < n; ++i) {
     925                    if (LLVM_LIKELY(isa<Statement>(demuxed[i]))) {
     926                        cast<Statement>(demuxed[i])->eraseFromParent(true);
     927                    }
     928                }
     929
    1097930            }
    1098931
     
    1103936    Z3_solver_dec_ref(ctx, solver);
    1104937    Z3_del_context(ctx);
    1105 
    1106 }
    1107 
    1108 ///** ------------------------------------------------------------------------------------------------------------- *
    1109 // * @brief multiplexSelectedSets
    1110 // ** ------------------------------------------------------------------------------------------------------------- */
    1111 //inline void MultiplexingPass::multiplexSelectedSets(PabloBlock * const block, Statement * const begin, Statement * const end) {
    1112 
    1113 //    assert ("begin cannot be null!" && begin);
    1114 //    assert (begin->getParent() == block);
    1115 //    assert (!end || end->getParent() == block);
    1116 //    assert (!end || isa<If>(end) || isa<While>(end));
    1117 
    1118 //    Statement * const ip = begin->getPrevNode(); // save our insertion point prior to modifying the AST
    1119 
    1120 //    Z3_config cfg = Z3_mk_config();
    1121 //    Z3_set_param_value(cfg, "MODEL", "true");
    1122 //    Z3_context ctx = Z3_mk_context(cfg);
    1123 //    Z3_del_config(cfg);
    1124 //    Z3_solver solver = Z3_mk_solver(ctx);
    1125 //    Z3_solver_inc_ref(ctx, solver);
    1126 
    1127 //    const auto first_set = num_vertices(mConstraintGraph);
    1128 //    const auto last_set = num_vertices(mCandidateGraph);
    1129 
    1130 //    // Compute the hard and soft constraints for any part of the AST that we are not intending to modify.
    1131 //    flat_map<Statement *, Z3_ast> M;
    1132 
    1133 //    Z3_ast prior = nullptr;
    1134 //    Z3_ast one = Z3_mk_int(ctx, 1, Z3_mk_int_sort(ctx));
    1135 //    std::vector<Z3_ast> ordering;
    1136 
    1137 //    for (Statement * stmt = begin; stmt != end; stmt = stmt->getNextNode()) { assert (stmt != ip);
    1138 //        Z3_ast node = addWithHardConstraints(ctx, solver, block, stmt, M);
    1139 //        // compute the soft ordering constraints
    1140 //        Z3_ast num = one;
    1141 //        if (prior) {
    1142 //            Z3_ast prior_plus_one[2] = { prior, one };
    1143 //            num = Z3_mk_add(ctx, 2, prior_plus_one);
    1144 //        }
    1145 //        ordering.push_back(Z3_mk_eq(ctx, node, num));
    1146 //        prior = node;
    1147 //    }
    1148 
    1149 //    block->setInsertPoint(block->back()); // <- necessary for domination check!
    1150 
    1151 //    errs() << "---------------------------------------------\n";
    1152 
    1153 //    for (auto idx = first_set; idx != last_set; ++idx) {
    1154 //        const size_t n = degree(idx, mCandidateGraph);
    1155 //        if (n) {
    1156 //            const size_t m = log2_plus_one(n); assert (n > 2 && m < n);
    1157 //            Advance * input[n];
    1158 //            PabloAST * muxed[m];
    1159 //            PabloAST * muxed_n[m];
    1160 
    1161 //            errs() << n << " -> " << m << "\n";
    1162 
    1163 //            // The multiplex set graph is a DAG with edges denoting the set relationships of our independent sets.
    1164 //            unsigned i = 0;
    1165 //            for (const auto u : make_iterator_range(adjacent_vertices(idx, mCandidateGraph))) {
    1166 //                input[i] = mConstraintGraph[u];
    1167 //                assert ("Not all inputs are in the same block!" && (input[i]->getParent() == block));
    1168 //                assert ("Not all inputs advance by the same amount!" && (input[i]->getOperand(1) == input[0]->getOperand(1)));
    1169 //                ++i;
    1170 //            }
    1171 
    1172 //            circular_buffer<PabloAST *> Q(n);
    1173 
    1174 //            /// Perform n-to-m Multiplexing
    1175 //            for (size_t j = 0; j != m; ++j) {
    1176 //                std::ostringstream prefix;
    1177 //                prefix << "mux" << n << "to" << m << '.' << (j);
    1178 //                assert (Q.empty());
    1179 //                for (size_t i = 0; i != n; ++i) {
    1180 //                    if (((i + 1) & (1UL << j)) != 0) {
    1181 //                        Q.push_back(input[i]->getOperand(0));
    1182 //                    }
    1183 //                }
    1184 //                while (Q.size() > 1) {
    1185 //                    PabloAST * a = Q.front(); Q.pop_front();
    1186 //                    PabloAST * b = Q.front(); Q.pop_front();
    1187 //                    PabloAST * expr = block->createOr(a, b);
    1188 //                    addWithHardConstraints(ctx, solver, block, expr, M, ip);
    1189 //                    Q.push_back(expr);
    1190 //                }
    1191 //                PabloAST * const muxing = Q.front(); Q.clear();
    1192 //                muxed[j] = block->createAdvance(muxing, input[0]->getOperand(1), prefix.str());
    1193 //                addWithHardConstraints(ctx, solver, block, muxed[j], M, ip);
    1194 //                muxed_n[j] = block->createNot(muxed[j]);
    1195 //                addWithHardConstraints(ctx, solver, block, muxed_n[j], M, ip);
    1196 //            }
    1197 
    1198 //            /// Perform m-to-n Demultiplexing
    1199 //            for (size_t i = 0; i != n; ++i) {
    1200 //                // Construct the demuxed values and replaces all the users of the original advances with them.
    1201 //                assert (Q.empty());
    1202 //                for (size_t j = 0; j != m; ++j) {
    1203 //                    Q.push_back((((i + 1) & (1UL << j)) != 0) ? muxed[j] : muxed_n[j]);
    1204 //                }
    1205 //                Z3_ast replacement = nullptr;
    1206 //                while (Q.size() > 1) {
    1207 //                    PabloAST * const a = Q.front(); Q.pop_front();
    1208 //                    PabloAST * const b = Q.front(); Q.pop_front();
    1209 //                    PabloAST * expr = block->createAnd(a, b);
    1210 //                    replacement = addWithHardConstraints(ctx, solver, block, expr, M, ip);
    1211 //                    Q.push_back(expr);
    1212 //                }
    1213 //                assert (replacement);
    1214 //                PabloAST * const demuxed = Q.front(); Q.clear();
    1215 
    1216 //                const auto f = M.find(input[i]);
    1217 //                assert (f != M.end());
    1218 //                Z3_solver_assert(ctx, solver, Z3_mk_eq(ctx, f->second, replacement));
    1219 //                M.erase(f);
    1220 
    1221 //                input[i]->replaceWith(demuxed);
    1222 //                assert (M.count(input[i]) == 0);
    1223 //            }
    1224 //        }
    1225 //    }
    1226 
    1227 //    assert (M.count(ip) == 0);
    1228 
    1229 //    // if (LLVM_UNLIKELY(maxsat(ctx, solver, ordering) == Z3_L_FALSE)) {
    1230 //    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) != Z3_L_TRUE)) {
    1231 //        throw std::runtime_error("Unexpected Z3 failure when attempting to topologically sort the AST!");
    1232 //    }
    1233 
    1234 //    Z3_model m = Z3_solver_get_model(ctx, solver);
    1235 //    Z3_model_inc_ref(ctx, m);
    1236 
    1237 //    std::vector<std::pair<long long int, Statement *>> Q;
    1238 
    1239 //    errs() << "-----------------------------------------------------------\n";
    1240 
    1241 //    for (const auto i : M) {
    1242 //        Z3_ast value;
    1243 //        if (Z3_model_eval(ctx, m, std::get<1>(i), Z3_L_TRUE, &value) != Z3_L_TRUE) {
    1244 //            throw std::runtime_error("Unexpected Z3 error when attempting to obtain value from model!");
    1245 //        }
    1246 //        long long int line;
    1247 //        if (Z3_get_numeral_int64(ctx, value, &line) != Z3_L_TRUE) {
    1248 //            throw std::runtime_error("Unexpected Z3 error when attempting to convert model value to integer!");
    1249 //        }
    1250 //        Q.emplace_back(line, std::get<0>(i));
    1251 //    }
    1252 
    1253 //    Z3_model_dec_ref(ctx, m);
    1254 //    Z3_solver_dec_ref(ctx, solver);
    1255 //    Z3_del_context(ctx);
    1256 
    1257 //    std::sort(Q.begin(), Q.end());
    1258 
    1259 //    block->setInsertPoint(ip);
    1260 //    for (auto i : Q) {
    1261 //        block->insert(std::get<1>(i));
    1262 //    }
    1263 //}
     938}
    1264939
    1265940/** ------------------------------------------------------------------------------------------------------------- *
     
    1300975    assert (expr);
    1301976    auto f = mCharacterization.find(expr);
    1302     assert (f != mCharacterization.end());
    1303     auto & val = f->second;
     977    if (LLVM_UNLIKELY(f == mCharacterization.end())) {
     978        characterize(cast<Statement>(expr), false);
     979        f = mCharacterization.find(expr);
     980        assert (f != mCharacterization.end());
     981    }
     982    CharacterizationRef & ref = f->second;
    1304983    if (deref) {
    1305         unsigned & refs = std::get<1>(val);
    1306         assert (refs > 0);
    1307         --refs;
    1308     }
    1309     return std::get<0>(val);
     984        if (LLVM_LIKELY(std::get<1>(ref)) > 0) {
     985            std::get<1>(ref) -= 1;
     986        }
     987    }
     988    return std::get<0>(ref);
    1310989}
    1311990
     
    1313992 * @brief make
    1314993 ** ------------------------------------------------------------------------------------------------------------- */
    1315 inline Z3_ast MultiplexingPass::make(const PabloAST * const expr) {
    1316     assert (expr);
     994inline Z3_ast MultiplexingPass::makeVar() {
    1317995    Z3_ast node = Z3_mk_fresh_const(mContext, nullptr, Z3_mk_bool_sort(mContext));
    1318996    Z3_inc_ref(mContext, node);
    1319     return add(expr, node);
     997    return node;
    1320998}
    1321999
     
    13231001 * @brief add
    13241002 ** ------------------------------------------------------------------------------------------------------------- */
    1325 inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node) {   
    1326     mCharacterization.insert(std::make_pair(expr, std::make_pair(node, expr->getNumUses())));
     1003inline Z3_ast MultiplexingPass::add(const PabloAST * const expr, Z3_ast node, const size_t refs) {
     1004    mCharacterization.insert(std::make_pair(expr, std::make_pair(node, refs)));
    13271005    return node;
    13281006}
     
    13311009 * @brief constructor
    13321010 ** ------------------------------------------------------------------------------------------------------------- */
    1333 inline MultiplexingPass::MultiplexingPass(PabloFunction & f, const RNG::result_type seed, Z3_context context, Z3_solver solver)
     1011inline MultiplexingPass::MultiplexingPass(PabloFunction & f, Z3_context context, Z3_solver solver)
    13341012: mContext(context)
    13351013, mSolver(solver)
    13361014, mFunction(f)
    1337 , mRNG(seed)
    13381015, mConstraintGraph(0)
    13391016{
     
    13411018}
    13421019
    1343 
    1344 inline Z3_ast mk_binary_or(Z3_context ctx, Z3_ast in_1, Z3_ast in_2) {
    1345     Z3_ast args[2] = { in_1, in_2 };
    1346     return Z3_mk_or(ctx, 2, args);
    1347 }
    1348 
    1349 inline Z3_ast mk_ternary_or(Z3_context ctx, Z3_ast in_1, Z3_ast in_2, Z3_ast in_3) {
    1350     Z3_ast args[3] = { in_1, in_2, in_3 };
    1351     return Z3_mk_or(ctx, 3, args);
    1352 }
    1353 
    1354 inline Z3_ast mk_binary_and(Z3_context ctx, Z3_ast in_1, Z3_ast in_2) {
    1355     Z3_ast args[2] = { in_1, in_2 };
    1356     return Z3_mk_and(ctx, 2, args);
    1357 }
    1358 
    1359 ///**
    1360 //   \brief Create a full adder with inputs \c in_1, \c in_2 and \c cin.
    1361 //   The output of the full adder is stored in \c out, and the carry in \c c_out.
    1362 //*/
    1363 //inline std::pair<Z3_ast, Z3_ast> mk_full_adder(Z3_context ctx, Z3_ast in_1, Z3_ast in_2, Z3_ast cin) {
    1364 //    Z3_ast out = Z3_mk_xor(ctx, Z3_mk_xor(ctx, in_1, in_2), cin);
    1365 //    Z3_ast cout = mk_ternary_or(ctx, mk_binary_and(ctx, in_1, in_2), mk_binary_and(ctx, in_1, cin), mk_binary_and(ctx, in_2, cin));
    1366 //    return std::make_pair(out, cout);
    1367 //}
    1368 
    1369 /**
    1370    \brief Create an adder for inputs of size \c num_bits.
    1371    The arguments \c in1 and \c in2 are arrays of bits of size \c num_bits.
    1372 
    1373    \remark \c result must be an array of size \c num_bits + 1.
    1374 */
    1375 void mk_adder(Z3_context ctx, const unsigned num_bits, Z3_ast * in_1, Z3_ast * in_2, Z3_ast * result) {
    1376     Z3_ast cin = Z3_mk_false(ctx);
    1377     for (unsigned i = 0; i < num_bits; i++) {
    1378         result[i] = Z3_mk_xor(ctx, Z3_mk_xor(ctx, in_1[i], in_2[i]), cin);
    1379         cin = mk_ternary_or(ctx, mk_binary_and(ctx, in_1[i], in_2[i]), mk_binary_and(ctx, in_1[i], cin), mk_binary_and(ctx, in_2[i], cin));
    1380     }
    1381     result[num_bits] = cin;
    1382 }
    1383 
    1384 /**
    1385    \brief Given \c num_ins "numbers" of size \c num_bits stored in \c in.
    1386    Create floor(num_ins/2) adder circuits. Each circuit is adding two consecutive "numbers".
    1387    The numbers are stored one after the next in the array \c in.
    1388    That is, the array \c in has size num_bits * num_ins.
    1389    Return an array of bits containing \c ceil(num_ins/2) numbers of size \c (num_bits + 1).
    1390    If num_ins/2 is not an integer, then the last "number" in the output, is the last "number" in \c in with an appended "zero".
    1391 */
    1392 unsigned mk_adder_pairs(Z3_context ctx, const unsigned num_bits, const unsigned num_ins, Z3_ast * in, Z3_ast * out) {
    1393     unsigned out_num_bits = num_bits + 1;
    1394     Z3_ast * _in          = in;
    1395     Z3_ast * _out         = out;
    1396     unsigned out_num_ins  = (num_ins % 2 == 0) ? (num_ins / 2) : (num_ins / 2) + 1;
    1397     for (unsigned i = 0; i < num_ins / 2; i++) {
    1398         mk_adder(ctx, num_bits, _in, _in + num_bits, _out);
    1399         _in  += num_bits;
    1400         _in  += num_bits;
    1401         _out += out_num_bits;
    1402     }
    1403     if (num_ins % 2 != 0) {
    1404         for (unsigned i = 0; i < num_bits; i++) {
    1405             _out[i] = _in[i];
    1406         }
    1407         _out[num_bits] = Z3_mk_false(ctx);
    1408     }
    1409     return out_num_ins;
    1410 }
    1411 
    1412 /**
    1413    \brief Return the \c idx bit of \c val.
    1414 */
    1415 inline bool get_bit(unsigned val, unsigned idx) {
    1416     return (val & (1U << (idx & 31))) != 0;
    1417 }
    1418 
    1419 /**
    1420    \brief Given an integer val encoded in n bits (boolean variables), assert the constraint that val <= k.
    1421 */
    1422 void assert_le_one(Z3_context ctx, Z3_solver s, unsigned n, Z3_ast * val)
    1423 {
    1424     Z3_ast i1, i2;
    1425     Z3_ast not_val = Z3_mk_not(ctx, val[0]);
    1426     assert (get_bit(1, 0));
    1427     Z3_ast out = Z3_mk_true(ctx);
    1428     for (unsigned i = 1; i < n; i++) {
    1429         not_val = Z3_mk_not(ctx, val[i]);
    1430         if (get_bit(1, i)) {
    1431             i1 = not_val;
    1432             i2 = out;
    1433         }
    1434         else {
    1435             i1 = Z3_mk_false(ctx);
    1436             i2 = Z3_mk_false(ctx);
    1437         }
    1438         out = mk_ternary_or(ctx, i1, i2, mk_binary_and(ctx, not_val, out));
    1439     }
    1440     Z3_solver_assert(ctx, s, out);
    1441 }
    1442 
    1443 /**
    1444    \brief Create a counter circuit to count the number of "ones" in lits.
    1445    The function returns an array of bits (i.e. boolean expressions) containing the output of the circuit.
    1446    The size of the array is stored in out_sz.
    1447 */
    1448 void mk_counter_circuit(Z3_context ctx, Z3_solver solver, unsigned n, Z3_ast * lits) {
    1449     unsigned k = 1;
    1450     assert (n != 0);
    1451     Z3_ast aux_array_1[n + 1];
    1452     Z3_ast aux_array_2[n + 1];
    1453     Z3_ast * aux_1 = aux_array_1;
    1454     Z3_ast * aux_2 = aux_array_2;
    1455     std::memcpy(aux_1, lits, sizeof(Z3_ast) * n);
    1456     while (n > 1) {
    1457         assert (aux_1 != aux_2);
    1458         n = mk_adder_pairs(ctx, k++, n, aux_1, aux_2);
    1459         std::swap(aux_1, aux_2);
    1460     }
    1461     assert_le_one(ctx, solver, k, aux_1);
    1462 }
    1463 
    1464 /** ------------------------------------------------------------------------------------------------------------- *
    1465  * Fu & Malik procedure for MaxSAT. This procedure is based on unsat core extraction and the at-most-one constraint.
    1466  ** ------------------------------------------------------------------------------------------------------------- */
    1467 Z3_bool maxsat(Z3_context ctx, Z3_solver solver, std::vector<Z3_ast> & soft) {
    1468     if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
    1469         return Z3_L_FALSE;
    1470     }
    1471     if (LLVM_UNLIKELY(soft.empty())) {
    1472         return true;
    1473     }
    1474 
    1475     const auto n = soft.size();
    1476     const auto ty = Z3_mk_bool_sort(ctx);
    1477     Z3_ast aux_vars[n];
    1478     Z3_ast assumptions[n];
    1479 
    1480     for (unsigned i = 0; i < n; ++i) {
    1481         aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, ty);
    1482         Z3_solver_assert(ctx, solver, mk_binary_or(ctx, soft[i], aux_vars[i]));
    1483     }
    1484 
    1485     for (;;) {
    1486         // create assumptions
    1487         for (unsigned i = 0; i < n; i++) {
    1488             // Recall that we asserted (soft_cnstrs[i] \/ aux_vars[i])
    1489             // So using (NOT aux_vars[i]) as an assumption we are actually forcing the soft_cnstrs[i] to be considered.
    1490             assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
    1491         }
    1492         if (Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE) {
    1493             return Z3_L_TRUE; // done
    1494         } else {
    1495             Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver);
    1496             unsigned m = Z3_ast_vector_size(ctx, core);
    1497             Z3_ast block_vars[m];
    1498             unsigned k = 0;
    1499             // update soft-constraints and aux_vars
    1500             for (unsigned i = 0; i < n; i++) {
    1501                 // check whether assumption[i] is in the core or not
    1502                 for (unsigned j = 0; j < m; j++) {
    1503                     if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
    1504                         // assumption[i] is in the unsat core... so soft_cnstrs[i] is in the unsat core
    1505                         Z3_ast block_var = Z3_mk_fresh_const(ctx, nullptr, ty);
    1506                         Z3_ast new_aux_var = Z3_mk_fresh_const(ctx, nullptr, ty);
    1507                         soft[i] = mk_binary_or(ctx, soft[i], block_var);
    1508                         aux_vars[i] = new_aux_var;
    1509                         block_vars[k] = block_var;
    1510                         ++k;
    1511                         // Add new constraint containing the block variable.
    1512                         // Note that we are using the new auxiliary variable to be able to use it as an assumption.
    1513                         Z3_solver_assert(ctx, solver, mk_binary_or(ctx, soft[i], new_aux_var) );
    1514                         break;
    1515                     }
    1516                 }
    1517 
    1518             }
    1519             if (k > 1) {
    1520                 mk_counter_circuit(ctx, solver, k, block_vars);
    1521             }
    1522         }
    1523     }
    1524     llvm_unreachable("unreachable");
    1525     return Z3_L_FALSE;
    1526 }
    1527 
    15281020} // end of namespace pablo
Note: See TracChangeset for help on using the changeset viewer.