source: icGREP/icgrep-devel/icgrep/pablo/optimizers/distributivepass.cpp @ 4878

Last change on this file since 4878 was 4878, checked in by nmedfort, 4 years ago

More work on n-ary operations.

File size: 14.4 KB
Line 
1#include "distributivepass.h"
2
3#include <pablo/codegenstate.h>
4#include <boost/container/flat_map.hpp>
5#include <boost/graph/adjacency_list.hpp>
6
7using namespace boost;
8using namespace boost::container;
9
10namespace pablo {
11
12using DistributionGraph = adjacency_list<hash_setS, vecS, bidirectionalS, PabloAST *>;
13using DistributionMap = flat_map<PabloAST *, DistributionGraph::vertex_descriptor>;
14
15using VertexSet = std::vector<Variadic *>;
16using Biclique = std::pair<VertexSet, VertexSet>;
17using BicliqueSet = std::vector<Biclique>;
18using DistributionSet = std::tuple<VertexSet, VertexSet, VertexSet>;
19using DistributionSets = std::vector<DistributionSet>;
20
21using TypeId = PabloAST::ClassTypeId;
22
23/** ------------------------------------------------------------------------------------------------------------- *
24 * @brief getVertex
25 ** ------------------------------------------------------------------------------------------------------------- */
26static inline DistributionGraph::vertex_descriptor getVertex(const PabloAST * value, DistributionGraph & G, DistributionMap & M) {
27    const auto f = M.find(value);
28    if (f != M.end()) {
29        return f->second;
30    }
31    const auto u = add_vertex(value, G);
32    M.emplace(value, u);
33    return u;
34}
35
36/** ------------------------------------------------------------------------------------------------------------- *
37 * @brief enumerateBicliques
38 *
39 * Adaptation of the MICA algorithm as described in "Consensus algorithms for the generation of all maximal
40 * bicliques" by Alexe et. al. (2003). Note: this implementation considers all verticies in set A to be in
41 * bipartition A and their adjacencies to be in B.
42  ** ------------------------------------------------------------------------------------------------------------- */
43template <class Graph>
44static BicliqueSet enumerateBicliques(const Graph & G, const VertexSet & A) {
45    using IntersectionSets = std::set<VertexSet>;
46
47    IntersectionSets B1;
48    for (auto u : A) {
49        if (in_degree(u, G) > 0) {
50            VertexSet incomingAdjacencies;
51            incomingAdjacencies.reserve(in_degree(u, G));
52            for (auto e : make_iterator_range(in_edges(u, G))) {
53                incomingAdjacencies.push_back(source(e, G));
54            }
55            std::sort(incomingAdjacencies.begin(), incomingAdjacencies.end());
56            B1.insert(std::move(incomingAdjacencies));
57        }
58    }
59
60    IntersectionSets B(B1);
61
62    IntersectionSets Bi;
63
64    VertexSet clique;
65    for (auto i = B1.begin(); i != B1.end(); ++i) {
66        for (auto j = i; ++j != B1.end(); ) {
67            std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
68            if (clique.size() > 0) {
69                if (B.count(clique) == 0) {
70                    Bi.insert(clique);
71                }
72                clique.clear();
73            }
74        }
75    }
76
77    for (;;) {
78        if (Bi.empty()) {
79            break;
80        }
81        B.insert(Bi.begin(), Bi.end());
82        IntersectionSets Bk;
83        for (auto i = B1.begin(); i != B1.end(); ++i) {
84            for (auto j = Bi.begin(); j != Bi.end(); ++j) {
85                std::set_intersection(i->begin(), i->end(), j->begin(), j->end(), std::back_inserter(clique));
86                if (clique.size() > 0) {
87                    if (B.count(clique) == 0) {
88                        Bk.insert(clique);
89                    }
90                    clique.clear();
91                }
92            }
93        }
94        Bi.swap(Bk);
95    }
96
97    BicliqueSet cliques;
98    cliques.reserve(B.size());
99    for (auto Bi = B.begin(); Bi != B.end(); ++Bi) {
100        VertexSet Ai(A);
101        for (const Vertex u : *Bi) {
102            VertexSet Aj;
103            Aj.reserve(out_degree(u, G));
104            for (auto e : make_iterator_range(out_edges(u, G))) {
105                Aj.push_back(target(e, G));
106            }
107            std::sort(Aj.begin(), Aj.end());
108            VertexSet Ak;
109            Ak.reserve(std::min(Ai.size(), Aj.size()));
110            std::set_intersection(Ai.begin(), Ai.end(), Aj.begin(), Aj.end(), std::back_inserter(Ak));
111            Ai.swap(Ak);
112        }
113        assert (Ai.size() > 0); // cannot happen if this algorithm is working correctly
114        cliques.emplace_back(std::move(Ai), std::move(*Bi));
115    }
116    return std::move(cliques);
117}
118
119/** ------------------------------------------------------------------------------------------------------------- *
120 * @brief intersects
121 ** ------------------------------------------------------------------------------------------------------------- */
122template <class Type>
123inline bool intersects(const Type & A, const Type & B) {
124    auto first1 = A.begin(), last1 = A.end();
125    auto first2 = B.begin(), last2 = B.end();
126    while (first1 != last1 && first2 != last2) {
127        if (*first1 < *first2) {
128            ++first1;
129        } else if (*first2 < *first1) {
130            ++first2;
131        } else {
132            return true;
133        }
134    }
135    return false;
136}
137
138/** ------------------------------------------------------------------------------------------------------------- *
139 * @brief independentCliqueSets
140 ** ------------------------------------------------------------------------------------------------------------- */
141template <unsigned side>
142inline static BicliqueSet && independentCliqueSets(BicliqueSet && cliques, const unsigned minimum) {
143    using IndependentSetGraph = adjacency_list<hash_setS, vecS, undirectedS, unsigned>;
144
145    const auto l = cliques.size();
146    IndependentSetGraph I(l);
147
148    // Initialize our weights
149    for (unsigned i = 0; i != l; ++i) {
150        I[i] = std::pow(std::get<side>(cliques[i]).size(), 2);
151    }
152
153    // Determine our constraints
154    for (unsigned i = 0; i != l; ++i) {
155        for (unsigned j = i + 1; j != l; ++j) {
156            if (intersects(std::get<side>(cliques[i]), std::get<side>(cliques[j]))) {
157                add_edge(i, j, I);
158            }
159        }
160    }
161
162    // Use the greedy algorithm to choose our independent set
163    VertexSet selected;
164    for (;;) {
165        unsigned w = 0;
166        Vertex u = 0;
167        for (unsigned i = 0; i != l; ++i) {
168            if (I[i] > w) {
169                w = I[i];
170                u = i;
171            }
172        }
173        if (w < minimum) break;
174        selected.push_back(u);
175        I[u] = 0;
176        for (auto v : make_iterator_range(adjacent_vertices(u, I))) {
177            I[v] = 0;
178        }
179    }
180
181    // Sort the selected list and then remove the unselected cliques
182    std::sort(selected.begin(), selected.end(), std::greater<Vertex>());
183    auto end = cliques.end();
184    for (const unsigned offset : selected) {
185        end = cliques.erase(cliques.begin() + offset + 1, end) - 1;
186    }
187    cliques.erase(cliques.begin(), end);
188
189    return std::move(cliques);
190}
191
192/** ------------------------------------------------------------------------------------------------------------- *
193 * @brief removeUnhelpfulBicliques
194 *
195 * An intermediary vertex could have more than one outgoing edge but if any that are not directed to vertices in
196 * the lower biclique, we'll need to compute that specific value anyway. Remove them from the clique set and if
197 * there are not enough vertices in the biclique to make distribution profitable, eliminate the clique.
198 ** ------------------------------------------------------------------------------------------------------------- */
199static BicliqueSet && removeUnhelpfulBicliques(BicliqueSet && cliques, DistributionGraph & G) {
200    for (auto ci = cliques.begin(); ci != cliques.end(); ) {
201        const auto cardinalityA = std::get<0>(*ci).size();
202        VertexSet & B = std::get<1>(*ci);
203        for (auto bi = B.begin(); bi != B.end(); ) {
204            if (G[*bi]->getNumUsers() == cardinalityA) {
205                ++bi;
206            } else {
207                bi = B.erase(bi);
208            }
209        }
210        if (B.size() > 1) {
211            ++ci;
212        } else {
213            ci = cliques.erase(ci);
214        }
215    }
216    return std::move(cliques);
217}
218
219/** ------------------------------------------------------------------------------------------------------------- *
220 * @brief safeDistributionSets
221 ** ------------------------------------------------------------------------------------------------------------- */
222static DistributionSets safeDistributionSets(DistributionGraph & G) {
223
224    VertexSet sinks;
225    for (const auto u : make_iterator_range(vertices(G))) {
226        if (out_degree(u, G) == 0 && in_degree(u, G) != 0) {
227            sinks.push_back(u);
228        }
229    }
230
231    DistributionSets T;
232    BicliqueSet lowerSet = independentCliqueSets<1>(removeUnhelpfulBicliques(enumerateBicliques(G, sinks), G), 1);
233    for (Biclique & lower : lowerSet) {
234        BicliqueSet upperSet = independentCliqueSets<0>(enumerateBicliques(G, std::get<1>(lower)), 2);
235        for (Biclique & upper : upperSet) {
236            T.emplace_back(std::move(std::get<1>(upper)), std::move(std::get<0>(upper)), std::get<0>(lower));
237        }
238    }
239    return std::move(T);
240}
241
242/** ------------------------------------------------------------------------------------------------------------- *
243 * @brief generateDistributionGraph
244 *
245 * Generate a graph G describing the potential applications of the distributive law for the given block.
246 ** ------------------------------------------------------------------------------------------------------------- */
247void generateDistributionGraph(PabloBlock * block, DistributionGraph & G) {
248    DistributionMap M;
249    for (Statement * stmt : *block) {
250        if (isa<And>(stmt) || isa<Or>(stmt)) {
251            const TypeId outerTypeId = stmt->getClassTypeId();
252            const TypeId innerTypeId = (outerTypeId == TypeId::And) ? TypeId::Or : TypeId::And;
253            flat_set<Variadic *> distributable;
254            for (PabloAST * const expr : *cast<Variadic>(stmt)) {
255                if (LLVM_UNLIKELY(expr->getClassTypeId() == innerTypeId)) {
256                    bool safe = true;
257                    for (PabloAST * const user : expr->users()) {
258                        if (user->getClassTypeId() != outerTypeId) {
259                            safe = false;
260                            break;
261                        }
262                    }
263                    if (safe) {
264                        distributable.insert(cast<Variadic>(expr));
265                    }
266                }
267            }
268            if (LLVM_LIKELY(distributable.size() > 1)) {
269                flat_map<PabloAST *, bool> observedMoreThanOnce;
270                bool anyOpportunities = false;
271                for (const Variadic * distOperation : distributable) {
272                    for (PabloAST * const distVar : *distOperation) {
273                        auto ob = observedMoreThanOnce.find(distVar);
274                        if (ob == observedMoreThanOnce.end()) {
275                            observedMoreThanOnce.emplace(distVar, false);
276                        } else {
277                            ob->second = true;
278                            anyOpportunities = true;
279                        }
280                    }
281                }
282                if (anyOpportunities) {
283                    for (const auto ob : observedMoreThanOnce) {
284                        PabloAST * const distVar = nullptr;
285                        bool observedTwice = false;
286                        std::tie(distVar, observedTwice) = *ob;
287                        if (observedTwice) {
288                            for (PabloAST * const distOperation : distVar->users()) {
289                                if (distributable.count(distOperation)) {
290                                    const Vertex y = getVertex(distOperation, G, M);
291                                    add_edge(y, getVertex(stmt, G, M), G);
292                                    add_edge(getVertex(distVar, G, M), y, G);
293                                }
294                            }
295                        }
296                    }
297                }
298            }
299        }
300    }
301}
302
303/** ------------------------------------------------------------------------------------------------------------- *
304 * @brief redistributeAST
305 *
306 * Apply the distribution law to reduce computations whenever possible.
307 ** ------------------------------------------------------------------------------------------------------------- */
308void DistributivePass::redistributeAST(PabloBlock * block) const {
309
310    std::vector<Vertex> mapping(num_vertices(G) + 16);
311    std::iota(mapping.begin(), mapping.end(), 0); // 0,1,.....n
312
313    for (;;) {
314
315        DistributionGraph H;
316
317        generateDistributionGraph(block, H);
318
319        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
320        if (num_vertices(H) == 0) {
321            break;
322        }
323
324        const DistributionSets distributionSets = safeDistributionSets(block, H);
325
326        if (LLVM_UNLIKELY(distributionSets.empty())) {
327            break;
328        }
329
330        for (const DistributionSet & set : distributionSets) {
331
332            // Each distribution tuple consists of the sources, intermediary, and sink nodes.
333            const VertexSet & sources = std::get<0>(set);
334            const VertexSet & intermediary = std::get<1>(set);
335            const VertexSet & sinks = std::get<2>(set);
336
337            const TypeId outerTypeId = getType(G[mapping[H[sinks.front()]]]);
338            assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
339            const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
340
341            // Update G to match the desired change
342            const Vertex x = addSummaryVertex(outerTypeId, G);
343            const Vertex y = addSummaryVertex(innerTypeId, G);
344            mapping.resize(num_vertices(G));
345            mapping[x] = x;
346            mapping[y] = y;
347
348            for (const Vertex i : intermediary) {
349                const auto u = mapping[H[i]];
350                assert (getType(G[u]) == innerTypeId);
351                for (const Vertex t : sinks) {
352                    assert (getType(G[mapping[H[t]]]) == outerTypeId);
353                    remove_edge(u, mapping[H[t]], G);
354                }
355                add_edge(nullptr, u, x, G);
356            }
357
358            for (const Vertex s : sources) {
359                const auto u = mapping[H[s]];
360                for (const Vertex i : intermediary) {
361                    remove_edge(u, mapping[H[i]], G);
362                }
363                add_edge(nullptr, u, y, G);
364            }
365            add_edge(nullptr, x, y, G);
366
367            for (const Vertex t : sinks) {
368                const auto v = mapping[H[t]];
369                add_edge(getValue(G[v]), y, v, G);
370            }
371
372            summarizeGraph(G, mapping);
373        }
374    }
375}
376
377
378}
Note: See TracBrowser for help on using the repository browser.