source: icGREP/icgrep-devel/icgrep/pablo/pabloAST.cpp @ 4861

Last change on this file since 4861 was 4861, checked in by nmedfort, 4 years ago

Work on better scheduling in reassociation pass.

File size: 17.4 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pabloAST.h>
8#include <pablo/codegenstate.h>
9#include <llvm/Support/Compiler.h>
10#include <pablo/printer_pablos.h>
11#include <iostream>
12
13namespace pablo {
14
15PabloAST::Allocator PabloAST::mAllocator;
16
17/*
18
19    Return true if expr1 and expr2 can be proven equivalent according to some rules,
20    false otherwise.  Note that false may be returned i some cases when the exprs are
21    equivalent.
22
23*/
24
25/** ------------------------------------------------------------------------------------------------------------- *
26 * @brief equals
27 ** ------------------------------------------------------------------------------------------------------------- */
28bool equals(const PabloAST * expr1, const PabloAST * expr2) {
29    assert (expr1 && expr2);
30    if (expr1 == expr2) {
31        return true;
32    } else if (expr1->getClassTypeId() == expr2->getClassTypeId()) {
33        if ((isa<Zeroes>(expr1)) || (isa<Ones>(expr1))) {
34            return true;
35        } else if (const Var * var1 = dyn_cast<const Var>(expr1)) {
36            if (const Var * var2 = cast<const Var>(expr2)) {
37                return (var1->getName() == var2->getName());
38            }
39        } else if (const Not* not1 = dyn_cast<const Not>(expr1)) {
40            if (const Not* not2 = cast<const Not>(expr2)) {
41                return equals(not1->getExpr(), not2->getExpr());
42            }
43        } else if (const And* and1 = dyn_cast<const And>(expr1)) {
44            if (const And* and2 = cast<const And>(expr2)) {
45                if (equals(and1->getExpr1(), and2->getExpr1())) {
46                    return equals(and1->getExpr2(), and2->getExpr2());
47                } else if (equals(and1->getExpr1(), and2->getExpr2())) {
48                    return equals(and1->getExpr2(), and2->getExpr1());
49                }
50            }
51        } else if (const Or * or1 = dyn_cast<const Or>(expr1)) {
52            if (const Or* or2 = cast<const Or>(expr2)) {
53                if (equals(or1->getExpr1(), or2->getExpr1())) {
54                    return equals(or1->getExpr2(), or2->getExpr2());
55                } else if (equals(or1->getExpr1(), or2->getExpr2())) {
56                    return equals(or1->getExpr2(), or2->getExpr1());
57                }
58            }
59        } else if (const Xor * xor1 = dyn_cast<const Xor>(expr1)) {
60            if (const Xor * xor2 = cast<const Xor>(expr2)) {
61                if (equals(xor1->getExpr1(), xor2->getExpr1())) {
62                    return equals(xor1->getExpr2(), xor2->getExpr2());
63                } else if (equals(xor1->getExpr1(), xor2->getExpr2())) {
64                    return equals(xor1->getExpr2(), xor2->getExpr1());
65                }
66            }
67        } else if (isa<Integer>(expr1) || isa<String>(expr1) || isa<Call>(expr1)) {
68            // If these weren't equivalent by address they won't be equivalent by their operands.
69            return false;
70        } else { // Non-reassociatable functions (i.e., Sel, Advance, ScanThru, MatchStar, Assign, Next)
71            const Statement * stmt1 = cast<Statement>(expr1);
72            const Statement * stmt2 = cast<Statement>(expr2);
73            assert (stmt1->getNumOperands() == stmt2->getNumOperands());
74            for (unsigned i = 0; i != stmt1->getNumOperands(); ++i) {
75                if (!equals(stmt1->getOperand(i), stmt2->getOperand(i))) {
76                    return false;
77                }
78            }
79            return true;
80        }
81    }
82    return false;
83}
84
85/** ------------------------------------------------------------------------------------------------------------- *
86 * @brief replaceAllUsesWith
87 ** ------------------------------------------------------------------------------------------------------------- */
88void PabloAST::replaceAllUsesWith(PabloAST * expr) {   
89    Statement * user[mUsers.size()];
90    Vector::size_type users = 0;
91    for (PabloAST * u : mUsers) {
92        if (isa<Statement>(u) && u != expr) {
93            user[users++] = cast<Statement>(u);
94        }
95    }
96    mUsers.clear();
97    assert (expr);
98    for (Vector::size_type i = 0; i != users; ++i) {
99        user[i]->replaceUsesOfWith(this, expr);
100    }
101}
102
103/** ------------------------------------------------------------------------------------------------------------- *
104 * @brief checkForReplacementInEscapedValueList
105 ** ------------------------------------------------------------------------------------------------------------- */
106template <class ValueType, class ValueList>
107inline void Statement::checkForReplacementInEscapedValueList(Statement * branch, PabloAST * const from, PabloAST * const to, ValueList & list) {
108    if (LLVM_LIKELY(isa<ValueType>(from))) {
109        auto f = std::find(list.begin(), list.end(), cast<ValueType>(from));
110        if (LLVM_LIKELY(f != list.end())) {
111            if (LLVM_LIKELY(isa<ValueType>(to))) {
112                if (std::find(list.begin(), list.end(), cast<ValueType>(to)) == list.end()) {
113                    *f = cast<ValueType>(to);
114                    branch->addUser(to);
115                } else {
116                    list.erase(f);
117                }
118                branch->removeUser(from);
119                assert (std::find(list.begin(), list.end(), cast<ValueType>(to)) != list.end());
120                assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(to)) != branch->user_end());
121            } else { // replacement error occured
122                std::string tmp;
123                raw_string_ostream str(tmp);
124                str << "cannot replace escaped value ";
125                PabloPrinter::print(from, str);
126                str << " with ";
127                PabloPrinter::print(to, str);
128                str << " in ";
129                PabloPrinter::print(branch, str);
130                throw std::runtime_error(str.str());
131            }
132        }               
133        assert (std::find(list.begin(), list.end(), cast<ValueType>(from)) == list.end());
134        assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(from)) == branch->user_end());
135    }
136}
137
138/** ------------------------------------------------------------------------------------------------------------- *
139 * @brief replaceUsesOfWith
140 ** ------------------------------------------------------------------------------------------------------------- */
141void Statement::replaceUsesOfWith(PabloAST * const from, PabloAST * const to) {
142    for (unsigned i = 0; i != getNumOperands(); ++i) {
143       if (getOperand(i) == from) {
144           setOperand(i, to);
145       }
146    }
147    if (LLVM_UNLIKELY(isa<If>(this))) {
148        checkForReplacementInEscapedValueList<Assign>(this, from, to, cast<If>(this)->getDefined());
149    } else if (LLVM_UNLIKELY(isa<While>(this))) {
150        checkForReplacementInEscapedValueList<Next>(this, from, to, cast<While>(this)->getVariants());
151    }
152}
153
154/** ------------------------------------------------------------------------------------------------------------- *
155 * @brief setOperand
156 ** ------------------------------------------------------------------------------------------------------------- */
157void Statement::setOperand(const unsigned index, PabloAST * const value) {
158    assert (value);
159    assert (index < getNumOperands());
160    PabloAST * const priorValue = getOperand(index);
161    if (LLVM_UNLIKELY(priorValue == value)) {
162        return;
163    }   
164    if (LLVM_LIKELY(priorValue != nullptr)) {
165        // Test just to be sure that we don't have multiple operands pointing to
166        // what we're replacing. If not, remove this from the prior value's
167        // user list.
168        unsigned count = 0;
169        for (unsigned i = 0; i != getNumOperands(); ++i) {
170            count += (getOperand(i) == priorValue) ? 1 : 0;
171        }
172        assert (count >= 1);
173        if (LLVM_LIKELY(count == 1)) {
174            priorValue->removeUser(this);
175        }
176    }
177    mOperand[index] = value;
178    value->addUser(this);
179}
180
181/** ------------------------------------------------------------------------------------------------------------- *
182 * @brief insertBefore
183 ** ------------------------------------------------------------------------------------------------------------- */
184void Statement::insertBefore(Statement * const statement) {
185    if (LLVM_UNLIKELY(statement == this)) {
186        return;
187    }
188    else if (LLVM_UNLIKELY(statement == nullptr)) {
189        throw std::runtime_error("cannot insert before null statement!");
190    }
191    else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
192        throw std::runtime_error("statement is not contained in a pablo block!");
193    }
194    removeFromParent();
195    mParent = statement->mParent;
196    if (LLVM_UNLIKELY(mParent->mFirst == statement)) {
197        mParent->mFirst = this;
198    }
199    mNext = statement;
200    mPrev = statement->mPrev;
201    statement->mPrev = this;
202    if (LLVM_LIKELY(mPrev != nullptr)) {
203        mPrev->mNext = this;
204    }
205    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
206        PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
207        mParent->addUser(&body);
208    }
209}
210
211/** ------------------------------------------------------------------------------------------------------------- *
212 * @brief insertAfter
213 ** ------------------------------------------------------------------------------------------------------------- */
214void Statement::insertAfter(Statement * const statement) {
215    if (LLVM_UNLIKELY(statement == this)) {
216        return;
217    }
218    else if (LLVM_UNLIKELY(statement == nullptr)) {
219        throw std::runtime_error("cannot insert after null statement!");
220    }
221    else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
222        throw std::runtime_error("statement is not contained in a pablo block!");
223    }
224    removeFromParent();
225    mParent = statement->mParent;
226    if (LLVM_UNLIKELY(mParent->mLast == statement)) {
227        mParent->mLast = this;
228    }
229    mPrev = statement;
230    mNext = statement->mNext;
231    statement->mNext = this;
232    if (LLVM_LIKELY(mNext != nullptr)) {
233        mNext->mPrev = this;
234    }
235    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
236        PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
237        mParent->addUser(&body);
238    }
239}
240
241/** ------------------------------------------------------------------------------------------------------------- *
242 * @brief removeFromParent
243 ** ------------------------------------------------------------------------------------------------------------- */
244Statement * Statement::removeFromParent() {
245    Statement * next = mNext;
246    if (LLVM_LIKELY(mParent != nullptr)) {
247        if (LLVM_UNLIKELY(mParent->mFirst == this)) {
248            mParent->mFirst = mNext;
249        }
250        if (LLVM_UNLIKELY(mParent->mLast == this)) {
251            mParent->mLast = mPrev;
252        }
253        if (LLVM_UNLIKELY(mParent->mInsertionPoint == this)) {
254            mParent->mInsertionPoint = mPrev;
255        }
256        if (LLVM_LIKELY(mPrev != nullptr)) {
257            mPrev->mNext = mNext;
258        }
259        if (LLVM_LIKELY(mNext != nullptr)) {
260            mNext->mPrev = mPrev;
261        }
262        if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
263            PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
264            mParent->removeUser(&body);
265        }
266    }
267    mPrev = nullptr;
268    mNext = nullptr;
269    mParent = nullptr;
270    return next;
271}
272
273/** ------------------------------------------------------------------------------------------------------------- *
274 * @brief eraseFromParent
275 ** ------------------------------------------------------------------------------------------------------------- */
276Statement * Statement::eraseFromParent(const bool recursively) {
277    // remove this statement from its operands' users list
278    for (unsigned i = 0; i != mOperands; ++i) {
279        mOperand[i]->removeUser(this);
280    }
281    Statement * redundantBranch = nullptr;
282    // If this is an If or While statement, we'll have to remove the statements within the
283    // body or we'll lose track of them.
284    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
285        PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
286        Statement * stmt = body.front();
287        // Note: by erasing the body, any Assign/Next nodes will be replaced with Zero.
288        while (stmt) {
289            stmt = stmt->eraseFromParent(recursively);
290        }
291    } else if (LLVM_UNLIKELY(isa<Assign>(this))) {
292        for (PabloAST * use : mUsers) {
293            if (If * ifNode = dyn_cast<If>(use)) {
294                auto & defs = ifNode->getDefined();
295                auto f = std::find(defs.begin(), defs.end(), this);
296                if (LLVM_LIKELY(f != defs.end())) {
297                    this->removeUser(ifNode);
298                    ifNode->removeUser(this);
299                    defs.erase(f);
300                    if (LLVM_UNLIKELY(defs.empty())) {
301                        redundantBranch = ifNode;
302                    }
303                    break;
304                }
305            }
306        }
307    } else if (LLVM_UNLIKELY(isa<Next>(this))) {
308        for (PabloAST * use : mUsers) {
309            if (While * whileNode = dyn_cast<While>(use)) {
310                auto & vars = whileNode->getVariants();
311                auto f = std::find(vars.begin(), vars.end(), this);
312                if (LLVM_LIKELY(f != vars.end())) {
313                    this->removeUser(whileNode);
314                    whileNode->removeUser(this);
315                    vars.erase(f);
316                    if (LLVM_UNLIKELY(vars.empty())) {
317                        redundantBranch = whileNode;
318                    }
319                    break;
320                }
321            }
322        }
323    }
324
325    replaceAllUsesWith(PabloBlock::createZeroes());
326
327    if (recursively) {
328        for (unsigned i = 0; i != mOperands; ++i) {
329            PabloAST * const op = mOperand[i];
330            if (LLVM_LIKELY(isa<Statement>(op))) {
331                bool erase = false;
332                if (op->getNumUses() == 0) {
333                    erase = true;
334                } else if ((isa<Assign>(op) || isa<Next>(op)) && op->getNumUses() == 1) {
335                    erase = true;
336                }
337                if (erase) {
338                    cast<Statement>(op)->eraseFromParent(true);
339                }
340            }
341        }
342        if (LLVM_UNLIKELY(redundantBranch != nullptr)) {
343            redundantBranch->eraseFromParent(true);
344        }
345    }
346
347    Statement * const next = removeFromParent();
348    mAllocator.deallocate(reinterpret_cast<Allocator::pointer>(this));
349    return next;
350}
351
352/** ------------------------------------------------------------------------------------------------------------- *
353 * @brief replaceWith
354 ** ------------------------------------------------------------------------------------------------------------- */
355Statement * Statement::replaceWith(PabloAST * const expr, const bool rename, const bool recursively) {
356    assert (expr);
357    if (LLVM_UNLIKELY(expr == this)) {
358        return getNextNode();
359    }
360    if (LLVM_LIKELY(rename && isa<Statement>(expr))) {
361        Statement * const stmt = cast<Statement>(expr);
362        if (getName()->isUserDefined() && stmt->getName()->isGenerated()) {
363            stmt->setName(getName());
364        }
365    }
366    replaceAllUsesWith(expr);   
367    return eraseFromParent(recursively);
368}
369
370/** ------------------------------------------------------------------------------------------------------------- *
371 * @brief contains
372 ** ------------------------------------------------------------------------------------------------------------- */
373bool StatementList::contains(Statement * const statement) {
374    for (Statement * stmt : *this) {
375        if (statement == stmt) {
376            return true;
377        }
378    }
379    return false;
380}
381
382/** ------------------------------------------------------------------------------------------------------------- *
383 * @brief clear
384 ** ------------------------------------------------------------------------------------------------------------- */
385void StatementList::clear() {
386    Statement * stmt = front();
387    while (stmt) {
388        Statement * next = stmt->mNext;
389        if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
390            PabloBlock & body = isa<If>(stmt) ? cast<If>(stmt)->getBody() : cast<While>(stmt)->getBody();
391            stmt->mParent->removeUser(&body);
392        }
393        stmt->mPrev = nullptr;
394        stmt->mNext = nullptr;
395        stmt->mParent = nullptr;
396        stmt = next;
397    }
398    mInsertionPoint = nullptr;
399    mFirst = nullptr;
400    mLast = nullptr;
401}
402
403StatementList::~StatementList() {
404
405}
406
407/** ------------------------------------------------------------------------------------------------------------- *
408 * @brief escapes
409 *
410 * Is this statement used outside of its scope?
411 ** ------------------------------------------------------------------------------------------------------------- */
412bool escapes(const Statement * statement) {
413    const PabloBlock * const parent = statement->getParent();
414    for (const PabloAST * user : statement->users()) {
415        if (LLVM_LIKELY(isa<Statement>(user))) {
416            const PabloBlock * used = cast<Statement>(user)->getParent();
417            while (used != parent) {
418                used = used->getParent();
419                if (used == nullptr) {
420                    assert (isa<Assign>(statement) || isa<Next>(statement));
421                    return true;
422                }
423            }
424        }
425    }
426    return false;
427}
428
429}
Note: See TracBrowser for help on using the repository browser.