source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 4885

Last change on this file since 4885 was 4885, checked in by nmedfort, 4 years ago

More work on n-ary operations. Unresolved bug in DistributionPass?.

File size: 15.1 KB
Line 
1#include "distributivepass.h"
2
3#include <pablo/codegenstate.h>
4#include <pablo/analysis/pabloverifier.hpp>
5#include <pablo/optimizers/pablo_simplifier.hpp>
6#include <boost/container/flat_set.hpp>
7#include <boost/container/flat_map.hpp>
8#include <boost/graph/adjacency_list.hpp>
9
10#include <pablo/printer_pablos.h>
11#include <iostream>
12
13using namespace boost;
14using namespace boost::container;
15
16namespace pablo {
17
18using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
19using Vertex = Graph::vertex_descriptor;
20using Map = flat_map<PabloAST *, Vertex>;
21using VertexSet = std::vector<Vertex>;
22using Biclique = std::pair<VertexSet, VertexSet>;
23using BicliqueSet = std::vector<Biclique>;
24using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
25using DistributionSets = std::vector<DistributionSet>;
26
27using TypeId = PabloAST::ClassTypeId;
28
29/** ------------------------------------------------------------------------------------------------------------- *
30 * @brief getVertex
31 ** ------------------------------------------------------------------------------------------------------------- */
32static inline Vertex getVertex(PabloAST * value, Graph & G, Map & M) {
33    const auto f = M.find(value);
34    if (f != M.end()) {
35        return f->second;
36    }
37    const auto u = add_vertex(value, G);
38    M.emplace(value, u);
39    return u;
40}
41
42/** ------------------------------------------------------------------------------------------------------------- *
43 * @brief generateDistributionGraph
44 *
45 * Generate a graph G describing the potential applications of the distributive law for the given block.
46 ** ------------------------------------------------------------------------------------------------------------- */
47VertexSet generateDistributionGraph(PabloBlock * block, Graph & G) {
48    Map M;
49    VertexSet distSet;
50    for (Statement * stmt : *block) {
51        if (isa<And>(stmt) || isa<Or>(stmt)) {
52            const TypeId outerTypeId = stmt->getClassTypeId();
53            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
54            flat_set<PabloAST *> distributable;
55            for (PabloAST * const expr : *cast<Variadic>(stmt)) {
56                if (LLVM_UNLIKELY(expr->getClassTypeId() == innerTypeId)) {
57                    bool safe = true;
58                    for (PabloAST * const user : expr->users()) {
59                        if (user->getClassTypeId() != outerTypeId) {
60                            safe = false;
61                            break;
62                        }
63                    }
64                    if (safe) {
65                        distributable.insert(expr);
66                    }
67                }
68            }
69            if (LLVM_LIKELY(distributable.size() > 1)) {
70                flat_map<PabloAST *, bool> observedMoreThanOnce;
71                bool anyOpportunities = false;
72                for (const PabloAST * distOperation : distributable) {
73                    for (PabloAST * const distVar : *cast<Variadic>(distOperation)) {
74                        auto ob = observedMoreThanOnce.find(distVar);
75                        if (ob == observedMoreThanOnce.end()) {
76                            observedMoreThanOnce.emplace(distVar, false);
77                        } else {
78                            ob->second = true;
79                            anyOpportunities = true;
80                        }
81                    }
82                }
83                if (anyOpportunities) {
84                    for (const auto ob : observedMoreThanOnce) {
85                        PabloAST * distVar = nullptr;
86                        bool observedTwice = false;
87                        std::tie(distVar, observedTwice) = ob;
88                        if (observedTwice) {
89                            const Vertex z = getVertex(stmt, G, M);
90                            distSet.push_back(z);
91                            for (PabloAST * const distOperation : distVar->users()) {
92                                if (distributable.count(distOperation)) {
93                                    const Vertex y = getVertex(distOperation, G, M);
94                                    add_edge(getVertex(distVar, G, M), y, G);
95                                    add_edge(y, z, G);
96                                }
97                            }
98                        }
99                    }
100                }
101            }
102        }
103    }
104    return distSet;
105}
106
107/** ------------------------------------------------------------------------------------------------------------- *
108 * @brief intersects
109 ** ------------------------------------------------------------------------------------------------------------- */
110template <class Type>
111inline bool intersects(const Type & A, const Type & B) {
112    auto first1 = A.begin(), last1 = A.end();
113    auto first2 = B.begin(), last2 = B.end();
114    while (first1 != last1 && first2 != last2) {
115        if (*first1 < *first2) {
116            ++first1;
117        } else if (*first2 < *first1) {
118            ++first2;
119        } else {
120            return true;
121        }
122    }
123    return false;
124}
125
126/** ------------------------------------------------------------------------------------------------------------- *
127 * @brief independentCliqueSets
128 ** ------------------------------------------------------------------------------------------------------------- */
129template <unsigned side>
130inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
131    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
132
133    const auto l = cliques.size();
134    IndependentSetGraph I(l);
135
136    // Initialize our weights
137    for (unsigned i = 0; i != l; ++i) {
138        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
139    }
140
141    // Determine our constraints
142    for (unsigned i = 0; i != l; ++i) {
143        for (unsigned j = i + 1; j != l; ++j) {
144            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
145                add_edge(i, j, I);
146            }
147        }
148    }
149
150    // Use the greedy algorithm to choose our independent set
151    VertexSet selected;
152    for (;;) {
153        unsigned w = 0;
154        Vertex u = 0;
155        for (unsigned i = 0; i != l; ++i) {
156            if (I[i] > w) {
157                w = I[i];
158                u = i;
159            }
160        }
161        if (w < minimum) break;
162        selected.push_back(u);
163        I[u] = 0;
164        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
165            I[v] = 0;
166        }
167    }
168
169    // Sort the selected list and then remove the unselected cliques
170    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
171    auto end = cliques.end();
172    for (const unsigned offset : selected) {
173        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
174    }
175    cliques.erase(cliques.begin(), end);
176
177    return std::move(cliques);
178}
179
180/** ------------------------------------------------------------------------------------------------------------- *
181 * @brief enumerateBicliques
182 *
183 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
184 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
185 * bipartition A and their adjacencies to be in B.
186  ** ------------------------------------------------------------------------------------------------------------- */
187static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
188    using IntersectionSets = std::set<VertexSet>;
189
190    IntersectionSets B1;
191    for (auto u : A) {
192        if (in_degree(u, G) > 0) {
193            VertexSet incomingAdjacencies;
194            incomingAdjacencies.reserve(in_degree(u, G));
195            for (auto e : make_iterator_range(in_edges(u, G))) {
196                incomingAdjacencies.push_back(source(e, G));
197            }
198            std::sort(incomingAdjacencies.begin(), incomingAdjacencies.end());
199            B1.insert(std::move(incomingAdjacencies));
200        }
201    }
202
203    IntersectionSets B(B1);
204
205    IntersectionSets Bi;
206
207    VertexSet clique;
208    for (auto i = B1.begin(); i != B1.end(); ++i) {
209        for (auto j = i; ++j != B1.end(); ) {
210            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
211            if (clique.size() > 0) {
212                if (B.count(clique) == 0) {
213                    Bi.insert(clique);
214                }
215                clique.clear();
216            }
217        }
218    }
219
220    for (;;) {
221        if (Bi.empty()) {
222            break;
223        }
224        B.insert(Bi.begin(), Bi.end());
225        IntersectionSets Bk;
226        for (auto i = B1.begin(); i != B1.end(); ++i) {
227            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
228                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
229                if (clique.size() > 0) {
230                    if (B.count(clique) == 0) {
231                        Bk.insert(clique);
232                    }
233                    clique.clear();
234                }
235            }
236        }
237        Bi.swap(Bk);
238    }
239
240    BicliqueSet cliques;
241    cliques.reserve(B.size());
242    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
243        VertexSet Ai(A);
244        for (const Vertex u : *Bi) {
245            VertexSet Aj;
246            Aj.reserve(out_degree(u, G));
247            for (auto e : make_iterator_range(out_edges(u, G))) {
248                Aj.push_back(target(e, G));
249            }
250            std::sort(Aj.begin(), Aj.end());
251            VertexSet Ak;
252            Ak.reserve(std::min(Ai.size(), Aj.size()));
253            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
254            Ai.swap(Ak);
255        }
256        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
257        cliques.emplace_back(std::move(Ai), std::move(*Bi));
258    }
259    return std::move(cliques);
260}
261
262/** ------------------------------------------------------------------------------------------------------------- *
263 * @brief removeUnhelpfulBicliques
264 *
265 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
266 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
267 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
268 ** ------------------------------------------------------------------------------------------------------------- */
269static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, Graph & G) {
270    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
271        const auto cardinalityA = std::get<0>(*ci).size();
272        VertexSet & B = std::get<1>(*ci);
273        for (auto bi = B.begin(); bi != B.end(); ) {
274            if (G[*bi]->getNumUsers() == cardinalityA) {
275                ++bi;
276            } else {
277                bi = B.erase(bi);
278            }
279        }
280        if (B.size() > 1) {
281            ++ci;
282        } else {
283            ci = cliques.erase(ci);
284        }
285    }
286    return std::move(cliques);
287}
288
289/** ------------------------------------------------------------------------------------------------------------- *
290 * @brief safeDistributionSets
291 ** ------------------------------------------------------------------------------------------------------------- */
292static DistributionSets safeDistributionSets(Graph & G, const VertexSet & distSet) {
293    DistributionSets T;
294    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(G, distSet), G), 1);
295    for (Biclique & lower : lowerSet) {
296        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(G, std::get<1>(lower)), 2);
297        for (Biclique & upper : upperSet) {
298            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
299        }
300    }
301    return std::move(T);
302}
303
304/** ------------------------------------------------------------------------------------------------------------- *
305 * @brief distribute
306 ** ------------------------------------------------------------------------------------------------------------- */
307inline bool DistributivePass::distribute(PabloBlock * const block) {
308
309    Graph G;
310
311    const VertexSet distSet = generateDistributionGraph(block, G);
312
313    // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
314    if (LLVM_UNLIKELY(distSet.empty())) {
315        return false;
316    }
317
318    const DistributionSets distributionSets = safeDistributionSets(G, distSet);
319
320    if (LLVM_UNLIKELY(distributionSets.empty())) {
321        return false;
322    }
323
324    for (const DistributionSet & set : distributionSets) {
325
326        // Each distribution tuple consists of the sources, intermediary, and sink nodes.
327        const VertexSet & sources = std::get<0>(set);
328        const VertexSet & intermediary = std::get<1>(set);
329        const VertexSet & sinks = std::get<2>(set);
330
331        // Find the first sink and set the insert point immediately before that.
332        Variadic * innerOp = nullptr;
333        Variadic * outerOp = nullptr;
334
335        block->setInsertPoint(cast<Variadic>(G[sinks.front()])->getPrevNode());
336        if (isa<And>(G[sinks.front()])) {
337            outerOp = block->createAnd(intermediary.size(), PabloBlock::createOnes());
338            innerOp = block->createOr(sources.size() + 1, outerOp);
339        } else {
340            outerOp = block->createOr(intermediary.size(), PabloBlock::createZeroes());
341            innerOp = block->createAnd(sources.size() + 1, outerOp);
342        }
343
344        unsigned i = 0;
345        for (const Vertex u : intermediary) {
346            for (const Vertex v : sinks) {
347                cast<Variadic>(G[v])->deleteOperand(G[u]);
348            }
349            outerOp->setOperand(i++, cast<Variadic>(G[u]));
350        }
351
352        i = 0;
353        for (const Vertex u : sources) {
354            for (const Vertex v : intermediary) {
355                cast<Variadic>(G[v])->deleteOperand(G[u]);
356            }
357            innerOp->setOperand(i++, G[u]);
358        }
359        for (const Vertex u : sinks) {
360            cast<Variadic>(G[u])->addOperand(innerOp);
361        }
362    }
363
364    return true;
365}
366
367/** ------------------------------------------------------------------------------------------------------------- *
368 * @brief process
369 ** ------------------------------------------------------------------------------------------------------------- */
370bool DistributivePass::process(PabloBlock * const block) {
371    bool modified = false;
372    for (Statement * stmt : *block) {
373        if (isa<If>(stmt) || isa<While>(stmt)) {
374            modified |= process(isa<If>(stmt) ? cast<If>(stmt)->getBody() : cast<While>(stmt)->getBody());
375        }
376    }
377    modified |= distribute(block);
378    return modified;
379}
380
381/** ------------------------------------------------------------------------------------------------------------- *
382 * @brief process
383 ** ------------------------------------------------------------------------------------------------------------- */
384bool DistributivePass::optimize(PabloFunction & function) {
385    const bool modified = DistributivePass::process(function.getEntryBlock());
386    #ifndef NDEBUG
387    PabloVerifier::verify(function, "post-distribution");
388    #endif
389    Simplifier::optimize(function);
390    return modified;
391}
392
393
394}
Note: See TracBrowser for help on using the repository browser.