source: icGREP/icgrep-devel/icgrep/pablo/optimizers/maxsat.hpp @ 5350

Last change on this file since 5350 was 5156, checked in by nmedfort, 3 years ago

Work on multiplexing and distribution passes + a few AST modification bug fixes.

File size: 6.5 KB
Line 
1#ifndef MAXSAT_HPP
2#define MAXSAT_HPP
3
4#include <vector>
5#include <z3.h>
6
7inline Z3_ast mk_binary_or(Z3_context ctx, Z3_ast in_1, Z3_ast in_2) {
8    Z3_ast args[2] = { in_1, in_2 };
9    return Z3_mk_or(ctx, 2, args);
10}
11
12inline Z3_ast mk_ternary_or(Z3_context ctx, Z3_ast in_1, Z3_ast in_2, Z3_ast in_3) {
13    Z3_ast args[3] = { in_1, in_2, in_3 };
14    return Z3_mk_or(ctx, 3, args);
15}
16
17inline Z3_ast mk_binary_and(Z3_context ctx, Z3_ast in_1, Z3_ast in_2) {
18    Z3_ast args[2] = { in_1, in_2 };
19    return Z3_mk_and(ctx, 2, args);
20}
21
22/**
23   \brief Create an adder for inputs of size \c num_bits.
24   The arguments \c in1 and \c in2 are arrays of bits of size \c num_bits.
25
26   \remark \c result must be an array of size \c num_bits + 1.
27*/
28inline void mk_adder(Z3_context ctx, const unsigned num_bits, Z3_ast * in_1, Z3_ast * in_2, Z3_ast * result) {
29    Z3_ast cin = Z3_mk_false(ctx);
30    for (unsigned i = 0; i < num_bits; i++) {
31        result[i] = Z3_mk_xor(ctx, Z3_mk_xor(ctx, in_1[i], in_2[i]), cin);
32        cin = mk_ternary_or(ctx, mk_binary_and(ctx, in_1[i], in_2[i]), mk_binary_and(ctx, in_1[i], cin), mk_binary_and(ctx, in_2[i], cin));
33    }
34    result[num_bits] = cin;
35}
36
37/**
38   \brief Given \c num_ins "numbers" of size \c num_bits stored in \c in.
39   Create floor(num_ins/2) adder circuits. Each circuit is adding two consecutive "numbers".
40   The numbers are stored one after the next in the array \c in.
41   That is, the array \c in has size num_bits * num_ins.
42   Return an array of bits containing \c ceil(num_ins/2) numbers of size \c (num_bits + 1).
43   If num_ins/2 is not an integer, then the last "number" in the output, is the last "number" in \c in with an appended "zero".
44*/
45inline unsigned mk_adder_pairs(Z3_context ctx, const unsigned num_bits, const unsigned num_ins, Z3_ast * in, Z3_ast * out) {
46    unsigned out_num_bits = num_bits + 1;
47    Z3_ast * _in          = in;
48    Z3_ast * _out         = out;
49    unsigned out_num_ins  = (num_ins % 2 == 0) ? (num_ins / 2) : (num_ins / 2) + 1;
50    for (unsigned i = 0; i < num_ins / 2; i++) {
51        mk_adder(ctx, num_bits, _in, _in + num_bits, _out);
52        _in  += num_bits;
53        _in  += num_bits;
54        _out += out_num_bits;
55    }
56    if (num_ins % 2 != 0) {
57        for (unsigned i = 0; i < num_bits; i++) {
58            _out[i] = _in[i];
59        }
60        _out[num_bits] = Z3_mk_false(ctx);
61    }
62    return out_num_ins;
63}
64
65/**
66   \brief Return the \c idx bit of \c val.
67*/
68inline bool get_bit(unsigned val, unsigned idx) {
69    return (val & (1U << (idx & 31))) != 0;
70}
71
72/**
73   \brief Given an integer val encoded in n bits (boolean variables), assert the constraint that val <= k.
74*/
75inline void assert_le_one(Z3_context ctx, Z3_solver s, unsigned n, Z3_ast * val) {
76    Z3_ast i1, i2;
77    Z3_ast not_val = Z3_mk_not(ctx, val[0]);
78    assert (get_bit(1, 0));
79    Z3_ast out = Z3_mk_true(ctx);
80    for (unsigned i = 1; i < n; i++) {
81        not_val = Z3_mk_not(ctx, val[i]);
82        if (get_bit(1, i)) {
83            i1 = not_val;
84            i2 = out;
85        } else {
86            i1 = Z3_mk_false(ctx);
87            i2 = Z3_mk_false(ctx);
88        }
89        out = mk_ternary_or(ctx, i1, i2, mk_binary_and(ctx, not_val, out));
90    }
91    // Z3_mk_atmost ?
92    Z3_solver_assert(ctx, s, out);
93}
94
95/** ------------------------------------------------------------------------------------------------------------- *
96 * Fu & Malik procedure for MaxSAT. This procedure is based on unsat core extraction and the at-most-one constraint.
97 ** ------------------------------------------------------------------------------------------------------------- */
98static int maxsat(Z3_context ctx, Z3_solver solver, std::vector<Z3_ast> & soft) {
99    if (LLVM_UNLIKELY(Z3_solver_check(ctx, solver) == Z3_L_FALSE)) {
100        return -1;
101    }
102    if (LLVM_UNLIKELY(soft.empty())) {
103        return 0;
104    }
105    const auto n = soft.size();
106    const auto ty = Z3_mk_bool_sort(ctx);
107    Z3_ast aux_vars[n];
108    Z3_ast assumptions[n];
109
110    for (unsigned i = 0; i < n; ++i) {
111        aux_vars[i] = Z3_mk_fresh_const(ctx, nullptr, ty);
112        Z3_solver_assert(ctx, solver, mk_binary_or(ctx, soft[i], aux_vars[i]));
113    }
114
115    for (unsigned c = n; c; --c) {
116        // create assumptions
117        for (unsigned i = 0; i < n; i++) {
118            // Recall that we asserted (soft_cnstrs[i] \/ aux_vars[i])
119            // So using (NOT aux_vars[i]) as an assumption we are actually forcing the soft_cnstrs[i] to be considered.
120            assumptions[i] = Z3_mk_not(ctx, aux_vars[i]);
121        }
122        if (Z3_solver_check_assumptions(ctx, solver, n, assumptions) != Z3_L_FALSE) {
123            return c; // done
124        } else {
125            Z3_ast_vector core = Z3_solver_get_unsat_core(ctx, solver);
126            unsigned m = Z3_ast_vector_size(ctx, core);
127            Z3_ast block_vars[m];
128            unsigned k = 0;
129            // update soft-constraints and aux_vars
130            for (unsigned i = 0; i < n; i++) {
131                // check whether assumption[i] is in the core or not
132                for (unsigned j = 0; j < m; j++) {
133                    if (assumptions[i] == Z3_ast_vector_get(ctx, core, j)) {
134                        // assumption[i] is in the unsat core... so soft_cnstrs[i] is in the unsat core
135                        Z3_ast block_var = Z3_mk_fresh_const(ctx, nullptr, ty);
136                        Z3_ast new_aux_var = Z3_mk_fresh_const(ctx, nullptr, ty);
137                        soft[i] = mk_binary_or(ctx, soft[i], block_var);
138                        aux_vars[i] = new_aux_var;
139                        block_vars[k] = block_var;
140                        ++k;
141                        // Add new constraint containing the block variable.
142                        // Note that we are using the new auxiliary variable to be able to use it as an assumption.
143                        Z3_solver_assert(ctx, solver, mk_binary_or(ctx, soft[i], new_aux_var) );
144                        break;
145                    }
146                }
147
148            }
149            if (k > 1) {
150                Z3_ast aux_array_1[k + 1];
151                Z3_ast aux_array_2[k + 1];
152                Z3_ast * aux_1 = aux_array_1;
153                Z3_ast * aux_2 = aux_array_2;
154                std::memcpy(aux_1, block_vars, sizeof(Z3_ast) * k);
155                unsigned i = 1;
156                for (; k > 1; ++i) {
157                    assert (aux_1 != aux_2);
158                    k = mk_adder_pairs(ctx, i, k, aux_1, aux_2);
159                    std::swap(aux_1, aux_2);
160                }
161                assert_le_one(ctx, solver, i, aux_1);
162            }
163        }
164    }
165    llvm_unreachable("unreachable");
166    return -1;
167}
168
169#endif // MAXSAT_HPP
Note: See TracBrowser for help on using the repository browser.