source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 4896

Last change on this file since 4896 was 4896, checked in by nmedfort, 3 years ago

Work on coalescing algorithm + minor changes.

File size: 15.3 KB
Line 
1#include "distributivepass.h"
2
3#include <pablo/codegenstate.h>
4#include <pablo/analysis/pabloverifier.hpp>
5#include <pablo/optimizers/pablo_simplifier.hpp>
6#include <pablo/passes/flattenassociativedfg.h>
7#include <boost/container/flat_set.hpp>
8#include <boost/container/flat_map.hpp>
9#include <boost/graph/adjacency_list.hpp>
10
11using namespace boost;
12using namespace boost::container;
13
14namespace pablo {
15
16using Graph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
17using Vertex = Graph::vertex_descriptor;
18using Map = flat_map<PabloAST *, Vertex>;
19using VertexSet = std::vector<Vertex>;
20using Biclique = std::pair<VertexSet, VertexSet>;
21using BicliqueSet = std::vector<Biclique>;
22using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
23using DistributionSets = std::vector<DistributionSet>;
24
25using TypeId = PabloAST::ClassTypeId;
26
27/** ------------------------------------------------------------------------------------------------------------- *
28 * @brief getVertex
29 ** ------------------------------------------------------------------------------------------------------------- */
30static inline Vertex getVertex(PabloAST * value, Graph & G, Map & M) {
31    const auto f = M.find(value);
32    if (f != M.end()) {
33        return f->second;
34    }
35    const auto u = add_vertex(value, G);
36    M.emplace(value, u);
37    return u;
38}
39
40/** ------------------------------------------------------------------------------------------------------------- *
41 * @brief generateDistributionGraph
42 *
43 * Generate a graph G describing the potential applications of the distributive law for the given block.
44 ** ------------------------------------------------------------------------------------------------------------- */
45VertexSet generateDistributionGraph(PabloBlock * block, Graph & G) {
46    Map M;
47    VertexSet distSet;
48    for (Statement * stmt : *block) {
49        if (isa<And>(stmt) || isa<Or>(stmt)) {
50            const TypeId outerTypeId = stmt->getClassTypeId();
51            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
52            flat_set<PabloAST *> distributable;
53            for (PabloAST * const expr : *cast<Variadic>(stmt)) {
54                if (LLVM_UNLIKELY(expr->getClassTypeId() == innerTypeId)) {
55                    bool safe = true;
56                    for (PabloAST * const user : expr->users()) {
57                        if (user->getClassTypeId() != outerTypeId) {
58                            safe = false;
59                            break;
60                        }
61                    }
62                    if (safe) {
63                        distributable.insert(expr);
64                    }
65                }
66            }
67            if (LLVM_LIKELY(distributable.size() > 1)) {
68                flat_map<PabloAST *, bool> observedMoreThanOnce;
69                bool anyOpportunities = false;
70                for (const PabloAST * distOperation : distributable) {
71                    for (PabloAST * const distVar : *cast<Variadic>(distOperation)) {
72                        auto ob = observedMoreThanOnce.find(distVar);
73                        if (ob == observedMoreThanOnce.end()) {
74                            observedMoreThanOnce.emplace(distVar, false);
75                        } else {
76                            ob->second = true;
77                            anyOpportunities = true;
78                        }
79                    }
80                }
81                if (anyOpportunities) {
82                    for (const auto ob : observedMoreThanOnce) {
83                        PabloAST * distVar = nullptr;
84                        bool observedTwice = false;
85                        std::tie(distVar, observedTwice) = ob;
86                        if (observedTwice) {
87                            const Vertex z = getVertex(stmt, G, M);
88                            distSet.push_back(z);
89                            for (PabloAST * const distOperation : distVar->users()) {
90                                if (distributable.count(distOperation)) {
91                                    const Vertex y = getVertex(distOperation, G, M);
92                                    add_edge(getVertex(distVar, G, M), y, G);
93                                    add_edge(y, z, G);
94                                }
95                            }
96                        }
97                    }
98                }
99            }
100        }
101    }
102    return distSet;
103}
104
105/** ------------------------------------------------------------------------------------------------------------- *
106 * @brief intersects
107 ** ------------------------------------------------------------------------------------------------------------- */
108template <class Type>
109inline bool intersects(const Type & A, const Type & B) {
110    auto first1 = A.begin(), last1 = A.end();
111    auto first2 = B.begin(), last2 = B.end();
112    while (first1 != last1 && first2 != last2) {
113        if (*first1 < *first2) {
114            ++first1;
115        } else if (*first2 < *first1) {
116            ++first2;
117        } else {
118            return true;
119        }
120    }
121    return false;
122}
123
124/** ------------------------------------------------------------------------------------------------------------- *
125 * @brief independentCliqueSets
126 ** ------------------------------------------------------------------------------------------------------------- */
127template <unsigned side>
128inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
129    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
130
131    const auto l = cliques.size();
132    IndependentSetGraph I(l);
133
134    // Initialize our weights
135    for (unsigned i = 0; i != l; ++i) {
136        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
137    }
138
139    // Determine our constraints
140    for (unsigned i = 0; i != l; ++i) {
141        for (unsigned j = i + 1; j != l; ++j) {
142            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
143                add_edge(i, j, I);
144            }
145        }
146    }
147
148    // Use the greedy algorithm to choose our independent set
149    VertexSet selected;
150    for (;;) {
151        unsigned w = 0;
152        Vertex u = 0;
153        for (unsigned i = 0; i != l; ++i) {
154            if (I[i] > w) {
155                w = I[i];
156                u = i;
157            }
158        }
159        if (w < minimum) break;
160        selected.push_back(u);
161        I[u] = 0;
162        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
163            I[v] = 0;
164        }
165    }
166
167    // Sort the selected list and then remove the unselected cliques
168    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
169    auto end = cliques.end();
170    for (const unsigned offset : selected) {
171        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
172    }
173    cliques.erase(cliques.begin(), end);
174
175    return std::move(cliques);
176}
177
178/** ------------------------------------------------------------------------------------------------------------- *
179 * @brief enumerateBicliques
180 *
181 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
182 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
183 * bipartition A and their adjacencies to be in B.
184  ** ------------------------------------------------------------------------------------------------------------- */
185static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
186    using IntersectionSets = std::set<VertexSet>;
187
188    IntersectionSets B1;
189    for (auto u : A) {
190        if (in_degree(u, G) > 0) {
191            VertexSet incomingAdjacencies;
192            incomingAdjacencies.reserve(in_degree(u, G));
193            for (auto e : make_iterator_range(in_edges(u, G))) {
194                incomingAdjacencies.push_back(source(e, G));
195            }
196            std::sort(incomingAdjacencies.begin(), incomingAdjacencies.end());
197            B1.insert(std::move(incomingAdjacencies));
198        }
199    }
200
201    IntersectionSets B(B1);
202
203    IntersectionSets Bi;
204
205    VertexSet clique;
206    for (auto i = B1.begin(); i != B1.end(); ++i) {
207        for (auto j = i; ++j != B1.end(); ) {
208            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
209            if (clique.size() > 0) {
210                if (B.count(clique) == 0) {
211                    Bi.insert(clique);
212                }
213                clique.clear();
214            }
215        }
216    }
217
218    for (;;) {
219        if (Bi.empty()) {
220            break;
221        }
222        B.insert(Bi.begin(), Bi.end());
223        IntersectionSets Bk;
224        for (auto i = B1.begin(); i != B1.end(); ++i) {
225            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
226                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
227                if (clique.size() > 0) {
228                    if (B.count(clique) == 0) {
229                        Bk.insert(clique);
230                    }
231                    clique.clear();
232                }
233            }
234        }
235        Bi.swap(Bk);
236    }
237
238    BicliqueSet cliques;
239    cliques.reserve(B.size());
240    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
241        VertexSet Ai(A);
242        for (const Vertex u : *Bi) {
243            VertexSet Aj;
244            Aj.reserve(out_degree(u, G));
245            for (auto e : make_iterator_range(out_edges(u, G))) {
246                Aj.push_back(target(e, G));
247            }
248            std::sort(Aj.begin(), Aj.end());
249            VertexSet Ak;
250            Ak.reserve(std::min(Ai.size(), Aj.size()));
251            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
252            Ai.swap(Ak);
253        }
254        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
255        cliques.emplace_back(std::move(Ai), std::move(*Bi));
256    }
257    return std::move(cliques);
258}
259
260/** ------------------------------------------------------------------------------------------------------------- *
261 * @brief removeUnhelpfulBicliques
262 *
263 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
264 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
265 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
266 ** ------------------------------------------------------------------------------------------------------------- */
267static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, Graph & G) {
268    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
269        const auto cardinalityA = std::get<0>(*ci).size();
270        VertexSet & B = std::get<1>(*ci);
271        for (auto bi = B.begin(); bi != B.end(); ) {
272            if (G[*bi]->getNumUses() == cardinalityA) {
273                ++bi;
274            } else {
275                bi = B.erase(bi);
276            }
277        }
278        if (B.size() > 1) {
279            ++ci;
280        } else {
281            ci = cliques.erase(ci);
282        }
283    }
284    return std::move(cliques);
285}
286
287/** ------------------------------------------------------------------------------------------------------------- *
288 * @brief safeDistributionSets
289 ** ------------------------------------------------------------------------------------------------------------- */
290static DistributionSets safeDistributionSets(Graph & G, const VertexSet & distSet) {
291    DistributionSets T;
292    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(G, distSet), G), 1);
293    for (Biclique & lower : lowerSet) {
294        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(G, std::get<1>(lower)), 2);
295        for (Biclique & upper : upperSet) {
296            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
297        }
298    }
299    return std::move(T);
300}
301
302/** ------------------------------------------------------------------------------------------------------------- *
303 * @brief process
304 ** ------------------------------------------------------------------------------------------------------------- */
305inline void DistributivePass::process(PabloBlock * const block) {
306
307    for (;;) {
308
309        FlattenAssociativeDFG::coalesce(block, false);
310
311        Graph G;
312
313        const VertexSet distSet = generateDistributionGraph(block, G);
314
315        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
316        if (LLVM_UNLIKELY(distSet.empty())) {
317            break;
318        }
319
320        const DistributionSets distributionSets = safeDistributionSets(G, distSet);
321
322        if (LLVM_UNLIKELY(distributionSets.empty())) {
323            break;
324        }
325
326        for (const DistributionSet & set : distributionSets) {
327
328            // Each distribution tuple consists of the sources, intermediary, and sink nodes.
329            const VertexSet & sources = std::get<0>(set);
330            const VertexSet & intermediary = std::get<1>(set);
331            const VertexSet & sinks = std::get<2>(set);
332
333            // Find the first sink and set the insert point immediately before that.
334            Variadic * innerOp = nullptr;
335            Variadic * outerOp = nullptr;
336
337            block->setInsertPoint(cast<Variadic>(G[sinks.front()])->getPrevNode());
338            if (isa<And>(G[sinks.front()])) {
339                outerOp = block->createAnd(intermediary.size());
340                innerOp = block->createOr(sources.size() + 1);
341            } else {
342                outerOp = block->createOr(intermediary.size());
343                innerOp = block->createAnd(sources.size() + 1);
344            }
345
346            for (const Vertex u : intermediary) {
347                for (const Vertex v : sinks) {
348                    cast<Variadic>(G[v])->deleteOperand(G[u]);
349                }
350                outerOp->addOperand(G[u]);
351            }
352
353            for (const Vertex u : sources) {
354                for (const Vertex v : intermediary) {
355                    cast<Variadic>(G[v])->deleteOperand(G[u]);
356                }
357                innerOp->addOperand(G[u]);
358            }
359            innerOp->addOperand(outerOp);
360
361            for (const Vertex u : sinks) {
362                cast<Variadic>(G[u])->addOperand(innerOp);
363            }
364        }
365
366    }
367}
368
369/** ------------------------------------------------------------------------------------------------------------- *
370 * @brief distribute
371 ** ------------------------------------------------------------------------------------------------------------- */
372void DistributivePass::distribute(PabloBlock * const block) {
373    for (Statement * stmt : *block) {
374        if (isa<If>(stmt) || isa<While>(stmt)) {
375            distribute(isa<If>(stmt) ? cast<If>(stmt)->getBody() : cast<While>(stmt)->getBody());
376        }
377    }
378    process(block);
379}
380
381/** ------------------------------------------------------------------------------------------------------------- *
382 * @brief optimize
383 ** ------------------------------------------------------------------------------------------------------------- */
384void DistributivePass::optimize(PabloFunction & function) {
385    DistributivePass::distribute(function.getEntryBlock());
386    #ifndef NDEBUG
387    PabloVerifier::verify(function, "post-distribution");
388    #endif
389    Simplifier::optimize(function);
390    FlattenAssociativeDFG::deMorgansReduction(function.getEntryBlock());
391    #ifndef NDEBUG
392    PabloVerifier::verify(function, "post-demorgans-reduction");
393    #endif
394    Simplifier::optimize(function);
395}
396
397
398}
Note: See TracBrowser for help on using the repository browser.