source: icGREP/icgrep-devel/icgrep/pablo/pabloAST.cpp @ 4878

Last change on this file since 4878 was 4878, checked in by nmedfort, 4 years ago

More work on n-ary operations.

File size: 17.6 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pabloAST.h>
8#include <pablo/codegenstate.h>
9#include <llvm/Support/Compiler.h>
10#include <pablo/printer_pablos.h>
11#include <llvm/ADT/SmallVector.h>
12#include <iostream>
13
14namespace pablo {
15
16PabloAST::Allocator PabloAST::mAllocator;
17
18/** ------------------------------------------------------------------------------------------------------------- *
19 * @brief equals
20 *
21 *  Return true if expr1 and expr2 can be proven equivalent according to some rules, false otherwise.  Note that
22 *  false may be returned i some cases when the exprs are equivalent.
23 ** ------------------------------------------------------------------------------------------------------------- */
24bool equals(const PabloAST * expr1, const PabloAST * expr2) {
25    assert (expr1 && expr2);
26    if (expr1 == expr2) {
27        return true;
28    } else if (expr1->getClassTypeId() == expr2->getClassTypeId()) {
29        if ((isa<Zeroes>(expr1)) || (isa<Ones>(expr1))) {
30            return true;
31        } else if (const Var * var1 = dyn_cast<const Var>(expr1)) {
32            if (const Var * var2 = cast<const Var>(expr2)) {
33                return (var1->getName() == var2->getName());
34            }
35        } else if (const Not* not1 = dyn_cast<const Not>(expr1)) {
36            if (const Not* not2 = cast<const Not>(expr2)) {
37                return equals(not1->getExpr(), not2->getExpr());
38            }
39        } else if (const And* and1 = dyn_cast<const And>(expr1)) {
40            if (const And* and2 = cast<const And>(expr2)) {
41                if (equals(and1->getOperand(0), and2->getOperand(0))) {
42                    return equals(and1->getOperand(1), and2->getOperand(1));
43                } else if (equals(and1->getOperand(0), and2->getOperand(1))) {
44                    return equals(and1->getOperand(1), and2->getOperand(0));
45                }
46            }
47        } else if (const Or * or1 = dyn_cast<const Or>(expr1)) {
48            if (const Or* or2 = cast<const Or>(expr2)) {
49                if (equals(or1->getOperand(0), or2->getOperand(0))) {
50                    return equals(or1->getOperand(1), or2->getOperand(1));
51                } else if (equals(or1->getOperand(0), or2->getOperand(1))) {
52                    return equals(or1->getOperand(1), or2->getOperand(0));
53                }
54            }
55        } else if (const Xor * xor1 = dyn_cast<const Xor>(expr1)) {
56            if (const Xor * xor2 = cast<const Xor>(expr2)) {
57                if (equals(xor1->getOperand(0), xor2->getOperand(0))) {
58                    return equals(xor1->getOperand(1), xor2->getOperand(1));
59                } else if (equals(xor1->getOperand(0), xor2->getOperand(1))) {
60                    return equals(xor1->getOperand(1), xor2->getOperand(0));
61                }
62            }
63        } else if (isa<Integer>(expr1) || isa<String>(expr1) || isa<Call>(expr1)) {
64            // If these weren't equivalent by address they won't be equivalent by their operands.
65            return false;
66        } else { // Non-reassociatable functions (i.e., Sel, Advance, ScanThru, MatchStar, Assign, Next)
67            const Statement * stmt1 = cast<Statement>(expr1);
68            const Statement * stmt2 = cast<Statement>(expr2);
69            assert (stmt1->getNumOperands() == stmt2->getNumOperands());
70            for (unsigned i = 0; i != stmt1->getNumOperands(); ++i) {
71                if (!equals(stmt1->getOperand(i), stmt2->getOperand(i))) {
72                    return false;
73                }
74            }
75            return true;
76        }
77    }
78    return false;
79}
80
81/** ------------------------------------------------------------------------------------------------------------- *
82 * @brief replaceAllUsesWith
83 ** ------------------------------------------------------------------------------------------------------------- */
84void PabloAST::replaceAllUsesWith(PabloAST * expr) {
85    assert (expr);
86    if (LLVM_UNLIKELY(this == expr)) {
87        return;
88    }
89    Statement * replacement[mUsers.size()];
90    Users::size_type replacements = 0;
91    PabloAST * last = nullptr;
92    for (PabloAST * user : mUsers) {
93        if (LLVM_UNLIKELY(user == expr || user == last)) {
94            continue;
95        } else if (isa<Statement>(user)) {
96            replacement[replacements++] = cast<Statement>(user);
97            last = user;
98        }
99    }
100    for (Users::size_type i = 0; i != replacements; ++i) {
101        replacement[i]->replaceUsesOfWith(this, expr);
102    }
103}
104
105/** ------------------------------------------------------------------------------------------------------------- *
106 * @brief checkEscapedValueList
107 ** ------------------------------------------------------------------------------------------------------------- */
108template <class ValueType, class ValueList>
109inline void Statement::checkEscapedValueList(Statement * branch, PabloAST * const from, PabloAST * const to, ValueList & list) {
110    if (LLVM_LIKELY(isa<ValueType>(from))) {
111        auto f = std::find(list.begin(), list.end(), cast<ValueType>(from));
112        if (LLVM_LIKELY(f != list.end())) {
113            if (LLVM_LIKELY(isa<ValueType>(to))) {
114                if (std::find(list.begin(), list.end(), cast<ValueType>(to)) == list.end()) {
115                    *f = cast<ValueType>(to);
116                    branch->addUser(to);
117                } else {
118                    list.erase(f);
119                }
120                branch->removeUser(from);
121                assert (std::find(list.begin(), list.end(), cast<ValueType>(to)) != list.end());
122                assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(to)) != branch->user_end());
123            } else {
124                list.erase(f);
125                branch->removeUser(from);
126            }
127        }               
128        assert (std::find(list.begin(), list.end(), cast<ValueType>(from)) == list.end());
129        assert (std::find(branch->user_begin(), branch->user_end(), cast<ValueType>(from)) == branch->user_end());
130    }
131}
132
133/** ------------------------------------------------------------------------------------------------------------- *
134 * @brief replaceUsesOfWith
135 ** ------------------------------------------------------------------------------------------------------------- */
136void Statement::replaceUsesOfWith(PabloAST * const from, PabloAST * const to) {
137    if (LLVM_UNLIKELY(from == to)) {
138        return;
139    }
140    for (unsigned i = 0; i != getNumOperands(); ++i) {
141       if (getOperand(i) == from) {
142           setOperand(i, to);
143       }
144    }
145    if (LLVM_UNLIKELY(isa<If>(this))) {
146        checkEscapedValueList<Assign>(this, from, to, cast<If>(this)->getDefined());
147    } else if (LLVM_UNLIKELY(isa<While>(this))) {
148        checkEscapedValueList<Next>(this, from, to, cast<While>(this)->getVariants());
149    }
150}
151
152/** ------------------------------------------------------------------------------------------------------------- *
153 * @brief setOperand
154 ** ------------------------------------------------------------------------------------------------------------- */
155void Statement::setOperand(const unsigned index, PabloAST * const value) {
156    assert ("No operand can be null!" && value);
157    assert (index < getNumOperands());
158    PabloAST * const prior = getOperand(index);
159    assert ("No operand can be null!" && prior);
160    if (LLVM_UNLIKELY(prior == value)) {
161        return;
162    }   
163    prior->removeUser(this);
164    mOperand[index] = value;
165    value->addUser(this);
166}
167
168/** ------------------------------------------------------------------------------------------------------------- *
169 * @brief insertBefore
170 ** ------------------------------------------------------------------------------------------------------------- */
171void Statement::insertBefore(Statement * const statement) {
172    if (LLVM_UNLIKELY(statement == this)) {
173        return;
174    }
175    else if (LLVM_UNLIKELY(statement == nullptr)) {
176        throw std::runtime_error("cannot insert before null statement!");
177    }
178    else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
179        throw std::runtime_error("statement is not contained in a pablo block!");
180    }
181    removeFromParent();
182    mParent = statement->mParent;
183    if (LLVM_UNLIKELY(mParent->mFirst == statement)) {
184        mParent->mFirst = this;
185    }
186    mNext = statement;
187    mPrev = statement->mPrev;
188    statement->mPrev = this;
189    if (LLVM_LIKELY(mPrev != nullptr)) {
190        mPrev->mNext = this;
191    }
192    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
193        PabloBlock * body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
194        body->setParent(mParent);
195    }
196}
197
198/** ------------------------------------------------------------------------------------------------------------- *
199 * @brief insertAfter
200 ** ------------------------------------------------------------------------------------------------------------- */
201void Statement::insertAfter(Statement * const statement) {
202    if (LLVM_UNLIKELY(statement == this)) {
203        return;
204    } else if (LLVM_UNLIKELY(statement == nullptr)) {
205        throw std::runtime_error("cannot insert after null statement!");
206    } else if (LLVM_UNLIKELY(statement->mParent == nullptr)) {
207        throw std::runtime_error("statement is not contained in a pablo block!");
208    }
209    removeFromParent();
210    mParent = statement->mParent;
211    if (LLVM_UNLIKELY(mParent->mLast == statement)) {
212        mParent->mLast = this;
213    }
214    mPrev = statement;
215    mNext = statement->mNext;
216    statement->mNext = this;
217    if (LLVM_LIKELY(mNext != nullptr)) {
218        mNext->mPrev = this;
219    }
220    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
221        PabloBlock * body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
222        body->setParent(mParent);
223    }
224}
225
226/** ------------------------------------------------------------------------------------------------------------- *
227 * @brief removeFromParent
228 ** ------------------------------------------------------------------------------------------------------------- */
229Statement * Statement::removeFromParent() {
230    Statement * next = mNext;
231    if (LLVM_LIKELY(mParent != nullptr)) {
232        if (LLVM_UNLIKELY(mParent->mFirst == this)) {
233            mParent->mFirst = mNext;
234        }
235        if (LLVM_UNLIKELY(mParent->mLast == this)) {
236            mParent->mLast = mPrev;
237        }
238        if (LLVM_UNLIKELY(mParent->mInsertionPoint == this)) {
239            mParent->mInsertionPoint = mPrev;
240        }
241        if (LLVM_LIKELY(mPrev != nullptr)) {
242            mPrev->mNext = mNext;
243        }
244        if (LLVM_LIKELY(mNext != nullptr)) {
245            mNext->mPrev = mPrev;
246        }
247        if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
248            PabloBlock * body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
249            body->setParent(nullptr);
250        }
251    }
252    mPrev = nullptr;
253    mNext = nullptr;
254    mParent = nullptr;
255    return next;
256}
257
258/** ------------------------------------------------------------------------------------------------------------- *
259 * @brief eraseFromParent
260 ** ------------------------------------------------------------------------------------------------------------- */
261Statement * Statement::eraseFromParent(const bool recursively) {
262    // remove this statement from its operands' users list
263    for (unsigned i = 0; i != mOperands; ++i) {
264        mOperand[i]->removeUser(this);
265    }
266    SmallVector<Statement *, 1> redundantBranches;
267    // If this is an If or While statement, we'll have to remove the statements within the
268    // body or we'll lose track of them.
269    if (LLVM_UNLIKELY(isa<If>(this) || isa<While>(this))) {
270        PabloBlock * const body = isa<If>(this) ? cast<If>(this)->getBody() : cast<While>(this)->getBody();
271        body->eraseFromParent(recursively);
272    } else if (LLVM_UNLIKELY(isa<Assign>(this))) {
273        for (PabloAST * use : mUsers) {
274            if (If * ifNode = dyn_cast<If>(use)) {
275                auto & defs = ifNode->getDefined();
276                auto f = std::find(defs.begin(), defs.end(), this);
277                if (LLVM_LIKELY(f != defs.end())) {
278                    this->removeUser(ifNode);
279                    ifNode->removeUser(this);
280                    defs.erase(f);
281                    if (LLVM_UNLIKELY(defs.empty())) {
282                        redundantBranches.push_back(ifNode);
283                    }
284                }
285            }
286        }
287    } else if (LLVM_UNLIKELY(isa<Next>(this))) {
288        for (PabloAST * use : mUsers) {
289            if (While * whileNode = dyn_cast<While>(use)) {
290                auto & vars = whileNode->getVariants();
291                auto f = std::find(vars.begin(), vars.end(), this);
292                if (LLVM_LIKELY(f != vars.end())) {
293                    this->removeUser(whileNode);
294                    whileNode->removeUser(this);
295                    vars.erase(f);
296                    if (LLVM_UNLIKELY(vars.empty())) {
297                        redundantBranches.push_back(whileNode);
298                    }
299                }
300            }
301        }
302    }
303
304    replaceAllUsesWith(PabloBlock::createZeroes());
305
306    if (recursively) {
307        for (unsigned i = 0; i != mOperands; ++i) {
308            PabloAST * const op = mOperand[i];
309            if (LLVM_LIKELY(isa<Statement>(op))) {
310                if (op->getNumUses() == 0) {
311                    cast<Statement>(op)->eraseFromParent(true);
312                }
313            }
314        }
315        if (LLVM_UNLIKELY(redundantBranches.size() != 0)) {
316            // By eliminating this redundant branch, we may inadvertantly delete the scope block this statement
317            // resides within. Return null if so.
318            bool eliminatedScope = false;
319            for (Statement * br : redundantBranches) {
320                const PabloBlock * const body = isa<If>(br) ? cast<If>(br)->getBody() : cast<While>(br)->getBody();
321                if (LLVM_UNLIKELY(body == getParent())) {
322                    eliminatedScope = true;
323                }
324                br->eraseFromParent(true);
325            }
326            if (eliminatedScope) {
327                return nullptr;
328            }
329        }
330    }
331    Statement * const next = removeFromParent();
332    mAllocator.deallocate(reinterpret_cast<Allocator::pointer>(this));
333    return next;
334}
335
336/** ------------------------------------------------------------------------------------------------------------- *
337 * @brief replaceWith
338 ** ------------------------------------------------------------------------------------------------------------- */
339Statement * Statement::replaceWith(PabloAST * const expr, const bool rename, const bool recursively) {
340    assert (expr);
341    if (LLVM_UNLIKELY(expr == this)) {
342        return getNextNode();
343    }
344    if (LLVM_LIKELY(rename && isa<Statement>(expr))) {
345        Statement * const stmt = cast<Statement>(expr);
346        if (getName()->isUserDefined() && stmt->getName()->isGenerated()) {
347            stmt->setName(getName());
348        }
349    }
350    replaceAllUsesWith(expr);   
351    return eraseFromParent(recursively);
352}
353
354/** ------------------------------------------------------------------------------------------------------------- *
355 * @brief addOperand
356 ** ------------------------------------------------------------------------------------------------------------- */
357void Variadic::addOperand(PabloAST * const expr) {
358    if (LLVM_UNLIKELY(mOperands == mCapacity)) {
359        mCapacity = std::max<unsigned>(mCapacity * 2, 2);
360        PabloAST ** expandedOperandSpace = reinterpret_cast<PabloAST**>(mAllocator.allocate(mCapacity * sizeof(PabloAST *)));
361        for (unsigned i = 0; i != mOperands; ++i) {
362            expandedOperandSpace[i] = mOperand[i];
363        }
364        mAllocator.deallocate(reinterpret_cast<Allocator::pointer>(mOperand));
365        mOperand = expandedOperandSpace;
366    }
367    mOperand[mOperands++] = expr;
368    expr->addUser(this);
369}
370
371/** ------------------------------------------------------------------------------------------------------------- *
372 * @brief removeOperand
373 ** ------------------------------------------------------------------------------------------------------------- */
374PabloAST * Variadic::removeOperand(const unsigned index) {
375    assert (index < mOperands);
376    PabloAST * const expr = mOperand[index];
377    --mOperands;
378    for (unsigned i = index; i != mOperands; ++i) {
379        mOperand[i] = mOperand[i + 1];
380    }
381    mOperand[mOperands] = nullptr;
382    expr->removeUser(this);
383    return expr;
384}
385
386/** ------------------------------------------------------------------------------------------------------------- *
387 * @brief contains
388 ** ------------------------------------------------------------------------------------------------------------- */
389bool StatementList::contains(Statement * const statement) {
390    for (Statement * stmt : *this) {
391        if (statement == stmt) {
392            return true;
393        }
394    }
395    return false;
396}
397
398/** ------------------------------------------------------------------------------------------------------------- *
399 * @brief escapes
400 *
401 * Is this statement used outside of its scope?
402 ** ------------------------------------------------------------------------------------------------------------- */
403bool escapes(const Statement * statement) {
404    const PabloBlock * const parent = statement->getParent();
405    for (const PabloAST * user : statement->users()) {
406        if (LLVM_LIKELY(isa<Statement>(user))) {
407            const PabloBlock * used = cast<Statement>(user)->getParent();
408            while (used != parent) {
409                used = used->getParent();
410                if (used == nullptr) {
411                    assert (isa<Assign>(statement) || isa<Next>(statement));
412                    return true;
413                }
414            }
415        }
416    }
417    return false;
418}
419
420}
Note: See TracBrowser for help on using the repository browser.