source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 5267

Last change on this file since 5267 was 5267, checked in by nmedfort, 3 years ago

Code clean-up. Removed Pablo Call, SetIthBit? and Prototype.

File size: 19.2 KB
Line 
1#include "distributivepass.h"
2
3#include <pablo/codegenstate.h>
4#ifndef NDEBUG
5#include <pablo/analysis/pabloverifier.hpp>
6#endif
7#include <pablo/optimizers/pablo_simplifier.hpp>
8#include <pablo/passes/flattenassociativedfg.h>
9#include <boost/container/flat_set.hpp>
10#include <boost/container/flat_map.hpp>
11#include <boost/graph/adjacency_list.hpp>
12
13#include <pablo/printer_pablos.h>
14#include <iostream>
15
16using namespace boost;
17using namespace boost::container;
18
19namespace pablo {
20
21using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
22using Vertex = Graph::vertex_descriptor;
23using Map = flat_map<PabloAST *, Vertex>;
24using VertexSet = std::vector<Vertex>;
25using Biclique = std::pair<VertexSet, VertexSet>;
26using BicliqueSet = std::vector<Biclique>;
27using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
28using DistributionSets = std::vector<DistributionSet>;
29using TypeId = PabloAST::ClassTypeId;
30
31/** ------------------------------------------------------------------------------------------------------------- *
32 * @brief intersects
33 ** ------------------------------------------------------------------------------------------------------------- */
34template <class Type>
35inline bool intersects(Type & A, Type & B) {
36    auto first1 = A.begin(), last1 = A.end();
37    auto first2 = B.begin(), last2 = B.end();
38    while (first1 != last1 && first2 != last2) {
39        if (*first1 < *first2) {
40            ++first1;
41        } else if (*first2 < *first1) {
42            ++first2;
43        } else {
44            return true;
45        }
46    }
47    return false;
48}
49
50/** ------------------------------------------------------------------------------------------------------------- *
51 * @brief independentCliqueSets
52 ** ------------------------------------------------------------------------------------------------------------- */
53static BicliqueSet && independentCliqueSets(BicliqueSet && bicliques, const bool uppsetSet) {
54    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
55
56    const auto l = bicliques.size();
57    IndependentSetGraph I(l);
58
59
60    // Initialize our weights and determine the constraints
61    if (uppsetSet) {
62        for (auto i = bicliques.begin(); i != bicliques.end(); ++i) {
63            I[std::distance(bicliques.begin(), i)] = std::pow(std::get<0>(*i).size(), 2);
64            for (auto j = i; ++j != bicliques.end(); ) {
65                if (intersects(i->first, j->first)) {
66                    add_edge(std::distance(bicliques.begin(), i), std::distance(bicliques.begin(), j), I);
67                }
68            }
69        }
70    } else {
71        for (auto i = bicliques.begin(); i != bicliques.end(); ++i) {
72            I[std::distance(bicliques.begin(), i)] = std::pow(std::get<1>(*i).size(), 2);
73            for (auto j = i; ++j != bicliques.end(); ) {
74                if (intersects(i->first, j->first) && intersects(i->second, j->second)) {
75                    add_edge(std::distance(bicliques.begin(), i), std::distance(bicliques.begin(), j), I);
76                }
77            }
78        }
79    }
80
81
82    // Use the greedy algorithm to choose our independent set
83    VertexSet selected;
84    for (;;) {
85        unsigned w = 0;
86        Vertex u = 0;
87        for (unsigned i = 0; i != l; ++i) {
88            if (I[i] > w) {
89                w = I[i];
90                u = i;
91            }
92        }
93        if (w < (uppsetSet ? 2 : 1)) break;
94        selected.push_back(u);
95        I[u] = 0;
96        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
97            I[v] = 0;
98        }
99    }
100
101    // Sort the selected list and then remove the unselected cliques
102    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
103    auto end = bicliques.end();
104    for (const unsigned offset : selected) {
105        end = bicliques.erase(bicliques.begin() + offset + 1, end) - 1;
106    }
107    bicliques.erase(bicliques.begin(), end);
108
109    return std::move(bicliques);
110}
111
112/** ------------------------------------------------------------------------------------------------------------- *
113 * @brief enumerateBicliques
114 *
115 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
116 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
117 * bipartition A and their adjacencies to be in B.
118  ** ------------------------------------------------------------------------------------------------------------- */
119static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A, const unsigned min) {
120    using IntersectionSets = std::set<VertexSet>;
121
122    IntersectionSets B1;
123    VertexSet tmp;
124    for (auto u : A) {
125        if (in_degree(u, G) > 0) {
126            tmp.reserve(in_degree(u, G));
127            for (auto e : make_iterator_range(in_edges(u, G))) {
128                tmp.push_back(source(e, G));
129            }
130            if (tmp.size() >= min) {
131                std::sort(tmp.begin(), tmp.end());
132                B1.emplace(tmp.begin(), tmp.end());
133            }
134            tmp.clear();
135        }
136    }
137
138    IntersectionSets B(B1);
139
140    IntersectionSets Bi;
141    for (auto i = B1.begin(); i != B1.end(); ++i) {
142        for (auto j = i; ++j != B1.end(); ) {
143            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(tmp));
144            if (tmp.size() >= min) {
145                if (B.count(tmp) == 0) {
146                    Bi.emplace(tmp.begin(), tmp.end());
147                }
148            }
149            tmp.clear();
150        }
151    }
152
153    for (;;) {
154        if (Bi.empty()) {
155            break;
156        }
157        B.insert(Bi.begin(), Bi.end());
158        IntersectionSets Bk;
159        for (auto i = B1.begin(); i != B1.end(); ++i) {
160            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
161                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(tmp));
162                if (tmp.size() >= min) {
163                    if (B.count(tmp) == 0) {
164                        Bk.emplace(tmp.begin(), tmp.end());
165                    }
166                }
167                tmp.clear();
168            }
169        }
170        Bi.swap(Bk);
171    }
172
173    BicliqueSet cliques;
174    cliques.reserve(B.size());
175    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
176        assert (Bi->size() >= min);
177        VertexSet Ai(A);
178        for (const Vertex u : *Bi) {
179            VertexSet Aj;
180            Aj.reserve(out_degree(u, G));
181            for (auto e : make_iterator_range(out_edges(u, G))) {
182                Aj.push_back(target(e, G));
183            }
184            std::sort(Aj.begin(), Aj.end());
185            VertexSet Ak;
186            Ak.reserve(std::min(Ai.size(), Aj.size()));
187            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
188            Ai.swap(Ak);
189        }
190        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
191        cliques.emplace_back(std::move(Ai), std::move(*Bi));
192    }
193    return std::move(cliques);
194}
195
196/** ------------------------------------------------------------------------------------------------------------- *
197 * @brief removeUnhelpfulBicliques
198 *
199 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
200 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
201 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
202 ** ------------------------------------------------------------------------------------------------------------- */
203static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, const Graph & G) {
204    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
205        const auto cardinalityA = std::get<0>(*ci).size();
206        VertexSet & B = std::get<1>(*ci);
207        for (auto bi = B.begin(); bi != B.end(); ) {
208            if (G[*bi]->getNumUses() == cardinalityA) {
209                ++bi;
210            } else {
211                bi = B.erase(bi);
212            }
213        }
214        if (B.size() > 1) {
215            ++ci;
216        } else {
217            ci = cliques.erase(ci);
218        }
219    }
220    return std::move(cliques);
221}
222
223/** ------------------------------------------------------------------------------------------------------------- *
224 * @brief safeDistributionSets
225 ** ------------------------------------------------------------------------------------------------------------- */
226inline static DistributionSets safeDistributionSets(const Graph & G, const VertexSet & A) {
227    DistributionSets T;
228    BicliqueSet lowerSet = independentCliqueSets(removeUnhelpfulBicliques(enumerateBicliques(G, A, 1), G), false);
229    for (Biclique & lower : lowerSet) {
230        BicliqueSet upperSet = independentCliqueSets(enumerateBicliques(G, std::get<1>(lower), 2), true);
231        for (Biclique & upper : upperSet) {
232            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
233        }
234    }
235    return std::move(T);
236}
237
238/** ------------------------------------------------------------------------------------------------------------- *
239 * @brief scopeDepthOf
240 ** ------------------------------------------------------------------------------------------------------------- */
241inline unsigned scopeDepthOf(const PabloBlock * block) {
242    unsigned depth = 0;
243    for (; block; ++depth, block = block->getPredecessor());
244    return depth;
245}
246
247/** ------------------------------------------------------------------------------------------------------------- *
248 * @brief findInsertionPoint
249 ** ------------------------------------------------------------------------------------------------------------- */
250inline PabloBlock * findInsertionPoint(const VertexSet & users, const Graph & G) {
251    std::vector<PabloBlock *> scopes(0);
252    scopes.reserve(users.size());
253    for (Vertex u : users) {
254        PabloBlock * const scope = cast<Statement>(G[u])->getParent(); assert (scope);
255        if (std::find(scopes.begin(), scopes.end(), scope) == scopes.end()) {
256            scopes.push_back(scope);
257        }
258    }
259    while (scopes.size() > 1) {
260        // Find the LCA of both scopes then add the LCA back to the list of scopes.
261        PabloBlock * scope1 = scopes.back();
262        scopes.pop_back();
263        PabloBlock * scope2 = scopes.back();
264        scopes.pop_back();
265        assert (scope1 != scope2);
266        unsigned depth1 = scopeDepthOf(scope1);
267        unsigned depth2 = scopeDepthOf(scope2);
268        // If one of these scopes is nested deeper than the other, scan upwards through
269        // the scope tree until both scopes are at the same depth.
270        while (depth1 > depth2) {
271            scope1 = scope1->getPredecessor();
272            --depth1;
273        }
274        while (depth1 < depth2) {
275            scope2 = scope2->getPredecessor();
276            --depth2;
277        }
278        assert (depth1 == depth2);
279        // Then iteratively step backwards until we find a matching scopes; this must be
280        // the LCA of our original pair.
281        while (scope1 != scope2) {
282            scope1 = scope1->getPredecessor();
283            scope2 = scope2->getPredecessor();
284        }
285        assert (scope1 && scope2);
286        if (std::find(scopes.begin(), scopes.end(), scope1) == scopes.end()) {
287            scopes.push_back(scope1);
288        }
289    }
290    assert (scopes.size() == 1);
291    PabloBlock * const root = scopes.front();
292    // Now that we know the common scope of these users, test which statement is the first to require it.
293    flat_set<Statement *> usages;
294    usages.reserve(users.size());
295    for (Vertex u : users) {
296        Statement * user = cast<Statement>(G[u]);
297        PabloBlock * scope = user->getParent();
298        while (scope != root) {
299            assert (scope);
300            user = scope->getBranch();
301            scope = scope->getPredecessor();
302        }
303        usages.insert(user);
304    }
305    Statement * ip = nullptr;
306    for (Statement * stmt : *root) {
307        if (usages.count(stmt)) {
308            ip = stmt->getPrevNode();
309            break;
310        }
311    }
312    assert (ip);
313    root->setInsertPoint(ip);
314    return root;
315}
316
317/** ------------------------------------------------------------------------------------------------------------- *
318 * @brief computeDistributionGraph
319 ** ------------------------------------------------------------------------------------------------------------- */
320static inline void computeDistributionGraph(Variadic * const expr, Graph & G, VertexSet & A) {
321
322    TypeId outerTypeId = expr->getClassTypeId();
323    TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
324
325    assert (isa<And>(expr) || isa<Or>(expr));
326
327    Map M;
328    for (unsigned i = 0; i != expr->getNumOperands(); ++i) {
329        PabloAST * const op = expr->getOperand(i);
330        if (op->getClassTypeId() == innerTypeId) {
331            bool distributable = true;
332            for (PabloAST * user : op->users()) {
333                // Early check to see whether it'd be beneficial to distribute it. If this fails, we'd have
334                // to compute the operand's value anyway, so just ignore this operand.
335                if (user->getClassTypeId() != outerTypeId) {
336                    distributable = false;
337                    break;
338                }
339            }
340            if (distributable) {
341                const Vertex u = add_vertex(op, G);
342                for (PabloAST * user : op->users()) {
343                    const auto f = M.find(user);
344                    Vertex v = 0;
345                    if (LLVM_LIKELY(f != M.end())) {
346                        v = f->second;
347                    } else {
348                        v = add_vertex(user, G);
349                        M.emplace(user, v);
350                        A.push_back(v);
351                    }
352                    add_edge(u, v, G);
353                }
354                for (PabloAST * input : *cast<Variadic>(op)) {
355                    const auto f = M.find(input);
356                    Vertex v = 0;
357                    if (f != M.end()) {
358                        v = f->second;
359                    } else {
360                        v = add_vertex(input, G);
361                        M.emplace(input, v);
362                    }
363                    add_edge(v, u, G);
364                }
365            }
366        }
367    }
368}
369
370/** ------------------------------------------------------------------------------------------------------------- *
371 * @brief distribute
372 *
373 * Based on the knowledge that:
374 *
375 *   (P ∧ Q) √ (P ∧ R) ⇔ P ∧ (Q √ R) and (P √ Q) ∧ (P √ R) ⇔ P √ (Q ∧ R)
376 *
377 * Try to eliminate some of the unnecessary operations from the current Variadic expression.
378 ** ------------------------------------------------------------------------------------------------------------- */
379inline void DistributivePass::distribute(Variadic * const var) {
380
381
382
383    assert (isa<And>(var) || isa<Or>(var));
384
385    Graph G;
386    VertexSet A;
387
388    std::vector<Variadic *> Q = {var};
389
390    while (!Q.empty()) {
391        Variadic * expr = CanonicalizeDFG::canonicalize(Q.back()); Q.pop_back();
392        PabloAST * const replacement = Simplifier::fold(expr, expr->getParent());
393        if (LLVM_UNLIKELY(replacement != nullptr)) {
394            expr->replaceWith(replacement, true, true);
395            if (LLVM_UNLIKELY(isa<Variadic>(replacement))) {
396                Q.push_back(cast<Variadic>(replacement));
397            }
398            continue;
399        }
400
401        if (LLVM_LIKELY(isa<And>(expr) || isa<Or>(expr))) {
402
403            computeDistributionGraph(expr, G, A);
404
405            // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
406            if (num_vertices(G) == 0) {
407                assert (A.empty());
408                continue;
409            }
410
411            const auto S = safeDistributionSets(G, A);
412            if (S.empty()) {
413                G.clear();
414                A.clear();
415                continue;
416            }
417
418            Q.push_back(expr);
419
420            for (const DistributionSet & set : S) {
421
422                // Each distribution tuple consists of the sources, intermediary, and sink nodes.
423                const VertexSet & sources = std::get<0>(set);
424                assert (sources.size() > 0);
425                const VertexSet & intermediary = std::get<1>(set);
426                assert (intermediary.size() > 1);
427                const VertexSet & sinks = std::get<2>(set);
428                assert (sinks.size() > 0);
429
430                for (const Vertex u : intermediary) {
431                    for (const Vertex v : sources) {
432                        cast<Variadic>(G[u])->deleteOperand(G[v]);
433                    }
434                }
435                for (const Vertex u : sinks) {
436                    for (const Vertex v : intermediary) {
437                        cast<Variadic>(G[u])->deleteOperand(G[v]);
438                    }
439                }
440
441                PabloBlock * const block = findInsertionPoint(sinks, G);
442                Variadic * innerOp = nullptr;
443                Variadic * outerOp = nullptr;
444                if (isa<And>(expr)) {
445                    outerOp = block->createAnd(intermediary.size());
446                    innerOp = block->createOr(sources.size() + 1);
447                } else {
448                    outerOp = block->createOr(intermediary.size());
449                    innerOp = block->createAnd(sources.size() + 1);
450                }
451                for (const Vertex v : intermediary) {
452                    outerOp->addOperand(G[v]);
453                }
454                for (const Vertex v : sources) {
455                    innerOp->addOperand(G[v]);
456                }
457                for (const Vertex u : sinks) {
458                    cast<Variadic>(G[u])->addOperand(innerOp);
459                }
460                innerOp->addOperand(outerOp);
461                // Push our newly constructed ops into the Q
462                Q.push_back(innerOp);
463                Q.push_back(outerOp);
464
465                unmodified = false;
466            }
467
468            G.clear();
469            A.clear();
470        }
471    }
472
473}
474
475/** ------------------------------------------------------------------------------------------------------------- *
476 * @brief distribute
477 ** ------------------------------------------------------------------------------------------------------------- */
478void DistributivePass::distribute(PabloBlock * const block) {
479    Statement * stmt = block->front();
480    while (stmt) {
481        Statement * next = stmt->getNextNode();
482        if (isa<If>(stmt) || isa<While>(stmt)) {
483            distribute(isa<If>(stmt) ? cast<If>(stmt)->getBody() : cast<While>(stmt)->getBody());
484        } else if (isa<And>(stmt) || isa<Or>(stmt)) {
485            distribute(cast<Variadic>(stmt));
486        }
487        stmt = next;
488    }
489}
490
491/** ------------------------------------------------------------------------------------------------------------- *
492 * @brief optimize
493 ** ------------------------------------------------------------------------------------------------------------- */
494bool DistributivePass::optimize(PabloFunction & function) {
495    DistributivePass dp;
496    unsigned rounds = 0;
497    for (;;) {
498        dp.unmodified = true;
499        dp.distribute(function.getEntryBlock());
500        if (dp.unmodified) {
501            break;
502        }
503        ++rounds;
504        #ifndef NDEBUG
505        PabloVerifier::verify(function, "post-distribution");
506        #endif
507        Simplifier::optimize(function);
508    }
509    return rounds > 0;
510}
511
512
513}
Note: See TracBrowser for help on using the repository browser.