source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 5493

Last change on this file since 5493 was 5493, checked in by cameron, 2 years ago

Restore check-ins from the last several days

File size: 33.5 KB
RevLine 
[4878]1#include "distributivepass.h"
2
[5464]3#include <pablo/pablo_kernel.h>
[4878]4#include <pablo/codegenstate.h>
[5464]5#include <pablo/branch.h>
6#include <pablo/pe_string.h>
7#include <pablo/pe_integer.h>
8#include <pablo/pe_zeroes.h>
9#include <pablo/pe_ones.h>
10#include <pablo/boolean.h>
[5486]11#include <pablo/pe_var.h>
[4880]12#include <boost/container/flat_set.hpp>
[4878]13#include <boost/container/flat_map.hpp>
[5486]14#include <boost/range/adaptor/reversed.hpp>
[4878]15#include <boost/graph/adjacency_list.hpp>
[5464]16#include <boost/graph/topological_sort.hpp>
17#include <boost/function_output_iterator.hpp>
[5486]18#include <set>
[4878]19
[5464]20#include <boost/graph/strong_components.hpp>
21#include <llvm/Support/raw_ostream.h>
[4922]22
[4878]23using namespace boost;
24using namespace boost::container;
[5464]25using namespace llvm;
[4878]26
[5464]27using TypeId = pablo::PabloAST::ClassTypeId;
28using VertexData = std::pair<pablo::PabloAST *, TypeId>;
[4878]29
[5486]30using Graph = adjacency_list<vecS, vecS, bidirectionalS, VertexData, pablo::PabloAST *>;
[4922]31using Vertex = Graph::vertex_descriptor;
[5486]32using in_edge_iterator = graph_traits<Graph>::in_edge_iterator;
33using out_edge_iterator = graph_traits<Graph>::out_edge_iterator;
[5464]34
[4880]35using VertexSet = std::vector<Vertex>;
[4878]36using Biclique = std::pair<VertexSet, VertexSet>;
37using BicliqueSet = std::vector<Biclique>;
38using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
39using DistributionSets = std::vector<DistributionSet>;
40
[5464]41using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
[4878]42
[5464]43namespace pablo {
44
45
[4878]46/** ------------------------------------------------------------------------------------------------------------- *
[5464]47 * @brief printGraph
[4878]48 ** ------------------------------------------------------------------------------------------------------------- */
[5464]49static void printGraph(const Graph & G, const std::string & name, llvm::raw_ostream & out) {
[4878]50
[5464]51    std::vector<unsigned> c(num_vertices(G));
52    strong_components(G, make_iterator_property_map(c.begin(), get(vertex_index, G), c[0]));
[4878]53
[5464]54    out << "digraph " << name << " {\n";
55    for (auto u : make_iterator_range(vertices(G))) {
56        if (in_degree(u, G) == 0 && out_degree(u, G) == 0) {
57            continue;
[4878]58        }
[5464]59        out << "v" << u << " [label=\"" << u << ": ";
60        TypeId typeId;
61        PabloAST * expr;
62        std::tie(expr, typeId) = G[u];
63        bool temporary = false;
64        bool error = false;
65        if (expr == nullptr || (typeId != expr->getClassTypeId() && typeId != TypeId::Var)) {
66            temporary = true;
67            switch (typeId) {
68                case TypeId::And:
69                    out << "And";
70                    break;
71                case TypeId::Or:
72                    out << "Or";
73                    break;
74                case TypeId::Xor:
75                    out << "Xor";
76                    break;
77                case TypeId::Not:
78                    out << "Not";
79                    break;
80                default:
81                    out << "???";
82                    error = true;
83                    break;
[4922]84            }
[5464]85            if (expr) {
86                out << " ("; expr->print(out); out << ")";
87            }
88        } else {
89            expr->print(out);
[4922]90        }
[5464]91        out << "\"";
92        if (typeId == TypeId::Var) {
93            out << " style=dashed";
94        }
95        if (error) {
96            out << " color=red";
97        } else if (temporary) {
98            out << " color=blue";
99        }
100        out << "];\n";
[4878]101    }
[5464]102    for (auto e : make_iterator_range(edges(G))) {
103        const auto s = source(e, G);
104        const auto t = target(e, G);
105        out << "v" << s << " -> v" << t;
106        bool cyclic = (c[s] == c[t]);
107        if (G[e] || cyclic) {
108            out << " [";
109            PabloAST * const expr = G[e];
110            if (expr) {
111                out << "label=\"";
112                expr->print(out);
113                out << "\" ";
114             }
115             if (cyclic) {
116                out << "color=red ";
117             }
118             out << "]";
119        }
120        out << ";\n";
121    }
[4878]122
[5464]123    if (num_vertices(G) > 0) {
[4919]124
[5464]125        out << "{ rank=same;";
126        for (auto u : make_iterator_range(vertices(G))) {
127            if (in_degree(u, G) == 0 && out_degree(u, G) != 0) {
128                out << " v" << u;
[4878]129            }
130        }
[5464]131        out << "}\n";
132
133        out << "{ rank=same;";
134        for (auto u : make_iterator_range(vertices(G))) {
135            if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
136                out << " v" << u;
137            }
[4878]138        }
[5464]139        out << "}\n";
[4878]140
141    }
142
[5464]143    out << "}\n\n";
144    out.flush();
[4878]145}
146
[5464]147struct PassContainer {
[4880]148
[5486]149    enum Modification {
150        None, Trivial, Structural
151    };
152
[5464]153    /** ------------------------------------------------------------------------------------------------------------- *
154     * @brief run
155     *
156     * Based on the knowledge that:
157     *
158     *  (ASSOCIATIVITY)    A ∧ (B ∧ C) ⇔ (A ∧ B) ∧ C ⇔ A ∧ B ∧ C   and   A √ (B √ C) ⇔ (A √ B) √ C ⇔ A √ B √ C
159     *
160     *  (IDENTITY)         A √ 0 ⇔ A   and   A ∧ 1 = A
161     *
162     *  (ANNULMENT)        A ∧ 0 ⇔ 0   and   A √ 1 = 1
163     *
164     *  (IDEMPOTENT)       A √ (A √ B) ⇔ A √ B   and   A ∧ (A ∧ B) ⇔ A ∧ B
165     *
166     *  (ABSORBTION)       A √ (A ∧ B) ⇔ A ∧ (A √ B) ⇔ A
167     *
168     *  (COMPLEMENT)       A √ ¬A ⇔ 1   and   A ∧ ¬A = 0
169     *
170     *  (DISTRIBUTIVITY)   (A ∧ B) √ (A ∧ C) ⇔ A ∧ (B √ C)   and   (A √ B) ∧ (A √ C) ⇔ A √ (B ∧ C)
171     *
172     * Try to eliminate some of the unnecessary operations from the current Variadic expressions
173     ** ------------------------------------------------------------------------------------------------------------- */
[5486]174    void run(PabloKernel * const kernel) {
175        run(kernel->getEntryBlock());
[5464]176
[5486]177        printGraph(G, "G", errs());
178        if (simplifyGraph() == Structural) {
179            // rewriteAST(first, stmt);
180            printGraph(G, "O", errs());
181        }
[5464]182
[5486]183    }
[5464]184
[5486]185    void run(PabloBlock * const block) {
186        for (Statement * stmt : *block) {           
187            if (isa<Branch>(stmt)) {
188                addBranch(stmt);
[5464]189                run(cast<Branch>(stmt)->getBody());
190            } else {
191                addStatement(stmt);
[4880]192            }
193        }
194
[5486]195//        G.clear();
196//        M.clear();
197//        for (Statement * stmt : *block) {
198//            addStatement(stmt);
199//        }
[4880]200
[5486]201//        printGraph(G, "G", errs());
202//        if (simplifyGraph() == Structural) {
203//            // rewriteAST(first, stmt);
204//            printGraph(G, "O", errs());
[5464]205//        }
206
[5486]207    }
208
[5464]209    /** ------------------------------------------------------------------------------------------------------------- *
[5486]210     * @brief simplifyGraph
[5464]211     ** ------------------------------------------------------------------------------------------------------------- */
[5486]212    Modification simplifyGraph() {
213        Modification modified = None;
214        for (;;) {
215            const auto p1 = applyAssociativeIdentityAnnulmentLaws();
216            const auto p2 = applyAbsorbtionComplementIdempotentLaws();
217            const auto p3 = applyDistributivityLaw();
218            if (std::max(std::max(p1, p2), p3) != Structural) {
219                break;
220            }
221            modified = Structural;
222        }
223        return modified;
224    }
[5464]225
[5486]226protected:
[5464]227
[5486]228    /** ------------------------------------------------------------------------------------------------------------- *
229     * @brief applyAssociativeIdentityAnnulmentLaws
230     ** ------------------------------------------------------------------------------------------------------------- */
231    Modification applyAssociativeIdentityAnnulmentLaws() {
[5464]232
[5486]233        auto identityComparator = [this](const Vertex u, const Vertex v) -> bool {
234            const auto typeA = getType(u);
235            assert (typeA != TypeId::Var);
236            const auto typeB = getType(v);
237            assert (typeB != TypeId::Var);
238            if (LLVM_LIKELY(typeA != typeB)) {
239                using value_of = std::underlying_type<TypeId>::type;
240                return static_cast<value_of>(typeA) < static_cast<value_of>(typeB);
241            } else {
242                const auto degA = in_degree(u, G);
243                const auto degB = in_degree(v, G);
244                if (degA != degB) {
245                    return degA < degB;
246                } else {
247                    Vertex adjA[degA];
248                    Vertex adjB[degA];
249                    in_edge_iterator ei, ej;
250                    std::tie(ei, std::ignore) = in_edges(u, G);
251                    std::tie(ej, std::ignore) = in_edges(v, G);
252                    for (size_t i = 0; i < degA; ++i, ++ei, ++ej) {
253                        adjA[i] = source(*ei, G);
254                        adjB[i] = source(*ej, G);
255                    }
256                    std::sort(adjA, adjA + degA);
257                    std::sort(adjB, adjB + degA);
258                    for (size_t i = 0; i < degA; ++i) {
259                        if (adjA[i] < adjB[i]) {
260                            return true;
[5464]261                        }
262                    }
[5486]263                    return false;
[4880]264                }
[5486]265            }
266        };
[4880]267
[5486]268        flat_set<Vertex, decltype(identityComparator)> V(identityComparator);
269        V.reserve(num_vertices(G));
270
271        VertexSet ordering;
272        ordering.reserve(num_vertices(G));
273
274        topological_sort(G, std::back_inserter(ordering)); // note: ordering is in reverse topological order
275
276        Modification modified = None;
277
278        for (const auto u : boost::adaptors::reverse(ordering)) {
279            const TypeId typeId = getType(u);
280            if (isImmutable(typeId)) {
281                continue;
282            } else if (LLVM_UNLIKELY(typeId == TypeId::Zeroes || typeId == TypeId::Ones)) {
283                for(;;) {
284                    bool done = true;
[5464]285                    for (auto e : make_iterator_range(out_edges(u, G))) {
[5486]286                        const auto v = target(e, G);
287                        const auto targetTypeId = getType(v);
288                        if (LLVM_UNLIKELY(isAssociative(targetTypeId))) {
289
290                            errs() << " -- identity relationship\n";
291
292                            if (isIdentityRelation(typeId, targetTypeId)) {
293                                remove_edge(e, G);
294                            } else {
295                                setType(v, typeId == TypeId::And ? TypeId::Zeroes : TypeId::Ones);
296                                clear_in_edges(v, G);
297                            }
298                            done = false;
299                            modified = Structural;
[5464]300                            break;
301                        }
[4880]302                    }
[5486]303                    if (done) {
304                        break;
305                    }
[4880]306                }
[5486]307            } else if (isAssociative(typeId)) {
308                if (LLVM_UNLIKELY(in_degree(u, G) == 0)) {
309                    setType(u, TypeId::Zeroes);
310                } else {
311                    // An associative operation with only one element is always equivalent to the element
312                    bool contractable = true;
313                    if (LLVM_LIKELY(in_degree(u, G) > 1)) {
314                        for (auto e : make_iterator_range(out_edges(u, G))) {
315                            if (LLVM_LIKELY(typeId != getType(target(e, G)))) {
316                                contractable = false;
317                                break;
318                            }
[5464]319                        }
320                    }
[5486]321                    if (LLVM_UNLIKELY(contractable)) {
322                        for (auto ei : make_iterator_range(in_edges(u, G))) {
323                            for (auto ej : make_iterator_range(out_edges(u, G))) {
324                                addEdge(source(ei, G), target(ej, G), G[ei]);
325                            }
326                        }
327                        removeVertex(u);
328                        modified = std::max(modified, Trivial);
329                        continue;
330                    }
331
332                    if (LLVM_UNLIKELY(typeId == TypeId::Xor)) {
333                        // TODO:: (A ⊕ ¬B) = (A ⊕ (B ⊕ 1)) = ¬(A ⊕ B)
334
335                    }
336
337
338
[5464]339                }
[4880]340            }
[5464]341
[5486]342            assert (getType(u) != TypeId::Var);
[5464]343
[5486]344            const auto f = V.insert(u);
345            if (LLVM_UNLIKELY(!f.second)) {
346                const auto v = *f.first;
347
348                errs() << " -- replacing " << u << " with " << v << "\n";
349
350                for (auto e : make_iterator_range(out_edges(u, G))) {
351                    addEdge(v, target(e, G), G[e]);
352                }
353                removeVertex(u);
354                modified = Structural;
355            }
356        }
[5464]357        return modified;
[4880]358    }
359
[5464]360    /** ------------------------------------------------------------------------------------------------------------- *
361     * @brief applyAbsorbtionComplementIdempotentLaws
362     ** ------------------------------------------------------------------------------------------------------------- */
[5486]363    Modification applyAbsorbtionComplementIdempotentLaws() {
364        Modification modified = None;
[5464]365        for (const Vertex u : make_iterator_range(vertices(G))) {
366            const TypeId typeId = getType(u);
367            if (isDistributive(typeId)) {
368restart_loop:   in_edge_iterator ei_begin, ei_end;
369                std::tie(ei_begin, ei_end) = in_edges(u, G);
370                for (auto ei = ei_begin; ei != ei_end; ++ei) {
371                    const auto v = source(*ei, G);
372                    const auto innerTypeId = getType(v);
373                    if (isDistributive(innerTypeId) || innerTypeId == TypeId::Not) {
374                        in_edge_iterator ek_begin, ek_end;
375                        std::tie(ek_begin, ek_end) = in_edges(v, G);
376                        for (auto ej = ei_begin; ej != ei_end; ++ej) {
377                            for (auto ek = ek_begin; ek != ek_end; ++ek) {
378                                if (LLVM_UNLIKELY(source(*ej, G) == source(*ek, G))) {
[5486]379                                    modified = Structural;
[5464]380                                    if (LLVM_UNLIKELY(innerTypeId == TypeId::Not)) {
381                                        // complement
382                                        setType(u, typeId == TypeId::And ? TypeId::Zeroes : TypeId::Ones);
383                                        clear_in_edges(u, G);
384                                        goto abort_loop;
385                                    } else {
386                                        if (LLVM_LIKELY(innerTypeId != typeId)) {
387                                            // idempotent
388                                            remove_edge(*ei, G);
389                                        } else {
390                                            // absorbtion
391                                            remove_edge(*ej, G);
[5486]392                                        }                                       
[5464]393                                        // this seldom occurs so if it does, just restart the process rather than
394                                        // trying to keep the iterators valid.
395                                        goto restart_loop;
396                                    }
397                                }
398                            }
399                        }
400                    }
401                }
[4880]402            }
[5464]403            abort_loop:;
[4880]404        }
[5464]405        return modified;
[4880]406    }
407
[5464]408    /** ------------------------------------------------------------------------------------------------------------- *
[5486]409     * @brief identifyDistributableVertices
[5464]410     *
411     * Let (u) ∈ V(G) be a conjunction ∧ or disjunction √ and (v) be a ∧ or √ and the opposite type of (u). If (u,v) ∈
412     * E(G) and all outgoing edges of (v) lead to a vertex of the same type, add (u), (v) and any vertex (w) in which
413     * (w,v) ∈ E(G) to the distribution graph H as well as the edges indicating their relationships within G.
414     *
415     *                  (?) (?) (?) <-- w1, w2, ...
416     *                     \ | /
417     *                      (v)   <-- v
418     *                     /   \
419     *            u --> (∧)     (∧)
420     *
421     ** ------------------------------------------------------------------------------------------------------------- */
[5486]422    void identifyDistributableVertices() {
[4878]423
[5486]424        assert (D.empty() && L.empty());
[4878]425
[5464]426        for (const Vertex u : make_iterator_range(vertices(G))) {
427            const TypeId outerTypeId = getType(u);
428            if (isDistributive(outerTypeId)) {
[5486]429                bool beneficial = true;
[5464]430                const TypeId innerTypeId = oppositeTypeId(outerTypeId);
[5486]431                for (auto e : make_iterator_range(out_edges(u, G))) {
432                    const Vertex v = target(e, G);
433                    if (LLVM_UNLIKELY(getType(v) != innerTypeId)) {
434                        beneficial = false;
435                        break;
436                    }
437                }
438                if (beneficial) {
439                    for (auto e : make_iterator_range(out_edges(u, G))) {
440                        const auto v = target(e, G);
441                        const auto f = std::lower_bound(D.begin(), D.end(), v);
442                        if (f == D.end() || *f != v) {
443                            D.insert(f, v);
444                            assert (std::is_sorted(D.begin(), D.end()));
[5464]445                        }
[5486]446                    }
447                    for (auto e : make_iterator_range(in_edges(u, G))) {
448                        const auto v = source(e, G);
449                        const auto f = std::lower_bound(L.begin(), L.end(), v);
450                        if (f == L.end() || *f != v) {
451                            L.insert(f, v);
452                            assert (std::is_sorted(L.begin(), L.end()));
[5464]453                        }
[4922]454                    }
455                }
[5486]456            }
457        }
458
459        // D = D - L
460
461        if (!L.empty()) {
462            if (!D.empty()) {
463                auto di = D.begin(), li = L.begin(), li_end = L.end();
464                for (;;) {
465                    if (*li < *di) {
466                        if (++li == li_end) {
467                            break;
[5464]468                        }
[5486]469                    } else {
470                        if (*di < *li) {
471                            ++di;
472                        } else {
473                            di = D.erase(di);
[5464]474                        }
[5486]475                        if (di == D.end()) {
476                            break;
477                        }
[5464]478                    }
[4922]479                }
480            }
[5486]481            L.clear();
[4887]482        }
[4922]483    }
[4878]484
[5464]485    /** ------------------------------------------------------------------------------------------------------------- *
[5486]486     * @brief applyDistributivityLaw
[5464]487     ** ------------------------------------------------------------------------------------------------------------- */
[5486]488    Modification applyDistributivityLaw() {
[4922]489
[5486]490        identifyDistributableVertices();
[4927]491
[5486]492        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
493        if (D.empty()) {
494            return None;
495        }
[4922]496
[5486]497        Modification modified = None;
[4922]498
[5486]499        const auto lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(D)), 1);
[4927]500
[5486]501        for (auto & lower : lowerSet) {
502            const auto upperSet = independentCliqueSets<0>(enumerateBicliques(std::get<1>(lower)), 2);
503            for (const auto & upper : upperSet) {
[4878]504
[5486]505                const auto & sources = std::get<1>(upper);
506                const auto & intermediary = std::get<0>(upper);
507                const auto & sinks = std::get<0>(lower);
[4878]508
[4887]509
[4878]510
[5486]511                const auto outerTypeId = getType(sinks.front());
512                const auto innerTypeId = oppositeTypeId(outerTypeId);
[4922]513
[5486]514                errs() << " -- distributing\n";
[4922]515
[5486]516                // Update G to match the desired change
517                const auto x = makeVertex(outerTypeId);
518                const auto y = makeVertex(innerTypeId);
[4922]519
[5486]520                for (const auto i : intermediary) {
521                    assert (getType(i) == innerTypeId);
522                    for (const Vertex t : sinks) {
523                        assert (getType(t) == outerTypeId);
524                        remove_edge(i, t, G);
[4922]525                    }
[5486]526                    addEdge(i, x);
527                }
[5464]528
[5486]529                for (const Vertex s : sources) {
530                    for (const Vertex i : intermediary) {
531                        remove_edge(s, i, G);
[5464]532                    }
[5486]533                    addEdge(s, y);
534                }
535                addEdge(x, y);
[5464]536
[5486]537                for (const Vertex t : sinks) {
538                    addEdge(y, t, std::get<0>(G[t]));
[4887]539                }
[5464]540
[5486]541                modified = Structural;
[5464]542            }
543        }
544
[5486]545        D.clear();
[5464]546
547        return modified;
548    }
549
550
551    /** ------------------------------------------------------------------------------------------------------------- *
552     * @brief enumerateBicliques
553     *
554     * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
555     * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
556     * bipartition A and their adjacencies to be in B.
557      ** ------------------------------------------------------------------------------------------------------------- */
558
559    BicliqueSet enumerateBicliques(const VertexSet & A) {
560        using IntersectionSets = std::set<VertexSet>;
561
562        IntersectionSets B1;
563        for (auto u : A) {
[5486]564            if (in_degree(u, G) > 0) {
[5464]565                VertexSet incomingAdjacencies;
[5486]566                incomingAdjacencies.reserve(in_degree(u, G));
567                for (auto e : make_iterator_range(in_edges(u, G))) {
568                    incomingAdjacencies.push_back(source(e, G));
[5464]569                }
570                std::sort(incomingAdjacencies.begin(), incomingAdjacencies.end());
571                B1.insert(std::move(incomingAdjacencies));
572            }
573        }
574
575        IntersectionSets B(B1);
576
577        IntersectionSets Bi;
578
579        VertexSet clique;
580        for (auto i = B1.begin(); i != B1.end(); ++i) {
581            for (auto j = i; ++j != B1.end(); ) {
582                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
583                if (clique.size() > 0) {
584                    if (B.count(clique) == 0) {
585                        Bi.insert(clique);
[4922]586                    }
[5464]587                    clique.clear();
[4922]588                }
[5464]589            }
590        }
[4887]591
[5464]592        for (;;) {
593            if (Bi.empty()) {
594                break;
595            }
596            B.insert(Bi.begin(), Bi.end());
597            IntersectionSets Bk;
598            for (auto i = B1.begin(); i != B1.end(); ++i) {
599                for (auto j = Bi.begin(); j != Bi.end(); ++j) {
600                    std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
601                    if (clique.size() > 0) {
602                        if (B.count(clique) == 0) {
603                            Bk.insert(clique);
604                        }
605                        clique.clear();
606                    }
[4922]607                }
[5464]608            }
609            Bi.swap(Bk);
610        }
611
612        BicliqueSet cliques;
613        cliques.reserve(B.size());
614        for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
615            VertexSet Ai(A);
616            for (const Vertex u : *Bi) {
617                VertexSet Aj;
[5486]618                Aj.reserve(out_degree(u, G));
619                for (auto e : make_iterator_range(out_edges(u, G))) {
620                    Aj.push_back(target(e, G));
[4887]621                }
[5464]622                std::sort(Aj.begin(), Aj.end());
623                VertexSet Ak;
624                Ak.reserve(std::min(Ai.size(), Aj.size()));
625                std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
626                Ai.swap(Ak);
627            }
628            assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
629            cliques.emplace_back(std::move(Ai), std::move(*Bi));
630        }
[5493]631        return cliques;
[5464]632    }
633
634
635    /** ------------------------------------------------------------------------------------------------------------- *
636     * @brief independentCliqueSets
637     ** ------------------------------------------------------------------------------------------------------------- */
638    template <unsigned side>
639    BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
640
641
642        const auto l = cliques.size();
643        IndependentSetGraph I(l);
644
645        // Initialize our weights
646        for (unsigned i = 0; i != l; ++i) {
647            I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
648        }
649
650        // Determine our constraints
651        for (unsigned i = 0; i != l; ++i) {
652            for (unsigned j = i + 1; j != l; ++j) {
653                if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
654                    boost::add_edge(i, j, I);
[4922]655                }
[5464]656            }
657        }
658
659        // Use the greedy algorithm to choose our independent set
660        VertexSet selected;
661        for (;;) {
662            unsigned w = 0;
663            Vertex u = 0;
664            for (unsigned i = 0; i != l; ++i) {
665                if (I[i] > w) {
666                    w = I[i];
667                    u = i;
[4922]668                }
[5464]669            }
670            if (w < minimum) break;
671            selected.push_back(u);
672            I[u] = 0;
673            for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
674                I[v] = 0;
675            }
676        }
[4927]677
[5464]678        // Sort the selected list and then remove the unselected cliques
679        std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
680        auto end = cliques.end();
681        for (const unsigned offset : selected) {
682            end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
683        }
684        cliques.erase(cliques.begin(), end);
685
686        return std::move(cliques);
687    }
688
689    /** ------------------------------------------------------------------------------------------------------------- *
690     * @brief removeUnhelpfulBicliques
691     *
692     * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
693     * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
694     * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
695     ** ------------------------------------------------------------------------------------------------------------- */
696    BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques) {
697        for (auto ci = cliques.begin(); ci != cliques.end(); ) {
698            const auto cardinalityA = std::get<0>(*ci).size();
699            VertexSet & B = std::get<1>(*ci);
700            for (auto bi = B.begin(); bi != B.end(); ) {
[5486]701                if (out_degree(*bi, G) == cardinalityA) {
[5464]702                    ++bi;
703                } else {
704                    bi = B.erase(bi);
705                }
[4887]706            }
[5464]707            if (B.size() > 1) {
708                ++ci;
709            } else {
710                ci = cliques.erase(ci);
711            }
712        }
713        return std::move(cliques);
714    }
[4887]715
[5464]716    /** ------------------------------------------------------------------------------------------------------------- *
[5486]717     * @brief makeVertex
[5464]718     ** ------------------------------------------------------------------------------------------------------------- */
719    Vertex makeVertex(const TypeId typeId, PabloAST * const expr = nullptr) {
720        return add_vertex(std::make_pair(expr, typeId), G);
721    }
722
723    /** ------------------------------------------------------------------------------------------------------------- *
724     * @brief addExpression
725     ** ------------------------------------------------------------------------------------------------------------- */
726    Vertex addExpression(PabloAST * const expr) {
727        const auto f = M.find(expr);
728        if (LLVM_LIKELY(f != M.end())) {
729            return f->second;
[4880]730        }
[5464]731        TypeId typeId = TypeId::Var;
732        if (isa<Zeroes>(expr)) {
733            typeId = TypeId::Zeroes;
734        } else if (isa<Ones>(expr)) {
735            typeId = TypeId::Ones;
736        }
737        const auto u = makeVertex(typeId, expr);
738        M.emplace(expr, u);
[5486]739        if (LLVM_UNLIKELY(isa<Not>(expr))) {
740            PabloAST * const negated = cast<Not>(expr)->getExpr();
741            addEdge(addExpression(negated), u, negated);
742        }
[5464]743        return u;
[4880]744    }
[4927]745
[5464]746    /** ------------------------------------------------------------------------------------------------------------- *
747     * @brief addStatement
748     ** ------------------------------------------------------------------------------------------------------------- */
[5486]749    Vertex addStatement(Statement * const stmt) {
[5464]750        assert (M.count(stmt) == 0);
751        const auto typeId = stmt->getClassTypeId();
[5486]752        if (LLVM_UNLIKELY(typeId == TypeId::Sel)) {
753
754            // expand Sel (C,T,F) statements into (C ∧ T) √ (C ∧ F)
755
756            const auto c = addExpression(cast<Sel>(stmt)->getCondition());
757            const auto t = addExpression(cast<Sel>(stmt)->getTrueExpr());
758            const auto l = makeVertex(TypeId::And);
759            addEdge(c, l);
760            addEdge(t, l);
761            const auto n = makeVertex(TypeId::Not);
762            addEdge(c, n);
763            const auto r = makeVertex(TypeId::And);
764            const auto f = addExpression(cast<Sel>(stmt)->getFalseExpr());
765            addEdge(n, r);
766            addEdge(f, r);
767            const auto u = makeVertex(TypeId::Or, stmt);
768            M.emplace(stmt, u);
769            addEdge(l, u);
770            addEdge(r, u);
771
772            return u;
773
774        } else {
775
776            const auto u = makeVertex(typeId, stmt);
777            M.emplace(stmt, u);
778            for (unsigned i = 0; i != stmt->getNumOperands(); ++i) {
779                PabloAST * const op = stmt->getOperand(i);
780                if (LLVM_UNLIKELY(isa<String>(op))) {
781                    continue;
782                }
783                addEdge(addExpression(op), u, op);
[5464]784            }
[5486]785
786            return u;
787        }
788
789    }
790
791    /** ------------------------------------------------------------------------------------------------------------- *
792     * @brief addBranch
793     ** ------------------------------------------------------------------------------------------------------------- */
794    void addBranch(Statement * const br) {
795        const auto u = addStatement(br);
796        for (auto escaped : cast<Branch>(br)->getEscaped()) {
797            addEdge(u, addExpression(escaped), escaped);
798        }
799    }
800
801
802    /** ------------------------------------------------------------------------------------------------------------- *
803     * @brief addEdge
804     ** ------------------------------------------------------------------------------------------------------------- */
805    void addEdge(const Vertex u, const Vertex v, PabloAST * const value = nullptr) {
806        const auto typeId = getType(v);
807        if (isAssociative(typeId)) {
808            for (auto e : make_iterator_range(in_edges(u, G))) {
809                if (LLVM_UNLIKELY(source(e, G) == u)) {
810                    if (LLVM_LIKELY(isDistributive(typeId))) {
811                        G[e] = std::max(G[e], value);
812                    } else {
813                        remove_edge(e, G);
814                    }
815                    return;
816                }
[5464]817            }
818        }
[5486]819        boost::add_edge(u, v, value, G);
[5464]820    }
[4878]821
[5464]822    /** ------------------------------------------------------------------------------------------------------------- *
[5486]823     * @brief removeVertex
824     ** ------------------------------------------------------------------------------------------------------------- */
825    void removeVertex(const Vertex u) {
826        clear_vertex(u, G);
827        setType(u, TypeId::Var);
828    }
829
830    /** ------------------------------------------------------------------------------------------------------------- *
[5464]831     * @brief intersects
832     ** ------------------------------------------------------------------------------------------------------------- */
833    template <class Type>
834    inline bool intersects(Type & A, Type & B) {
835        auto first1 = A.begin(), last1 = A.end();
836        auto first2 = B.begin(), last2 = B.end();
837        while (first1 != last1 && first2 != last2) {
838            if (*first1 < *first2) {
839                ++first1;
840            } else if (*first2 < *first1) {
841                ++first2;
842            } else {
843                return true;
844            }
[4878]845        }
[5464]846        return false;
[4878]847    }
848
[5464]849    TypeId getType(const Vertex u) {
850        return std::get<1>(G[u]);
851    }
852
853    void setType(const Vertex u, const TypeId typeId) {
854        std::get<1>(G[u]) = typeId;
855    }
856
857    static bool isIdentityRelation(const TypeId a, const TypeId b) {
858        return !((a == TypeId::Zeroes) ^ (b == TypeId::Or));
859    }
860
861    static bool isAssociative(const TypeId typeId) {
862        return (isDistributive(typeId) || typeId == TypeId::Xor);
863    }
864
865    static bool isDistributive(const TypeId typeId) {
866        return (typeId == TypeId::And || typeId == TypeId::Or);
867    }
868
[5486]869    static bool isImmutable(const TypeId typeId) {
870        return (typeId == TypeId::Var || typeId == TypeId::Assign || typeId == TypeId::Extract);
871    }
872
[5464]873    static TypeId oppositeTypeId(const TypeId typeId) {
874        assert (isDistributive(typeId));
875        return (typeId == TypeId::And) ? TypeId::Or : TypeId::And;
876    }
877
878private:
879
880    Graph G;
881    flat_map<pablo::PabloAST *, Vertex> M;
[5486]882    VertexSet D;
883    VertexSet L;
[5464]884
885};
886
[4880]887/** ------------------------------------------------------------------------------------------------------------- *
[4887]888 * @brief optimize
[4880]889 ** ------------------------------------------------------------------------------------------------------------- */
[5464]890bool DistributivePass::optimize(PabloKernel * const kernel) {
[5486]891    #ifdef NDEBUG
892    report_fatal_error("DistributivePass is unsupported");
893    #endif
[5464]894    PassContainer C;
[5486]895    C.run(kernel);
[5464]896    return true;
[4880]897}
[4878]898
899}
Note: See TracBrowser for help on using the repository browser.