source: icGREP/icgrep-devel/icgrep/pablo/pabloAST.cpp @ 4856

Last change on this file since 4856 was 4856, checked in by nmedfort, 3 years ago

Bug fix for use-def correctness regarding escaping values of If and While nodes.

File size: 16.5 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pabloAST.h>
8#include <pablo/codegenstate.h>
9#include <llvm/Support/Compiler.h>
10#include <pablo/printer_pablos.h>
11#include <iostream>
12
13namespace pablo {
14
15PabloAST::Allocator PabloAST::mAllocator;
16PabloAST::VectorAllocator PabloAST::mVectorAllocator;
17
18/*
19
20    Return true if expr1 and expr2 can be proven equivalent according to some rules,
21    false otherwise.  Note that false may be returned i some cases when the exprs are
22    equivalent.
23
24*/
25
26/** ------------------------------------------------------------------------------------------------------------- *
27 * @brief equals
28 ** ------------------------------------------------------------------------------------------------------------- */
29bool equals(const PabloAST * expr1, const PabloAST * expr2) {
30    assert (expr1 && expr2);
31    if (expr1 == expr2) {
32        return true;
33    } else if (expr1->getClassTypeId() == expr2->getClassTypeId()) {
34        if ((isa<Zeroes>(expr1)) || (isa<Ones>(expr1))) {
35            return true;
36        } else if (const Var * var1 = dyn_cast<const Var>(expr1)) {
37            if (const Var * var2 = cast<const Var>(expr2)) {
38                return (var1->getName() == var2->getName());
39            }
40        } else if (const Not* not1 = dyn_cast<const Not>(expr1)) {
41            if (const Not* not2 = cast<const Not>(expr2)) {
42                return equals(not1->getExpr(), not2->getExpr());
43            }
44        } else if (const And* and1 = dyn_cast<const And>(expr1)) {
45            if (const And* and2 = cast<const And>(expr2)) {
46                if (equals(and1->getExpr1(), and2->getExpr1())) {
47                    return equals(and1->getExpr2(), and2->getExpr2());
48                } else if (equals(and1->getExpr1(), and2->getExpr2())) {
49                    return equals(and1->getExpr2(), and2->getExpr1());
50                }
51            }
52        } else if (const Or * or1 = dyn_cast<const Or>(expr1)) {
53            if (const Or* or2 = cast<const Or>(expr2)) {
54                if (equals(or1->getExpr1(), or2->getExpr1())) {
55                    return equals(or1->getExpr2(), or2->getExpr2());
56                } else if (equals(or1->getExpr1(), or2->getExpr2())) {
57                    return equals(or1->getExpr2(), or2->getExpr1());
58                }
59            }
60        } else if (const Xor * xor1 = dyn_cast<const Xor>(expr1)) {
61            if (const Xor * xor2 = cast<const Xor>(expr2)) {
62                if (equals(xor1->getExpr1(), xor2->getExpr1())) {
63                    return equals(xor1->getExpr2(), xor2->getExpr2());
64                } else if (equals(xor1->getExpr1(), xor2->getExpr2())) {
65                    return equals(xor1->getExpr2(), xor2->getExpr1());
66                }
67            }
68        } else if (isa<Integer>(expr1) || isa<String>(expr1) || isa<Call>(expr1)) {
69            // If these weren't equivalent by address they won't be equivalent by their operands.
70            return false;
71        } else { // Non-reassociatable functions (i.e., Sel, Advance, ScanThru, MatchStar, Assign, Next)
72            const Statement * stmt1 = cast<Statement>(expr1);
73            const Statement * stmt2 = cast<Statement>(expr2);
74            assert (stmt1->getNumOperands() == stmt2->getNumOperands());
75            for (unsigned i = 0; i != stmt1->getNumOperands(); ++i) {
76                if (!equals(stmt1->getOperand(i), stmt2->getOperand(i))) {
77                    return false;
78                }
79            }
80            return true;
81        }
82    }
83    return false;
84}
85
86/** ------------------------------------------------------------------------------------------------------------- *
87 * @brief replaceAllUsesWith
88 ** ------------------------------------------------------------------------------------------------------------- */
89void PabloAST::replaceAllUsesWith(PabloAST * expr) {   
90    Statement * user[mUsers.size()];
91    Vector::size_type users = 0;
92    for (PabloAST * u : mUsers) {
93        if (isa<Statement>(u) && u != expr) {
94            user[users++] = cast<Statement>(u);
95        }
96    }
97    mUsers.clear();
98    assert (expr);
99    for (Vector::size_type i = 0; i != users; ++i) {
100        user[i]->replaceUsesOfWith(this, expr);
101    }
102}
103
104/** ------------------------------------------------------------------------------------------------------------- *
105 * @brief checkForReplacementInEscapedValueList
106 ** ------------------------------------------------------------------------------------------------------------- */
107template <class ValueType, class ValueList>
108inline void Statement::checkForReplacementInEscapedValueList(Statement * branch, PabloAST * const from, PabloAST * const to, ValueList & list) {
109    if (LLVM_LIKELY(isa<ValueType>(from))) {
110        auto f = std::find(list.begin(), list.end(), cast<ValueType>(from));
111        if (LLVM_LIKELY(f != list.end())) {
112            if (LLVM_LIKELY(isa<ValueType>(to))) {
113                if (std::find(list.begin(), list.end(), cast<ValueType>(to)) == list.end()) {
114                    *f = cast<ValueType>(to);
115                    branch->addUser(to);
116                } else {
117                    list.erase(f);
118                }
119                branch->removeUser(from);
120                assert (std::find(list.begin(), list.end(), cast<ValueType>(to)) != list.end());
121                assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(to)) != branch->user_end());
122            } else { // replacement error occured
123                std::string tmp;
124                raw_string_ostream str(tmp);
125                str << "cannot replace escaped value ";
126                PabloPrinter::print(from, str);
127                str << " with ";
128                PabloPrinter::print(to, str);
129                str << " in ";
130                PabloPrinter::print(branch, str);
131                throw std::runtime_error(str.str());
132            }
133        }               
134        assert (std::find(list.begin(), list.end(), cast<ValueType>(from)) == list.end());
135        assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(from)) == branch->user_end());
136    }
137}
138
139/** ------------------------------------------------------------------------------------------------------------- *
140 * @brief replaceUsesOfWith
141 ** ------------------------------------------------------------------------------------------------------------- */
142void Statement::replaceUsesOfWith(PabloAST * const from, PabloAST * const to) {
143    for (unsigned i = 0; i != getNumOperands(); ++i) {
144       if (getOperand(i) == from) {
145           setOperand(i, to);
146       }
147    }
148    if (LLVM_UNLIKELY(isa<If>(this))) {
149        checkForReplacementInEscapedValueList<Assign>(this, from, to, cast<If>(this)->getDefined());
150    } else if (LLVM_UNLIKELY(isa<While>(this))) {
151        checkForReplacementInEscapedValueList<Next>(this, from, to, cast<While>(this)->getVariants());
152    }
153}
154
155/** ------------------------------------------------------------------------------------------------------------- *
156 * @brief setOperand
157 ** ------------------------------------------------------------------------------------------------------------- */
158void Statement::setOperand(const unsigned index, PabloAST * const value) {
159    assert (value);
160    assert (index < getNumOperands());
161    PabloAST * const priorValue = getOperand(index);
162    if (LLVM_UNLIKELY(priorValue == value)) {
163        return;
164    }   
165    if (LLVM_LIKELY(priorValue != nullptr)) {
166        // Test just to be sure that we don't have multiple operands pointing to
167        // what we're replacing. If not, remove this from the prior value's
168        // user list.
169        unsigned count = 0;
170        for (unsigned i = 0; i != getNumOperands(); ++i) {
171            count += (getOperand(i) == priorValue) ? 1 : 0;
172        }
173        assert (count >= 1);
174        if (LLVM_LIKELY(count == 1)) {
175            priorValue->removeUser(this);
176        }
177    }
178    mOperand[index] = value;
179    value->addUser(this);
180}
181
182/** ------------------------------------------------------------------------------------------------------------- *
183 * @brief insertBefore
184 ** ------------------------------------------------------------------------------------------------------------- */
185void Statement::insertBefore(Statement * const statement) {
186    if (LLVM_UNLIKELY(statement == this)) {
187        return;
188    }
189    else if (LLVM_UNLIKELY(statement == nullptr)) {
190        throw std::runtime_error("cannot insert before null statement!");
191    }
192    else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
193        throw std::runtime_error("statement is not contained in a pablo block!");
194    }
195    removeFromParent();
196    mParent = statement->mParent;
197    if (LLVM_UNLIKELY(mParent->mFirst == statement)) {
198        mParent->mFirst = this;
199    }
200    mNext = statement;
201    mPrev = statement->mPrev;
202    statement->mPrev = this;
203    if (LLVM_LIKELY(mPrev != nullptr)) {
204        mPrev->mNext = this;
205    }
206    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
207        PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
208        mParent->addUser(&body);
209    }
210}
211
212/** ------------------------------------------------------------------------------------------------------------- *
213 * @brief insertAfter
214 ** ------------------------------------------------------------------------------------------------------------- */
215void Statement::insertAfter(Statement * const statement) {
216    if (LLVM_UNLIKELY(statement == this)) {
217        return;
218    }
219    else if (LLVM_UNLIKELY(statement == nullptr)) {
220        throw std::runtime_error("cannot insert after null statement!");
221    }
222    else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
223        throw std::runtime_error("statement is not contained in a pablo block!");
224    }
225    removeFromParent();
226    mParent = statement->mParent;
227    if (LLVM_UNLIKELY(mParent->mLast == statement)) {
228        mParent->mLast = this;
229    }
230    mPrev = statement;
231    mNext = statement->mNext;
232    statement->mNext = this;
233    if (LLVM_LIKELY(mNext != nullptr)) {
234        mNext->mPrev = this;
235    }
236    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
237        PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
238        mParent->addUser(&body);
239    }
240}
241
242/** ------------------------------------------------------------------------------------------------------------- *
243 * @brief removeFromParent
244 ** ------------------------------------------------------------------------------------------------------------- */
245Statement * Statement::removeFromParent() {
246    Statement * next = mNext;
247    if (LLVM_LIKELY(mParent != nullptr)) {
248        if (LLVM_UNLIKELY(mParent->mFirst == this)) {
249            mParent->mFirst = mNext;
250        }
251        if (LLVM_UNLIKELY(mParent->mLast == this)) {
252            mParent->mLast = mPrev;
253        }
254        if (LLVM_UNLIKELY(mParent->mInsertionPoint == this)) {
255            mParent->mInsertionPoint = mPrev;
256        }
257        if (LLVM_LIKELY(mPrev != nullptr)) {
258            mPrev->mNext = mNext;
259        }
260        if (LLVM_LIKELY(mNext != nullptr)) {
261            mNext->mPrev = mPrev;
262        }
263        if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
264            PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
265            mParent->removeUser(&body);
266        }
267    }
268    mPrev = nullptr;
269    mNext = nullptr;
270    mParent = nullptr;
271    return next;
272}
273
274/** ------------------------------------------------------------------------------------------------------------- *
275 * @brief eraseFromParent
276 ** ------------------------------------------------------------------------------------------------------------- */
277Statement * Statement::eraseFromParent(const bool recursively) {
278    // remove this statement from its operands' users list
279    for (unsigned i = 0; i != mOperands; ++i) {
280        mOperand[i]->removeUser(this);
281    }
282    Statement * redundantBranch = nullptr;
283    // If this is an If or While statement, we'll have to remove the statements within the
284    // body or we'll lose track of them.
285    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
286        PabloBlock & body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
287        Statement * stmt = body.front();
288        // Note: by erasing the body, any Assign/Next nodes will be replaced with Zero.
289        while (stmt) {
290            stmt = stmt->eraseFromParent(recursively);
291        }
292    } else if (LLVM_UNLIKELY(isa<Assign>(this))) {
293        for (PabloAST * use : mUsers) {
294            if (If * ifNode = dyn_cast<If>(use)) {
295                auto & defs = ifNode->getDefined();
296                auto f = std::find(defs.begin(), defs.end(), this);
297                if (LLVM_LIKELY(f != defs.end())) {
298                    this->removeUser(ifNode);
299                    ifNode->removeUser(this);
300                    defs.erase(f);
301                    if (LLVM_UNLIKELY(defs.empty())) {
302                        redundantBranch = ifNode;
303                    }
304                    break;
305                }
306            }
307        }
308    } else if (LLVM_UNLIKELY(isa<Next>(this))) {
309        for (PabloAST * use : mUsers) {
310            if (While * whileNode = dyn_cast<While>(use)) {
311                auto & vars = whileNode->getVariants();
312                auto f = std::find(vars.begin(), vars.end(), this);
313                if (LLVM_LIKELY(f != vars.end())) {
314                    this->removeUser(whileNode);
315                    whileNode->removeUser(this);
316                    vars.erase(f);
317                    if (LLVM_UNLIKELY(vars.empty())) {
318                        redundantBranch = whileNode;
319                    }
320                    break;
321                }
322            }
323        }
324    }
325
326    replaceAllUsesWith(PabloBlock::createZeroes());
327
328    if (recursively) {
329        for (unsigned i = 0; i != mOperands; ++i) {
330            PabloAST * const op = mOperand[i];
331            if (LLVM_LIKELY(isa<Statement>(op))) {
332                bool erase = false;
333                if (op->getNumUses() == 0) {
334                    erase = true;
335                } else if ((isa<Assign>(op) || isa<Next>(op)) && op->getNumUses() == 1) {
336                    erase = true;
337                }
338                if (erase) {
339                    cast<Statement>(op)->eraseFromParent(true);
340                }
341            }
342        }
343        if (LLVM_UNLIKELY(redundantBranch != nullptr)) {
344            redundantBranch->eraseFromParent(true);
345        }
346    }
347
348    return removeFromParent();
349}
350
351/** ------------------------------------------------------------------------------------------------------------- *
352 * @brief replaceWith
353 ** ------------------------------------------------------------------------------------------------------------- */
354Statement * Statement::replaceWith(PabloAST * const expr, const bool rename, const bool recursively) {
355    assert (expr);
356    if (LLVM_UNLIKELY(expr == this)) {
357        return getNextNode();
358    }
359    if (LLVM_LIKELY(rename && isa<Statement>(expr))) {
360        Statement * const stmt = cast<Statement>(expr);
361        if (getName()->isUserDefined() && stmt->getName()->isGenerated()) {
362            stmt->setName(getName());
363        }
364    }
365    replaceAllUsesWith(expr);   
366    return eraseFromParent(recursively);
367}
368
369/** ------------------------------------------------------------------------------------------------------------- *
370 * @brief contains
371 ** ------------------------------------------------------------------------------------------------------------- */
372bool StatementList::contains(Statement * const statement) {
373    for (Statement * stmt : *this) {
374        if (statement == stmt) {
375            return true;
376        }
377    }
378    return false;
379}
380
381StatementList::~StatementList() {
382
383}
384
385/** ------------------------------------------------------------------------------------------------------------- *
386 * @brief escapes
387 *
388 * Is this statement used outside of its scope?
389 ** ------------------------------------------------------------------------------------------------------------- */
390bool escapes(const Statement * statement) {
391    const PabloBlock * const parent = statement->getParent();
392    for (const PabloAST * user : statement->users()) {
393        if (LLVM_LIKELY(isa<Statement>(user))) {
394            const PabloBlock * used = cast<Statement>(user)->getParent();
395            while (used != parent) {
396                used = used->getParent();
397                if (used == nullptr) {
398                    assert (isa<Assign>(statement) || isa<Next>(statement));
399                    return true;
400                }
401            }
402        }
403    }
404    return false;
405}
406
407}
Note: See TracBrowser for help on using the repository browser.