source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 4886

Last change on this file since 4886 was 4886, checked in by nmedfort, 4 years ago

Bug fixes

File size: 15.0 KB
RevLine 
[4878]1#include "distributivepass.h"
2
3#include <pablo/codegenstate.h>
[4880]4#include <pablo/analysis/pabloverifier.hpp>
5#include <pablo/optimizers/pablo_simplifier.hpp>
6#include <boost/container/flat_set.hpp>
[4878]7#include <boost/container/flat_map.hpp>
8#include <boost/graph/adjacency_list.hpp>
9
10using namespace boost;
11using namespace boost::container;
12
13namespace pablo {
14
[4880]15using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
16using Vertex = Graph::vertex_descriptor;
17using Map = flat_map<PabloAST *, Vertex>;
18using VertexSet = std::vector<Vertex>;
[4878]19using Biclique = std::pair<VertexSet, VertexSet>;
20using BicliqueSet = std::vector<Biclique>;
21using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
22using DistributionSets = std::vector<DistributionSet>;
23
24using TypeId = PabloAST::ClassTypeId;
25
26/** ------------------------------------------------------------------------------------------------------------- *
27 * @brief getVertex
28 ** ------------------------------------------------------------------------------------------------------------- */
[4880]29static inline Vertex getVertex(PabloAST * value, Graph & G, Map & M) {
[4878]30    const auto f = M.find(value);
31    if (f != M.end()) {
32        return f->second;
33    }
34    const auto u = add_vertex(value, G);
35    M.emplace(value, u);
36    return u;
37}
38
39/** ------------------------------------------------------------------------------------------------------------- *
[4880]40 * @brief generateDistributionGraph
[4878]41 *
[4880]42 * Generate a graph G describing the potential applications of the distributive law for the given block.
43 ** ------------------------------------------------------------------------------------------------------------- */
44VertexSet generateDistributionGraph(PabloBlock * block, Graph & G) {
45    Map M;
46    VertexSet distSet;
47    for (Statement * stmt : *block) {
48        if (isa<And>(stmt) || isa<Or>(stmt)) {
49            const TypeId outerTypeId = stmt->getClassTypeId();
50            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
51            flat_set<PabloAST *> distributable;
52            for (PabloAST * const expr : *cast<Variadic>(stmt)) {
53                if (LLVM_UNLIKELY(expr->getClassTypeId() == innerTypeId)) {
54                    bool safe = true;
55                    for (PabloAST * const user : expr->users()) {
56                        if (user->getClassTypeId() != outerTypeId) {
57                            safe = false;
58                            break;
59                        }
60                    }
61                    if (safe) {
62                        distributable.insert(expr);
63                    }
[4878]64                }
65            }
[4880]66            if (LLVM_LIKELY(distributable.size() > 1)) {
67                flat_map<PabloAST *, bool> observedMoreThanOnce;
68                bool anyOpportunities = false;
69                for (const PabloAST * distOperation : distributable) {
70                    for (PabloAST * const distVar : *cast<Variadic>(distOperation)) {
71                        auto ob = observedMoreThanOnce.find(distVar);
72                        if (ob == observedMoreThanOnce.end()) {
73                            observedMoreThanOnce.emplace(distVar, false);
74                        } else {
75                            ob->second = true;
76                            anyOpportunities = true;
77                        }
[4878]78                    }
79                }
[4880]80                if (anyOpportunities) {
81                    for (const auto ob : observedMoreThanOnce) {
82                        PabloAST * distVar = nullptr;
83                        bool observedTwice = false;
84                        std::tie(distVar, observedTwice) = ob;
85                        if (observedTwice) {
86                            const Vertex z = getVertex(stmt, G, M);
87                            distSet.push_back(z);
88                            for (PabloAST * const distOperation : distVar->users()) {
89                                if (distributable.count(distOperation)) {
90                                    const Vertex y = getVertex(distOperation, G, M);
91                                    add_edge(getVertex(distVar, G, M), y, G);
92                                    add_edge(y, z, G);
93                                }
94                            }
95                        }
96                    }
97                }
[4878]98            }
99        }
100    }
[4880]101    return distSet;
[4878]102}
103
104/** ------------------------------------------------------------------------------------------------------------- *
105 * @brief intersects
106 ** ------------------------------------------------------------------------------------------------------------- */
107template <class Type>
108inline bool intersects(const Type & A, const Type & B) {
109    auto first1 = A.begin(), last1 = A.end();
110    auto first2 = B.begin(), last2 = B.end();
111    while (first1 != last1 && first2 != last2) {
112        if (*first1 < *first2) {
113            ++first1;
114        } else if (*first2 < *first1) {
115            ++first2;
116        } else {
117            return true;
118        }
119    }
120    return false;
121}
122
123/** ------------------------------------------------------------------------------------------------------------- *
124 * @brief independentCliqueSets
125 ** ------------------------------------------------------------------------------------------------------------- */
126template <unsigned side>
127inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
128    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
129
130    const auto l = cliques.size();
131    IndependentSetGraph I(l);
132
133    // Initialize our weights
134    for (unsigned i = 0; i != l; ++i) {
135        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
136    }
137
138    // Determine our constraints
139    for (unsigned i = 0; i != l; ++i) {
140        for (unsigned j = i + 1; j != l; ++j) {
141            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
142                add_edge(i, j, I);
143            }
144        }
145    }
146
147    // Use the greedy algorithm to choose our independent set
148    VertexSet selected;
149    for (;;) {
150        unsigned w = 0;
151        Vertex u = 0;
152        for (unsigned i = 0; i != l; ++i) {
153            if (I[i] > w) {
154                w = I[i];
155                u = i;
156            }
157        }
158        if (w < minimum) break;
159        selected.push_back(u);
160        I[u] = 0;
161        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
162            I[v] = 0;
163        }
164    }
165
166    // Sort the selected list and then remove the unselected cliques
167    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
168    auto end = cliques.end();
169    for (const unsigned offset : selected) {
170        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
171    }
172    cliques.erase(cliques.begin(), end);
173
174    return std::move(cliques);
175}
176
177/** ------------------------------------------------------------------------------------------------------------- *
[4880]178 * @brief enumerateBicliques
179 *
180 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
181 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
182 * bipartition A and their adjacencies to be in B.
183  ** ------------------------------------------------------------------------------------------------------------- */
184static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
185    using IntersectionSets = std::set<VertexSet>;
186
187    IntersectionSets B1;
188    for (auto u : A) {
189        if (in_degree(u, G) > 0) {
190            VertexSet incomingAdjacencies;
191            incomingAdjacencies.reserve(in_degree(u, G));
192            for (auto e : make_iterator_range(in_edges(u, G))) {
193                incomingAdjacencies.push_back(source(e, G));
194            }
195            std::sort(incomingAdjacencies.begin(), incomingAdjacencies.end());
196            B1.insert(std::move(incomingAdjacencies));
197        }
198    }
199
200    IntersectionSets B(B1);
201
202    IntersectionSets Bi;
203
204    VertexSet clique;
205    for (auto i = B1.begin(); i != B1.end(); ++i) {
206        for (auto j = i; ++j != B1.end(); ) {
207            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
208            if (clique.size() > 0) {
209                if (B.count(clique) == 0) {
210                    Bi.insert(clique);
211                }
212                clique.clear();
213            }
214        }
215    }
216
217    for (;;) {
218        if (Bi.empty()) {
219            break;
220        }
221        B.insert(Bi.begin(), Bi.end());
222        IntersectionSets Bk;
223        for (auto i = B1.begin(); i != B1.end(); ++i) {
224            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
225                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
226                if (clique.size() > 0) {
227                    if (B.count(clique) == 0) {
228                        Bk.insert(clique);
229                    }
230                    clique.clear();
231                }
232            }
233        }
234        Bi.swap(Bk);
235    }
236
237    BicliqueSet cliques;
238    cliques.reserve(B.size());
239    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
240        VertexSet Ai(A);
241        for (const Vertex u : *Bi) {
242            VertexSet Aj;
243            Aj.reserve(out_degree(u, G));
244            for (auto e : make_iterator_range(out_edges(u, G))) {
245                Aj.push_back(target(e, G));
246            }
247            std::sort(Aj.begin(), Aj.end());
248            VertexSet Ak;
249            Ak.reserve(std::min(Ai.size(), Aj.size()));
250            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
251            Ai.swap(Ak);
252        }
253        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
254        cliques.emplace_back(std::move(Ai), std::move(*Bi));
255    }
256    return std::move(cliques);
257}
258
259/** ------------------------------------------------------------------------------------------------------------- *
[4878]260 * @brief removeUnhelpfulBicliques
261 *
262 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
263 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
264 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
265 ** ------------------------------------------------------------------------------------------------------------- */
[4880]266static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, Graph & G) {
[4878]267    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
268        const auto cardinalityA = std::get<0>(*ci).size();
269        VertexSet & B = std::get<1>(*ci);
270        for (auto bi = B.begin(); bi != B.end(); ) {
271            if (G[*bi]->getNumUsers() == cardinalityA) {
272                ++bi;
273            } else {
274                bi = B.erase(bi);
275            }
276        }
277        if (B.size() > 1) {
278            ++ci;
279        } else {
280            ci = cliques.erase(ci);
281        }
282    }
283    return std::move(cliques);
284}
285
286/** ------------------------------------------------------------------------------------------------------------- *
287 * @brief safeDistributionSets
288 ** ------------------------------------------------------------------------------------------------------------- */
[4880]289static DistributionSets safeDistributionSets(Graph & G, const VertexSet & distSet) {
[4878]290    DistributionSets T;
[4880]291    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(G, distSet), G), 1);
[4878]292    for (Biclique & lower : lowerSet) {
293        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(G, std::get<1>(lower)), 2);
294        for (Biclique & upper : upperSet) {
295            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
296        }
297    }
298    return std::move(T);
299}
300
301/** ------------------------------------------------------------------------------------------------------------- *
[4880]302 * @brief distribute
[4878]303 ** ------------------------------------------------------------------------------------------------------------- */
[4880]304inline bool DistributivePass::distribute(PabloBlock * const block) {
[4878]305
[4880]306    Graph G;
[4878]307
[4880]308    const VertexSet distSet = generateDistributionGraph(block, G);
[4878]309
[4880]310    // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
311    if (LLVM_UNLIKELY(distSet.empty())) {
312        return false;
313    }
[4878]314
[4880]315    const DistributionSets distributionSets = safeDistributionSets(G, distSet);
[4878]316
[4880]317    if (LLVM_UNLIKELY(distributionSets.empty())) {
318        return false;
319    }
[4878]320
[4880]321    for (const DistributionSet & set : distributionSets) {
[4878]322
[4880]323        // Each distribution tuple consists of the sources, intermediary, and sink nodes.
324        const VertexSet & sources = std::get<0>(set);
325        const VertexSet & intermediary = std::get<1>(set);
326        const VertexSet & sinks = std::get<2>(set);
[4878]327
[4880]328        // Find the first sink and set the insert point immediately before that.
329        Variadic * innerOp = nullptr;
330        Variadic * outerOp = nullptr;
[4878]331
[4880]332        block->setInsertPoint(cast<Variadic>(G[sinks.front()])->getPrevNode());
333        if (isa<And>(G[sinks.front()])) {
[4886]334            outerOp = block->createAnd(intermediary.size());
335            innerOp = block->createOr(sources.size() + 1);
[4880]336        } else {
[4886]337            outerOp = block->createOr(intermediary.size());
338            innerOp = block->createAnd(sources.size() + 1);
[4880]339        }
[4878]340
[4880]341        for (const Vertex u : intermediary) {
342            for (const Vertex v : sinks) {
[4885]343                cast<Variadic>(G[v])->deleteOperand(G[u]);
[4878]344            }
[4886]345            outerOp->addOperand(G[u]);
[4880]346        }
[4878]347
[4880]348        for (const Vertex u : sources) {
349            for (const Vertex v : intermediary) {
[4885]350                cast<Variadic>(G[v])->deleteOperand(G[u]);
[4878]351            }
[4886]352            innerOp->addOperand(G[u]);
[4880]353        }
[4886]354        innerOp->addOperand(outerOp);
355
[4880]356        for (const Vertex u : sinks) {
357            cast<Variadic>(G[u])->addOperand(innerOp);
358        }
359    }
[4878]360
[4880]361    return true;
362}
[4878]363
[4880]364/** ------------------------------------------------------------------------------------------------------------- *
365 * @brief process
366 ** ------------------------------------------------------------------------------------------------------------- */
367bool DistributivePass::process(PabloBlock * const block) {
368    bool modified = false;
369    for (Statement * stmt : *block) {
370        if (isa<If>(stmt) || isa<While>(stmt)) {
371            modified |= process(isa<If>(stmt) ? cast<If>(stmt)->getBody() : cast<While>(stmt)->getBody());
[4878]372        }
373    }
[4880]374    modified |= distribute(block);
375    return modified;
[4878]376}
377
[4880]378/** ------------------------------------------------------------------------------------------------------------- *
379 * @brief process
380 ** ------------------------------------------------------------------------------------------------------------- */
381bool DistributivePass::optimize(PabloFunction & function) {
382    const bool modified = DistributivePass::process(function.getEntryBlock());
383    #ifndef NDEBUG
384    PabloVerifier::verify(function, "post-distribution");
385    #endif
[4886]386
[4880]387    Simplifier::optimize(function);
388    return modified;
389}
[4878]390
[4880]391
[4878]392}
Note: See TracBrowser for help on using the repository browser.