source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 4922

Last change on this file since 4922 was 4922, checked in by nmedfort, 4 years ago

Incorporated a few common case boolean optimizations in the Simplifier.

File size: 19.0 KB
RevLine 
[4878]1#include "distributivepass.h"
2
3#include <pablo/codegenstate.h>
[4880]4#include <pablo/analysis/pabloverifier.hpp>
5#include <pablo/optimizers/pablo_simplifier.hpp>
[4887]6#include <pablo/passes/flattenassociativedfg.h>
[4880]7#include <boost/container/flat_set.hpp>
[4878]8#include <boost/container/flat_map.hpp>
9#include <boost/graph/adjacency_list.hpp>
10
[4922]11#include <pablo/printer_pablos.h>
12#include <iostream>
13
[4878]14using namespace boost;
15using namespace boost::container;
16
17namespace pablo {
18
[4922]19using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
20using Vertex = Graph::vertex_descriptor;
21using Map = flat_map<PabloAST *, Vertex>;
[4880]22using VertexSet = std::vector<Vertex>;
[4878]23using Biclique = std::pair<VertexSet, VertexSet>;
24using BicliqueSet = std::vector<Biclique>;
25using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
26using DistributionSets = std::vector<DistributionSet>;
27using TypeId = PabloAST::ClassTypeId;
28
29/** ------------------------------------------------------------------------------------------------------------- *
30 * @brief intersects
31 ** ------------------------------------------------------------------------------------------------------------- */
32template <class Type>
33inline bool intersects(const Type & A, const Type & B) {
34    auto first1 = A.begin(), last1 = A.end();
35    auto first2 = B.begin(), last2 = B.end();
36    while (first1 != last1 && first2 != last2) {
37        if (*first1 < *first2) {
38            ++first1;
39        } else if (*first2 < *first1) {
40            ++first2;
41        } else {
42            return true;
43        }
44    }
45    return false;
46}
47
48/** ------------------------------------------------------------------------------------------------------------- *
49 * @brief independentCliqueSets
50 ** ------------------------------------------------------------------------------------------------------------- */
[4922]51static BicliqueSet && independentCliqueSets(BicliqueSet && bicliques, const bool uppsetSet) {
[4878]52    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
53
[4919]54    const auto l = bicliques.size();
[4878]55    IndependentSetGraph I(l);
56
[4922]57
[4919]58    // Initialize our weights and determine the constraints
[4922]59    if (uppsetSet) {
60        for (auto i = bicliques.begin(); i != bicliques.end(); ++i) {
61            I[std::distance(bicliques.begin(), i)] = std::pow(std::get<0>(*i).size(), 2);
62            for (auto j = i; ++j != bicliques.end(); ) {
63                if (intersects(i->first, j->first)) {
64                    add_edge(std::distance(bicliques.begin(), i), std::distance(bicliques.begin(), j), I);
65                }
[4878]66            }
67        }
[4922]68    } else {
69        for (auto i = bicliques.begin(); i != bicliques.end(); ++i) {
70            I[std::distance(bicliques.begin(), i)] = std::pow(std::get<1>(*i).size(), 2);
71            for (auto j = i; ++j != bicliques.end(); ) {
72                if (intersects(i->first, j->first) && intersects(i->second, j->second)) {
73                    add_edge(std::distance(bicliques.begin(), i), std::distance(bicliques.begin(), j), I);
74                }
75            }
76        }
[4878]77    }
78
[4919]79
[4878]80    // Use the greedy algorithm to choose our independent set
81    VertexSet selected;
82    for (;;) {
83        unsigned w = 0;
84        Vertex u = 0;
85        for (unsigned i = 0; i != l; ++i) {
86            if (I[i] > w) {
87                w = I[i];
88                u = i;
89            }
90        }
[4922]91        if (w < (uppsetSet ? 2 : 1)) break;
[4878]92        selected.push_back(u);
93        I[u] = 0;
94        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
95            I[v] = 0;
96        }
97    }
98
99    // Sort the selected list and then remove the unselected cliques
100    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
[4919]101    auto end = bicliques.end();
[4878]102    for (const unsigned offset : selected) {
[4919]103        end = bicliques.erase(bicliques.begin() + offset + 1, end) - 1;
[4878]104    }
[4919]105    bicliques.erase(bicliques.begin(), end);
[4878]106
[4919]107    return std::move(bicliques);
[4878]108}
109
110/** ------------------------------------------------------------------------------------------------------------- *
[4880]111 * @brief enumerateBicliques
112 *
113 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
114 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
115 * bipartition A and their adjacencies to be in B.
116  ** ------------------------------------------------------------------------------------------------------------- */
[4922]117static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A, const unsigned min) {
[4880]118    using IntersectionSets = std::set<VertexSet>;
119
120    IntersectionSets B1;
[4922]121    VertexSet tmp;
[4880]122    for (auto u : A) {
123        if (in_degree(u, G) > 0) {
[4922]124            tmp.reserve(in_degree(u, G));
[4880]125            for (auto e : make_iterator_range(in_edges(u, G))) {
[4922]126                tmp.push_back(source(e, G));
[4880]127            }
[4922]128            if (tmp.size() >= min) {
129                std::sort(tmp.begin(), tmp.end());
130                B1.emplace(tmp.begin(), tmp.end());
131            }
132            tmp.clear();
[4880]133        }
134    }
135
136    IntersectionSets B(B1);
137
138    IntersectionSets Bi;
139    for (auto i = B1.begin(); i != B1.end(); ++i) {
140        for (auto j = i; ++j != B1.end(); ) {
[4922]141            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(tmp));
142            if (tmp.size() >= min) {
143                if (B.count(tmp) == 0) {
144                    Bi.emplace(tmp.begin(), tmp.end());
[4880]145                }
146            }
[4922]147            tmp.clear();
[4880]148        }
149    }
150
151    for (;;) {
152        if (Bi.empty()) {
153            break;
154        }
155        B.insert(Bi.begin(), Bi.end());
156        IntersectionSets Bk;
157        for (auto i = B1.begin(); i != B1.end(); ++i) {
158            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
[4922]159                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(tmp));
160                if (tmp.size() >= min) {
161                    if (B.count(tmp) == 0) {
162                        Bk.emplace(tmp.begin(), tmp.end());
[4880]163                    }
164                }
[4922]165                tmp.clear();
[4880]166            }
167        }
168        Bi.swap(Bk);
169    }
170
171    BicliqueSet cliques;
172    cliques.reserve(B.size());
173    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
[4922]174        assert (Bi->size() >= min);
[4880]175        VertexSet Ai(A);
176        for (const Vertex u : *Bi) {
177            VertexSet Aj;
178            Aj.reserve(out_degree(u, G));
179            for (auto e : make_iterator_range(out_edges(u, G))) {
180                Aj.push_back(target(e, G));
181            }
182            std::sort(Aj.begin(), Aj.end());
183            VertexSet Ak;
184            Ak.reserve(std::min(Ai.size(), Aj.size()));
185            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
186            Ai.swap(Ak);
187        }
188        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
189        cliques.emplace_back(std::move(Ai), std::move(*Bi));
190    }
191    return std::move(cliques);
192}
193
194/** ------------------------------------------------------------------------------------------------------------- *
[4878]195 * @brief removeUnhelpfulBicliques
196 *
197 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
198 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
199 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
200 ** ------------------------------------------------------------------------------------------------------------- */
[4922]201static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G) {
[4878]202    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
203        const auto cardinalityA = std::get<0>(*ci).size();
204        VertexSet & B = std::get<1>(*ci);
205        for (auto bi = B.begin(); bi != B.end(); ) {
[4896]206            if (G[*bi]->getNumUses() == cardinalityA) {
[4878]207                ++bi;
208            } else {
209                bi = B.erase(bi);
210            }
211        }
212        if (B.size() > 1) {
213            ++ci;
214        } else {
215            ci = cliques.erase(ci);
216        }
217    }
218    return std::move(cliques);
219}
220
221/** ------------------------------------------------------------------------------------------------------------- *
222 * @brief safeDistributionSets
223 ** ------------------------------------------------------------------------------------------------------------- */
[4922]224inline static DistributionSets safeDistributionSets(const Graph & G, const VertexSet & A) {
[4878]225    DistributionSets T;
[4922]226    BicliqueSet lowerSet = independentCliqueSets(removeUnhelpfulBicliques(enumerateBicliques(G, A, 1), G), false);
[4878]227    for (Biclique & lower : lowerSet) {
[4922]228        BicliqueSet upperSet = independentCliqueSets(enumerateBicliques(G, std::get<1>(lower), 2), true);
[4878]229        for (Biclique & upper : upperSet) {
230            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
231        }
232    }
233    return std::move(T);
234}
235
236/** ------------------------------------------------------------------------------------------------------------- *
[4922]237 * @brief scopeDepthOf
[4878]238 ** ------------------------------------------------------------------------------------------------------------- */
[4922]239inline unsigned scopeDepthOf(const PabloBlock * block) {
240    unsigned depth = 0;
241    for (; block; ++depth, block = block->getParent());
242    return depth;
243}
[4878]244
[4922]245/** ------------------------------------------------------------------------------------------------------------- *
246 * @brief findInsertionPoint
247 ** ------------------------------------------------------------------------------------------------------------- */
248inline PabloBlock * findInsertionPoint(const VertexSet & users, const Graph & G) {
249    std::vector<PabloBlock *> scopes(0);
250    scopes.reserve(users.size());
251    for (Vertex u : users) {
252        PabloBlock * const scope = cast<Statement>(G[u])->getParent(); assert (scope);
253        if (std::find(scopes.begin(), scopes.end(), scope) == scopes.end()) {
254            scopes.push_back(scope);
255        }
256    }
257    while (scopes.size() > 1) {
258        // Find the LCA of both scopes then add the LCA back to the list of scopes.
259        PabloBlock * scope1 = scopes.back();
260        scopes.pop_back();
261        PabloBlock * scope2 = scopes.back();
262        scopes.pop_back();
263        assert (scope1 != scope2);
264        unsigned depth1 = scopeDepthOf(scope1);
265        unsigned depth2 = scopeDepthOf(scope2);
266        // If one of these scopes is nested deeper than the other, scan upwards through
267        // the scope tree until both scopes are at the same depth.
268        while (depth1 > depth2) {
269            scope1 = scope1->getParent();
270            --depth1;
271        }
272        while (depth1 < depth2) {
273            scope2 = scope2->getParent();
274            --depth2;
275        }
276        assert (depth1 == depth2);
277        // Then iteratively step backwards until we find a matching scopes; this must be
278        // the LCA of our original pair.
279        while (scope1 != scope2) {
280            scope1 = scope1->getParent();
281            scope2 = scope2->getParent();
282        }
283        assert (scope1 && scope2);
284        if (std::find(scopes.begin(), scopes.end(), scope1) == scopes.end()) {
285            scopes.push_back(scope1);
286        }
287    }
288    assert (scopes.size() == 1);
289    PabloBlock * const root = scopes.front();
290    // Now that we know the common scope of these users, test which statement is the first to require it.
291    flat_set<Statement *> usages;
292    usages.reserve(users.size());
293    for (Vertex u : users) {
294        Statement * user = cast<Statement>(G[u]);
295        PabloBlock * scope = user->getParent();
296        while (scope != root) {
297            assert (scope);
298            user = scope->getBranch();
299            scope = scope->getParent();
300        }
301        usages.insert(user);
302    }
303    Statement * ip = nullptr;
304    for (Statement * stmt : *root) {
305        if (usages.count(stmt)) {
306            ip = stmt->getPrevNode();
307            break;
308        }
309    }
310    assert (ip);
311    root->setInsertPoint(ip);
312    return root;
313}
[4878]314
[4922]315/** ------------------------------------------------------------------------------------------------------------- *
316 * @brief computeDistributionGraph
317 ** ------------------------------------------------------------------------------------------------------------- */
318static inline void computeDistributionGraph(Variadic * const expr, Graph & G, VertexSet & A) {
[4878]319
[4922]320    const TypeId outerTypeId = expr->getClassTypeId();
321    const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
[4878]322
[4922]323    assert (isa<And>(expr) || isa<Or>(expr));
[4878]324
[4922]325    Map M;
326    for (unsigned i = 0; i != expr->getNumOperands(); ++i) {
327        PabloAST * const op = expr->getOperand(i);
328        if (op->getClassTypeId() == innerTypeId) {
329            bool distributable = true;
330            for (PabloAST * user : op->users()) {
331                // Early check to see whether it'd be beneficial to distribute it. If this fails, we'd have
332                // to compute the operand's value anyway, so just ignore this operand.
333                if (user->getClassTypeId() != outerTypeId) {
334                    distributable = false;
335                    break;
336                }
337            }
338            if (distributable) {
339                const Vertex u = add_vertex(op, G);
340                for (PabloAST * user : op->users()) {
341                    const auto f = M.find(user);
342                    Vertex v = 0;
343                    if (LLVM_LIKELY(f != M.end())) {
344                        v = f->second;
345                    } else {
346                        v = add_vertex(user, G);
347                        M.emplace(user, v);
348                        A.push_back(v);
349                    }
350                    add_edge(u, v, G);
351                }
352                for (PabloAST * input : *cast<Variadic>(op)) {
353                    const auto f = M.find(input);
354                    Vertex v = 0;
355                    if (f != M.end()) {
356                        v = f->second;
357                    } else {
358                        v = add_vertex(input, G);
359                        M.emplace(input, v);
360                    }
361                    add_edge(v, u, G);
362                }
363            }
[4887]364        }
[4922]365    }
366}
[4878]367
[4922]368/** ------------------------------------------------------------------------------------------------------------- *
369 * @brief distribute
370 *
371 * Based on the knowledge that:
372 *
373 *   (P ∧ Q) √ (P ∧ R) ⇔ P ∧ (Q √ R) and (P √ Q) ∧ (P √ R) ⇔ P √ (Q ∧ R)
374 *
375 * Try to eliminate some of the unnecessary operations from the current Variadic expression.
376 ** ------------------------------------------------------------------------------------------------------------- */
377inline void DistributivePass::distribute(Variadic * const var) {
[4878]378
[4922]379    std::vector<Variadic *> Q;
380
381    assert (isa<And>(var) || isa<Or>(var));
382
383    Q.push_back(var);
384
385    Graph G;
386    VertexSet A;
387
388    while (Q.size() > 0) {
389
390        Variadic * expr = CanonicalizeDFG::canonicalize(Q.back());
391        Q.pop_back();
392        PabloAST * const replacement = Simplifier::fold(expr, expr->getParent());
393        if (LLVM_UNLIKELY(replacement != nullptr)) {
394            expr->replaceWith(replacement, true, true);
395            if (LLVM_UNLIKELY(isa<Variadic>(replacement))) {
396                Q.push_back(cast<Variadic>(replacement));
397            }
398            continue;
[4887]399        }
[4878]400
[4922]401        if (LLVM_LIKELY(isa<And>(expr) || isa<Or>(expr))) {
[4878]402
[4922]403            computeDistributionGraph(expr, G, A);
[4878]404
[4922]405            // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
406            if (num_vertices(G) == 0) {
407                assert (A.empty());
408                continue;
409            }
[4887]410
[4922]411            const auto S = safeDistributionSets(G, A);
412            if (S.empty()) {
413                G.clear();
414                A.clear();
415                continue;
[4878]416            }
417
[4922]418            Q.push_back(expr);
419
420            for (const DistributionSet & set : S) {
421
422                // Each distribution tuple consists of the sources, intermediary, and sink nodes.
423                const VertexSet & sources = std::get<0>(set);
424                assert (sources.size() > 0);
425                const VertexSet & intermediary = std::get<1>(set);
426                assert (intermediary.size() > 1);
427                const VertexSet & sinks = std::get<2>(set);
428                assert (sinks.size() > 0);
429
430                // Test whether we can apply the identity law to distributed set. I.e., (P ∧ Q) √ (P ∧ ¬Q) ⇔ (P √ Q) ∧ (P √ ¬Q) ⇔ P
431
432
433                for (const Vertex u : intermediary) {
434                    for (const Vertex v : sources) {
435                        cast<Variadic>(G[u])->deleteOperand(G[v]);
436                    }
[4887]437                }
[4922]438                for (const Vertex u : sinks) {
439                    for (const Vertex v : intermediary) {
440                        cast<Variadic>(G[u])->deleteOperand(G[v]);
441                    }
442                }
[4887]443
[4922]444                PabloBlock * const block = findInsertionPoint(sinks, G);
445                Variadic * innerOp = nullptr;
446                Variadic * outerOp = nullptr;
447                if (isa<And>(expr)) {
448                    outerOp = block->createAnd(intermediary.size());
449                    innerOp = block->createOr(sources.size() + 1);
450                } else {
451                    outerOp = block->createOr(intermediary.size());
452                    innerOp = block->createAnd(sources.size() + 1);
453                }
[4887]454                for (const Vertex v : intermediary) {
[4922]455                    outerOp->addOperand(G[v]);
[4887]456                }
[4922]457                for (const Vertex v : sources) {
458                    innerOp->addOperand(G[v]);
459                }
460                for (const Vertex u : sinks) {
461                    cast<Variadic>(G[u])->addOperand(innerOp);
462                }
463                innerOp->addOperand(outerOp);
464                // Push our newly constructed ops into the Q
465                Q.push_back(innerOp);
466                Q.push_back(outerOp);
[4887]467            }
468
[4922]469            G.clear();
470            A.clear();
[4880]471        }
472    }
473}
[4878]474
[4880]475/** ------------------------------------------------------------------------------------------------------------- *
[4887]476 * @brief distribute
[4880]477 ** ------------------------------------------------------------------------------------------------------------- */
[4887]478void DistributivePass::distribute(PabloBlock * const block) {
[4880]479    for (Statement * stmt : *block) {
480        if (isa<If>(stmt) || isa<While>(stmt)) {
[4887]481            distribute(isa<If>(stmt) ? cast<If>(stmt)->getBody() : cast<While>(stmt)->getBody());
[4922]482        } else if (isa<And>(stmt) || isa<Or>(stmt)) {
483            distribute(cast<Variadic>(stmt));
[4878]484        }
485    }
486}
487
[4880]488/** ------------------------------------------------------------------------------------------------------------- *
[4887]489 * @brief optimize
[4880]490 ** ------------------------------------------------------------------------------------------------------------- */
[4887]491void DistributivePass::optimize(PabloFunction & function) {
492    DistributivePass::distribute(function.getEntryBlock());
[4880]493    #ifndef NDEBUG
494    PabloVerifier::verify(function, "post-distribution");
495    #endif
496    Simplifier::optimize(function);
497}
[4878]498
[4880]499
[4878]500}
Note: See TracBrowser for help on using the repository browser.