source: proto/Compiler/pablo_util.py @ 4619

Last change on this file since 4619 was 3571, checked in by nmedfort, 5 years ago

start of error rewriting work. some clean up done to pablo.py; a few classes in it were moved to pablo_util.py.

File size: 5.7 KB
Line 
1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3
4import ast
5import mkast
6from carryInfo import *
7
8#
9# Transform any augmented assignments into standard assignments (e.g., a |= b -> a = a | b)
10#
11class AugAssignRemoval(ast.NodeTransformer):
12
13    def xfrm(self, t):
14        return self.generic_visit(t)
15
16    def visit_AugAssign(self, e):
17        self.generic_visit(e)
18        return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
19
20#
21# Removes any atEOF or inFile expressions from the ast
22#
23class RewriteEOF(ast.NodeTransformer): 
24    def __init__(self, EOFMaskName = 'EOF_mask'):       
25        self._EOFMaskName = EOFMaskName
26
27    def xfrm(self, root, FinalBlockMode=False):
28        self._finalBlockMode = FinalBlockMode
29        return self.generic_visit(root)
30               
31    def visit_Call(self, node): 
32        self.generic_visit(node)
33        if is_BuiltIn_Call(node, 'atEOF', 1):
34            if (self._finalBlockMode):
35                mask1 = mkast.call('bitblock::slli<1>', [mkast.call('simd_not', [ast.Name(self._EOFMaskName, ast.Load())])])
36                node = mkast.call('simd_andc', [mkast.call('simd_andc', [node.args[0], ast.Name(self._EOFMaskName, ast.Load())]), mask1])
37            else:
38                node = mkast.call('simd<1>::constant<0>', [])
39        elif is_BuiltIn_Call(node, 'inFile', 1):
40            if (self._finalBlockMode):
41                node = mkast.call('simd_and', [node.args[0], ast.Name(self._EOFMaskName, ast.Load())])           
42            else:
43                node = node.args[0] 
44        return node
45       
46class TempifyBuiltins(ast.NodeTransformer):
47
48    def __init__(self, tempVarpfx='tempvar'):
49        self.tempVarCount = 0
50        self.newTempList = []
51        self.tempVarPrefix = tempVarpfx
52
53    def genVar(self):
54        newTemp = self.tempVarPrefix + repr(self.tempVarCount)
55        self.newTempList.append(newTemp)
56        self.tempVarCount += 1
57        return newTemp
58
59    def tempVars(self):
60        return self.newTempList
61
62    def xfrm(self, t):
63        self.setUpStmts = []
64        self.assigNode = None
65        return self.generic_visit(t)
66
67    def is_Assign_value(self, node):
68        return self.assigNode != None and self.assigNode.value == node
69
70    def visit_If(self, ifNode):
71        self.setUpStmts = []
72        self.generic_visit(ifNode.test)
73        ifSetUpStmts = self.setUpStmts
74        self.generic_visit(ifNode)
75        if ifSetUpStmts == []:
76            return ifNode
77        else:
78            return ifSetUpStmts + [ifNode]
79
80    def visit_While(self, whileNode):
81        self.setUpStmts = []
82        self.generic_visit(whileNode.test)
83        whileSetUpStmts = self.setUpStmts
84        self.generic_visit(whileNode)
85        whileNode.body = whileNode.body + whileSetUpStmts
86        return whileSetUpStmts + [whileNode]
87
88    def visit_Assign(self, node):
89        self.assigNode = node
90        self.setUpStmts = []
91        self.generic_visit(node)
92        return self.setUpStmts + [node]
93
94    def visit_AugAssign(self, node):
95        self.setUpStmts = []
96        self.generic_visit(node)
97        return self.setUpStmts + [node]
98
99    def visit_Call(self, callnode):
100        self.generic_visit(callnode)
101        if CheckForBuiltin(callnode) and not self.is_Assign_value(callnode):
102            tempVar = ast.Name(self.genVar(), ast.Load())
103            self.setUpStmts.append(ast.Assign([tempVar], callnode))
104            return tempVar
105        else:
106            return callnode
107
108           
109           
110           
111def is_simd_not(e):
112    return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
113               
114
115               
116class Bitwise_to_SIMD(ast.NodeTransformer):
117
118    """
119  Make the following substitutions:
120     x & y => simd_and(x, y)
121     x & ~y => simd_andc(x, y)
122     x | y => simd_or(x, y)
123     x ^ y => simd_xor(x, y)
124     ~x    => simd_not(x)
125     0     => simd_const_1(0)
126     -1    => simd_const_1(1)
127     if x: => if bitblock::any(x):
128  while x: => while bitblock::any(x):
129  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse(\"pfx = bit0 & bit1; sfx = bit0 &~ bit1\")))
130 
131  pfx = simd_and(bit0, bit1)
132  sfx = simd_and(bit0, simd_not(bit1))
133  >>>
134  """
135 
136    def xfrm(self, t):
137        return self.generic_visit(t)
138
139    def visit_UnaryOp(self, t):
140        self.generic_visit(t)
141        if isinstance(t.op, ast.Invert):
142            return mkast.call('simd_not', [t.operand])
143        else:
144            return t
145
146    def visit_BinOp(self, t):
147        self.generic_visit(t)
148        if isinstance(t.op, ast.BitOr):
149            return mkast.call('simd_or', [t.left, t.right])
150        elif isinstance(t.op, ast.BitAnd):
151            if is_simd_not(t.right):
152                return mkast.call('simd_andc', [t.left, t.right.args[0]])
153            elif is_simd_not(t.left):
154                return mkast.call('simd_andc', [t.right, t.left.args[0]])
155            else:
156                return mkast.call('simd_and', [t.left, t.right])
157        elif isinstance(t.op, ast.BitXor):
158            return mkast.call('simd_xor', [t.left, t.right])
159        else:
160            return t
161
162    def visit_Num(self, numnode):
163        n = numnode.n
164        if n == 0:
165            return mkast.call('simd<1>::constant<0>', [])
166        elif n == -1:
167            return mkast.call('simd<1>::constant<1>', [])
168        else:
169            return numnode
170
171    def visit_If(self, ifNode):
172        self.generic_visit(ifNode)
173        ifNode.test = mkast.call('bitblock::any', [ifNode.test])
174        return ifNode
175
176    def visit_While(self, whileNode):
177        self.generic_visit(whileNode)
178        whileNode.test = mkast.call('bitblock::any', [whileNode.test])
179        return whileNode
180
181    def visit_Subscript(self, numnode):
182        return numnode  # no recursive modifications of index expressions
Note: See TracBrowser for help on using the repository browser.