source: proto/parabix2/Compiler/bitstream_compiler.py @ 299

Last change on this file since 299 was 299, checked in by cameron, 10 years ago

Avoid generating trivial new variables.

File size: 7.3 KB
Line 
1#
2import ast, bitexpr, py2bitexpr
3
4#
5#
6#  Code Generation
7#
8
9Nop = 'nop'
10Andc = 'simd_andc'
11And = 'simd_and'
12Or = 'simd_or'
13Xor = 'simd_xor'
14Not = 'not'
15Add = 'simd_add'
16Sub = 'simd_sub'
17
18AllOne = 'allone'
19AllZero = 'allzero'
20
21class simple_op:
22        def __init__(self, *arg):
23                self.op = arg[0]
24                self.first = arg[1]
25                if len(arg) > 2:
26                        self.second = arg[2]
27                else:
28                        assert(self.op==Nop)
29                        self.second = None
30
31        def show(self):
32                if self.op == Nop:
33                        return "%s\n"%self.first
34                else:
35                        return "%s(%s, %s)\n"%(self.op, self.first, self.second)
36        def update_var(self, old, new):
37                if self.first == old:
38                        self.first = new
39                if self.second == old:
40                        self.second = new
41
42class CodeGenObject:
43    def __init__(self, predeclared):
44        self.gensym_template = 'temp%i'
45        self.gensym_counter = 0
46        self.generated_code = []
47        self.common_expression_map = {}
48        for sym in predeclared: self.common_expression_map[sym] = sym
49    def add_stmt(self, varname, expr):
50        self.common_expression_map[expr.show()] = varname
51        self.generated_code.append(bitexpr.BitAssign(varname, expr))
52    def expr_string_to_variable(self, expr_string):
53        if self.common_expression_map.has_key(expr_string.show()): 
54            return self.common_expression_map[expr_string.show()]
55        else:
56            self.gensym_counter += 1
57            sym = self.gensym_template % self.gensym_counter
58            self.add_stmt(sym, expr_string)
59            return sym
60    def showcode(self, line_no = False):
61        s = ''
62        for index, stmt in enumerate(self.generated_code): 
63                if line_no:
64                        s+= "%i %s"%(index, stmt.show())
65                else:
66                        s += stmt.show()
67        return s
68
69    def get_code(self):
70        return self.generated_code
71
72def expr2simd(genobj, expr):
73    """Translate a Boolean expression into three-address simd code
74       using code generator object genobj.
75    """
76    if isinstance(expr, bitexpr.TrueLiteral): return simple_op(Nop, AllOne)
77    elif isinstance(expr, bitexpr.FalseLiteral): return simple_op(Nop, AllZero)
78    elif isinstance(expr, bitexpr.Var): 
79        v = simple_op(Nop, expr.varname)
80        genobj.common_expression_map[v.show()] = expr.varname
81        return v
82    elif isinstance(expr, bitexpr.Not):
83       e = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand))
84       return simple_op(Andc, AllOne, e)
85    elif isinstance(expr, bitexpr.Or):
86       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
87       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
88       return simple_op(Or, e1, e2)
89    elif isinstance(expr, bitexpr.Xor):
90       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
91       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
92       return simple_op(Xor, e1, e2)
93    elif isinstance(expr, bitexpr.And):
94       if isinstance(expr.operand1, bitexpr.Not):
95           e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1.operand))
96           e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
97           return simple_op(Andc, e2, e1)
98       elif isinstance(expr.operand2, bitexpr.Not):
99           e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
100           e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2.operand))
101           return simple_op(Andc, e1, e2)
102       else:
103           e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
104           e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
105           return simple_op(And, e1, e2)
106    elif isinstance(expr, bitexpr.Sel):
107       sel = genobj.expr_string_to_variable(expr2simd(genobj, expr.sel))
108       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.true_branch))
109       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.false_branch))
110       return None
111       #return 'simd_if(%s, %s, %s)' %(sel, e1, e2)
112       ## TODO; Do something for this. It should be removed
113    elif isinstance(expr, bitexpr.Add):
114       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
115       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
116       return simple_op(Add, e1, e2)
117    elif isinstance(expr, bitexpr.Sub):
118       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
119       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
120       return simple_op(Sub, e1, e2)
121
122def pybit_codegen(cgo, stmts):
123        for s in stmts:
124                e = expr2simd(cgo, s.RHS)
125                cgo.add_stmt(s.LHS, e)
126
127def gen_sym_table(code):
128        """Generate a simple symbol table for a three address code
129           each entry is of this form: var_name:[[defs][uses]]
130        """
131        table = {}
132        for index, stmt in enumerate(code):
133                if stmt.LHS in table:
134                        table[stmt.LHS][0].append(index)
135                else:
136                        table[stmt.LHS] = [[index],[]]
137               
138                if stmt.RHS.first in table:
139                        table[stmt.RHS.first][1].append(index)
140                else:
141                        table[stmt.RHS.first] = [[], [index]]
142               
143                if stmt.RHS.second is None or stmt.RHS.second==stmt.RHS.first:
144                        continue
145               
146                if stmt.RHS.second in table:
147                        table[stmt.RHS.second][1].append(index)
148                else:
149                        table[stmt.RHS.second] = [[], [index]]
150
151        return table
152
153
154def pairs(lst):
155        if lst == []:
156                return []
157        return zip(lst,lst[1:]+[lst[0]])
158
159
160def make_SSA(code, st):
161        total_lines = len(code)
162        for var in st:
163                st[var][0].append(total_lines)
164                st[var][1].append(total_lines)
165       
166        for var in st:
167                use_index = 0
168                def_index = 1
169                for current, next in pairs(st[var][0])[1:-1]:
170                        code[current].LHS = "%s_%i"%(var, current)
171                        uline = st[var][1][use_index]
172                        while uline <= next and uline < total_lines:
173                                if uline > current:
174                                        code[uline].RHS.update_var(var, "%s_%i"%(var, current))
175                                use_index += 1
176                                uline = st[var][1][use_index]
177
178
179
180
181def copy_propagation(code, st):
182        """Assumes the code is in SSA form"""
183        """TODO: THIS IS NOT COMPLETE"""
184        dic = {}
185        for stmt in code:
186                if stmt.RHS.op == Nop:
187                        dic[stmt.LHS] = stmt.RHS.first
188       
189        for i in dic:
190                if not dic[i] in dic:
191                        continue
192                while dic[i] in dic:
193                        i = dic[i]
194                dic[i] = i
195
196        return dic
197
198
199
200
201if __name__ == '__main__':
202        s=ast.parse(r"""def u8_streams(u8bit):
203        Ref2 = bitutil.Advance(lex.RefStart &~ CtCDPI_mask)
204        NumRef2 = Ref2 & lex.Hash
205        GenRef2 = Ref2 & ~lex.Hash
206        NumRef3 = bitutil.Advance(NumRef2)
207        HexRef3 = NumRef3 & lex.x
208        DecRef3 = NumRef3 &~ lex.x
209        HexRef4 = bitutil.Advance(HexRef3)
210        GenRefEnds = bitutil.ScanThru(GenRef2, lex.NameScan)
211        DecRefEnds = bitutil.ScanThru(DecRef3, lex.Digit)
212        HexRefEnds = bitutil.ScanThru(HexRef4, lex.Hex)
213        # Error checks
214        # At least one digit required for DecRef, one hex digit for HexRef.
215        Error = DecRef3 &~ lex.Digit
216        Error |= HexRef4 &~ lex.Hex
217        # Semicolon terminator required (also covers unterminated at EOF).
218        Error |= (GenRefEnds | DecRefEnds | HexRefEnds) &~ lex.Semicolon
219        CallOuts.GenRefs = GenRefEnds - GenRef2
220        CallOuts.DecRefs = DecRefEnds - DecRef3
221        CallOuts.HexRefs = HexRefEnds - HexRef4
222        # Mark references for deletion, but leave the trailing semicolon as
223        # the point for insertion of the "expansion" text (most often a
224        # single character).
225        CallOuts.delmask = (GenRefEnds | DecRefEnds | HexRefEnds) - lex.RefStart
226        CallOuts.error = Error
227""")
228        s=s.body[0].body
229
230        c=CodeGenObject([])
231
232        pybit_codegen(c, py2bitexpr.translate_stmts(s))
233        #print c.showcode(True)
234        code = c.generated_code
235        make_SSA(code, gen_sym_table(code))
236        print c.showcode(True)
Note: See TracBrowser for help on using the repository browser.