source: icGREP/icgrep-devel/icgrep/pablo/pabloAST.cpp @ 4871

Last change on this file since 4871 was 4871, checked in by nmedfort, 4 years ago

Minor improvements to the optimizers and AST manipulation.

File size: 16.7 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pabloAST.h>
8#include <pablo/codegenstate.h>
9#include <llvm/Support/Compiler.h>
10#include <pablo/printer_pablos.h>
11#include <llvm/ADT/SmallVector.h>
12
13namespace pablo {
14
15PabloAST::Allocator PabloAST::mAllocator;
16
17/*
18
19    Return true if expr1 and expr2 can be proven equivalent according to some rules,
20    false otherwise.  Note that false may be returned i some cases when the exprs are
21    equivalent.
22
23*/
24
25/** ------------------------------------------------------------------------------------------------------------- *
26 * @brief equals
27 ** ------------------------------------------------------------------------------------------------------------- */
28bool equals(const PabloAST * expr1, const PabloAST * expr2) {
29    assert (expr1 && expr2);
30    if (expr1 == expr2) {
31        return true;
32    } else if (expr1->getClassTypeId() == expr2->getClassTypeId()) {
33        if ((isa<Zeroes>(expr1)) || (isa<Ones>(expr1))) {
34            return true;
35        } else if (const Var * var1 = dyn_cast<const Var>(expr1)) {
36            if (const Var * var2 = cast<const Var>(expr2)) {
37                return (var1->getName() == var2->getName());
38            }
39        } else if (const Not* not1 = dyn_cast<const Not>(expr1)) {
40            if (const Not* not2 = cast<const Not>(expr2)) {
41                return equals(not1->getExpr(), not2->getExpr());
42            }
43        } else if (const And* and1 = dyn_cast<const And>(expr1)) {
44            if (const And* and2 = cast<const And>(expr2)) {
45                if (equals(and1->getExpr1(), and2->getExpr1())) {
46                    return equals(and1->getExpr2(), and2->getExpr2());
47                } else if (equals(and1->getExpr1(), and2->getExpr2())) {
48                    return equals(and1->getExpr2(), and2->getExpr1());
49                }
50            }
51        } else if (const Or * or1 = dyn_cast<const Or>(expr1)) {
52            if (const Or* or2 = cast<const Or>(expr2)) {
53                if (equals(or1->getExpr1(), or2->getExpr1())) {
54                    return equals(or1->getExpr2(), or2->getExpr2());
55                } else if (equals(or1->getExpr1(), or2->getExpr2())) {
56                    return equals(or1->getExpr2(), or2->getExpr1());
57                }
58            }
59        } else if (const Xor * xor1 = dyn_cast<const Xor>(expr1)) {
60            if (const Xor * xor2 = cast<const Xor>(expr2)) {
61                if (equals(xor1->getExpr1(), xor2->getExpr1())) {
62                    return equals(xor1->getExpr2(), xor2->getExpr2());
63                } else if (equals(xor1->getExpr1(), xor2->getExpr2())) {
64                    return equals(xor1->getExpr2(), xor2->getExpr1());
65                }
66            }
67        } else if (isa<Integer>(expr1) || isa<String>(expr1) || isa<Call>(expr1)) {
68            // If these weren't equivalent by address they won't be equivalent by their operands.
69            return false;
70        } else { // Non-reassociatable functions (i.e., Sel, Advance, ScanThru, MatchStar, Assign, Next)
71            const Statement * stmt1 = cast<Statement>(expr1);
72            const Statement * stmt2 = cast<Statement>(expr2);
73            assert (stmt1->getNumOperands() == stmt2->getNumOperands());
74            for (unsigned i = 0; i != stmt1->getNumOperands(); ++i) {
75                if (!equals(stmt1->getOperand(i), stmt2->getOperand(i))) {
76                    return false;
77                }
78            }
79            return true;
80        }
81    }
82    return false;
83}
84
85/** ------------------------------------------------------------------------------------------------------------- *
86 * @brief replaceAllUsesWith
87 ** ------------------------------------------------------------------------------------------------------------- */
88void PabloAST::replaceAllUsesWith(PabloAST * expr) {
89    if (LLVM_UNLIKELY(this == expr)) {
90        return;
91    }
92    Statement * replacements[mUsers.size()];
93    Vector::size_type users = 0;
94    bool exprIsAUser = false;
95    assert (expr);
96    for (PabloAST * user : mUsers) {
97        if (LLVM_UNLIKELY(user == expr)) {
98            exprIsAUser = true;
99            continue;
100        }
101        replacements[users++] = cast<Statement>(user);
102    }
103    mUsers.clear();
104    if (LLVM_UNLIKELY(exprIsAUser)) {
105        mUsers.push_back(expr);
106    }
107    for (Vector::size_type i = 0; i != users; ++i) {
108        replacements[i]->replaceUsesOfWith(this, expr);
109    }
110}
111
112/** ------------------------------------------------------------------------------------------------------------- *
113 * @brief checkEscapedValueList
114 ** ------------------------------------------------------------------------------------------------------------- */
115template <class ValueType, class ValueList>
116inline void Statement::checkEscapedValueList(Statement * branch, PabloAST * const from, PabloAST * const to, ValueList & list) {
117    if (LLVM_LIKELY(isa<ValueType>(from))) {
118        auto f = std::find(list.begin(), list.end(), cast<ValueType>(from));
119        if (LLVM_LIKELY(f != list.end())) {
120            if (LLVM_LIKELY(isa<ValueType>(to))) {
121                if (std::find(list.begin(), list.end(), cast<ValueType>(to)) == list.end()) {
122                    *f = cast<ValueType>(to);
123                    branch->addUser(to);
124                } else {
125                    list.erase(f);
126                }
127                branch->removeUser(from);
128                assert (std::find(list.begin(), list.end(), cast<ValueType>(to)) != list.end());
129                assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(to)) != branch->user_end());
130            } else {
131                list.erase(f);
132                branch->removeUser(from);
133            }
134        }               
135        assert (std::find(list.begin(), list.end(), cast<ValueType>(from)) == list.end());
136        assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(from)) == branch->user_end());
137    }
138}
139
140/** ------------------------------------------------------------------------------------------------------------- *
141 * @brief replaceUsesOfWith
142 ** ------------------------------------------------------------------------------------------------------------- */
143void Statement::replaceUsesOfWith(PabloAST * const from, PabloAST * const to) {
144    if (LLVM_UNLIKELY(from == to)) {
145        return;
146    }
147    for (unsigned i = 0; i != getNumOperands(); ++i) {
148       if (getOperand(i) == from) {
149           setOperand(i, to);
150       }
151    }
152    if (LLVM_UNLIKELY(isa<If>(this))) {
153        checkEscapedValueList<Assign>(this, from, to, cast<If>(this)->getDefined());
154    } else if (LLVM_UNLIKELY(isa<While>(this))) {
155        checkEscapedValueList<Next>(this, from, to, cast<While>(this)->getVariants());
156    }
157}
158
159/** ------------------------------------------------------------------------------------------------------------- *
160 * @brief setOperand
161 ** ------------------------------------------------------------------------------------------------------------- */
162void Statement::setOperand(const unsigned index, PabloAST * const value) {
163    assert (value);
164    assert (index < getNumOperands());
165    PabloAST * const priorValue = getOperand(index);
166    if (LLVM_UNLIKELY(priorValue == value)) {
167        return;
168    }   
169    if (LLVM_LIKELY(priorValue != nullptr)) {
170        // Test just to be sure that we don't have multiple operands pointing to
171        // what we're replacing. If not, remove this from the prior value's
172        // user list.
173        unsigned count = 0;
174        for (unsigned i = 0; i != getNumOperands(); ++i) {
175            count += (getOperand(i) == priorValue) ? 1 : 0;
176        }
177        assert (count >= 1);
178        if (LLVM_LIKELY(count == 1)) {
179            priorValue->removeUser(this);
180        }
181    }
182    mOperand[index] = value;
183    value->addUser(this);
184}
185
186/** ------------------------------------------------------------------------------------------------------------- *
187 * @brief insertBefore
188 ** ------------------------------------------------------------------------------------------------------------- */
189void Statement::insertBefore(Statement * const statement) {
190    if (LLVM_UNLIKELY(statement == this)) {
191        return;
192    }
193    else if (LLVM_UNLIKELY(statement == nullptr)) {
194        throw std::runtime_error("cannot insert before null statement!");
195    }
196    else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
197        throw std::runtime_error("statement is not contained in a pablo block!");
198    }
199    removeFromParent();
200    mParent = statement->mParent;
201    if (LLVM_UNLIKELY(mParent->mFirst == statement)) {
202        mParent->mFirst = this;
203    }
204    mNext = statement;
205    mPrev = statement->mPrev;
206    statement->mPrev = this;
207    if (LLVM_LIKELY(mPrev != nullptr)) {
208        mPrev->mNext = this;
209    }
210    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
211        PabloBlock * body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
212        body->setParent(mParent);
213        mParent->addUser(body);
214    }
215}
216
217/** ------------------------------------------------------------------------------------------------------------- *
218 * @brief insertAfter
219 ** ------------------------------------------------------------------------------------------------------------- */
220void Statement::insertAfter(Statement * const statement) {
221    if (LLVM_UNLIKELY(statement == this)) {
222        return;
223    } else if (LLVM_UNLIKELY(statement == nullptr)) {
224        throw std::runtime_error("cannot insert after null statement!");
225    } else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
226        throw std::runtime_error("statement is not contained in a pablo block!");
227    }
228    removeFromParent();
229    mParent = statement->mParent;
230    if (LLVM_UNLIKELY(mParent->mLast == statement)) {
231        mParent->mLast = this;
232    }
233    mPrev = statement;
234    mNext = statement->mNext;
235    statement->mNext = this;
236    if (LLVM_LIKELY(mNext != nullptr)) {
237        mNext->mPrev = this;
238    }
239    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
240        PabloBlock * body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
241        body->setParent(mParent);
242        mParent->addUser(body);
243    }
244}
245
246/** ------------------------------------------------------------------------------------------------------------- *
247 * @brief removeFromParent
248 ** ------------------------------------------------------------------------------------------------------------- */
249Statement * Statement::removeFromParent() {
250    Statement * next = mNext;
251    if (LLVM_LIKELY(mParent != nullptr)) {
252        if (LLVM_UNLIKELY(mParent->mFirst == this)) {
253            mParent->mFirst = mNext;
254        }
255        if (LLVM_UNLIKELY(mParent->mLast == this)) {
256            mParent->mLast = mPrev;
257        }
258        if (LLVM_UNLIKELY(mParent->mInsertionPoint == this)) {
259            mParent->mInsertionPoint = mPrev;
260        }
261        if (LLVM_LIKELY(mPrev != nullptr)) {
262            mPrev->mNext = mNext;
263        }
264        if (LLVM_LIKELY(mNext != nullptr)) {
265            mNext->mPrev = mPrev;
266        }
267        if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
268            PabloBlock * body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
269            body->setParent(nullptr);
270            mParent->removeUser(body);
271        }
272    }
273    mPrev = nullptr;
274    mNext = nullptr;
275    mParent = nullptr;
276    return next;
277}
278
279/** ------------------------------------------------------------------------------------------------------------- *
280 * @brief eraseFromParent
281 ** ------------------------------------------------------------------------------------------------------------- */
282Statement * Statement::eraseFromParent(const bool recursively) {
283    // remove this statement from its operands' users list
284    for (unsigned i = 0; i != mOperands; ++i) {
285        mOperand[i]->removeUser(this);
286    }
287    SmallVector<Statement *, 1> redundantBranches;
288    // If this is an If or While statement, we'll have to remove the statements within the
289    // body or we'll lose track of them.
290    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
291        PabloBlock * const body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
292        body->eraseFromParent(recursively);
293    } else if (LLVM_UNLIKELY(isa<Assign>(this))) {
294        for (PabloAST * use : mUsers) {
295            if (If * ifNode = dyn_cast<If>(use)) {
296                auto & defs = ifNode->getDefined();
297                auto f = std::find(defs.begin(), defs.end(), this);
298                if (LLVM_LIKELY(f != defs.end())) {
299                    this->removeUser(ifNode);
300                    ifNode->removeUser(this);
301                    defs.erase(f);
302                    if (LLVM_UNLIKELY(defs.empty())) {
303                        redundantBranches.push_back(ifNode);
304                    }
305                }
306            }
307        }
308    } else if (LLVM_UNLIKELY(isa<Next>(this))) {
309        for (PabloAST * use : mUsers) {
310            if (While * whileNode = dyn_cast<While>(use)) {
311                auto & vars = whileNode->getVariants();
312                auto f = std::find(vars.begin(), vars.end(), this);
313                if (LLVM_LIKELY(f != vars.end())) {
314                    this->removeUser(whileNode);
315                    whileNode->removeUser(this);
316                    vars.erase(f);
317                    if (LLVM_UNLIKELY(vars.empty())) {
318                        redundantBranches.push_back(whileNode);
319                    }
320                }
321            }
322        }
323    }
324
325    replaceAllUsesWith(PabloBlock::createZeroes());
326
327    if (recursively) {
328        for (unsigned i = 0; i != mOperands; ++i) {
329            PabloAST * const op = mOperand[i];
330            if (LLVM_LIKELY(isa<Statement>(op))) {
331                if (op->getNumUses() == 0) {
332                    cast<Statement>(op)->eraseFromParent(true);
333                }
334            }
335        }
336        if (LLVM_UNLIKELY(redundantBranches.size() != 0)) {
337            // By eliminating this redundant branch, we may inadvertantly delete the scope block this statement
338            // resides within. Check and return null if so.
339            bool eliminatedScope = false;
340            for (Statement * br : redundantBranches) {
341                const PabloBlock * const body = isa<If>(br) ? cast<If>(br)->getBody() : cast<While>(br)->getBody();
342                if (LLVM_UNLIKELY(body == getParent())) {
343                    eliminatedScope = true;
344                }
345                br->eraseFromParent(true);
346            }
347            if (eliminatedScope) {
348                return nullptr;
349            }
350        }
351    }
352    Statement * const next = removeFromParent();
353    mAllocator.deallocate(reinterpret_cast<Allocator::pointer>(this));
354    return next;
355}
356
357/** ------------------------------------------------------------------------------------------------------------- *
358 * @brief replaceWith
359 ** ------------------------------------------------------------------------------------------------------------- */
360Statement * Statement::replaceWith(PabloAST * const expr, const bool rename, const bool recursively) {
361    assert (expr);
362    if (LLVM_UNLIKELY(expr == this)) {
363        return getNextNode();
364    }
365    if (LLVM_LIKELY(rename && isa<Statement>(expr))) {
366        Statement * const stmt = cast<Statement>(expr);
367        if (getName()->isUserDefined() && stmt->getName()->isGenerated()) {
368            stmt->setName(getName());
369        }
370    }
371    replaceAllUsesWith(expr);   
372    return eraseFromParent(recursively);
373}
374
375/** ------------------------------------------------------------------------------------------------------------- *
376 * @brief contains
377 ** ------------------------------------------------------------------------------------------------------------- */
378bool StatementList::contains(Statement * const statement) {
379    for (Statement * stmt : *this) {
380        if (statement == stmt) {
381            return true;
382        }
383    }
384    return false;
385}
386
387StatementList::~StatementList() {
388
389}
390
391/** ------------------------------------------------------------------------------------------------------------- *
392 * @brief escapes
393 *
394 * Is this statement used outside of its scope?
395 ** ------------------------------------------------------------------------------------------------------------- */
396bool escapes(const Statement * statement) {
397    const PabloBlock * const parent = statement->getParent();
398    for (const PabloAST * user : statement->users()) {
399        if (LLVM_LIKELY(isa<Statement>(user))) {
400            const PabloBlock * used = cast<Statement>(user)->getParent();
401            while (used != parent) {
402                used = used->getParent();
403                if (used == nullptr) {
404                    assert (isa<Assign>(statement) || isa<Next>(statement));
405                    return true;
406                }
407            }
408        }
409    }
410    return false;
411}
412
413}
Note: See TracBrowser for help on using the repository browser.