source: proto/parabix2/Compiler/bitstream_compiler.py @ 301

Last change on this file since 301 was 301, checked in by eamiri, 10 years ago

simple_op supports any number of variables

File size: 6.8 KB
Line 
1#
2import ast, bitexpr, py2bitexpr
3
4#
5#
6#  Code Generation
7#
8
9def pairs(lst):
10        if lst == []:
11                return []
12        return zip(lst,lst[1:]+[lst[0]])
13
14Nop = 'nop'
15Andc = 'simd_andc'
16And = 'simd_and'
17Or = 'simd_or'
18Xor = 'simd_xor'
19Not = 'not'
20Sel = 'simd_if'
21Add = 'simd_add'
22Sub = 'simd_sub'
23
24AllOne = 'allone'
25AllZero = 'allzero'
26
27class simple_op:
28        def __init__(self, *arg):
29                self.op = arg[0]
30                self.vars = list(arg[1:])
31               
32        def show(self):
33                if self.op == Nop:
34                        return "%s\n"%self.vars[0]
35                else:
36                        line_of_code = self.op + "("
37                       
38                        for index, var in enumerate(self.vars):
39                                line_of_code += str(var)
40                                if index+1 < len(self.vars):
41                                        line_of_code += ', '
42                       
43                        line_of_code += ')\n'
44                       
45                        return line_of_code
46       
47        def update_var(self, old, new):
48                while old in self.vars:
49                        i = self.vars.index(old)
50                        self.vars[i] = new
51
52class CodeGenObject:
53    def __init__(self, predeclared):
54        self.gensym_template = 'temp%i'
55        self.gensym_counter = 0
56        self.generated_code = []
57        self.common_expression_map = {}
58        for sym in predeclared: self.common_expression_map[sym] = sym
59    def add_stmt(self, varname, expr):
60        self.common_expression_map[expr.show()] = varname
61        self.generated_code.append(bitexpr.BitAssign(varname, expr))
62    def expr_string_to_variable(self, expr_string):
63        if self.common_expression_map.has_key(expr_string.show()): 
64            return self.common_expression_map[expr_string.show()]
65        else:
66            self.gensym_counter += 1
67            sym = self.gensym_template % self.gensym_counter
68            self.add_stmt(sym, expr_string)
69            return sym
70    def showcode(self, line_no = False):
71        s = ''
72        for index, stmt in enumerate(self.generated_code): 
73                if line_no:
74                        s+= "%i %s"%(index, stmt.show())
75                else:
76                        s += stmt.show()
77        return s
78
79    def get_code(self):
80        return self.generated_code
81
82def expr2simd(genobj, expr):
83    """Translate a Boolean expression into three-address simd code
84       using code generator object genobj.
85    """
86    if isinstance(expr, bitexpr.TrueLiteral): return simple_op(Nop, AllOne)
87    elif isinstance(expr, bitexpr.FalseLiteral): return simple_op(Nop, AllZero)
88    elif isinstance(expr, bitexpr.Var): 
89        v = simple_op(Nop, expr.varname)
90        genobj.common_expression_map[v.show()] = expr.varname
91        return v
92    elif isinstance(expr, bitexpr.Not):
93       e = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand))
94       return simple_op(Andc, AllOne, e)
95    elif isinstance(expr, bitexpr.Or):
96       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
97       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
98       return simple_op(Or, e1, e2)
99    elif isinstance(expr, bitexpr.Xor):
100       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
101       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
102       return simple_op(Xor, e1, e2)
103    elif isinstance(expr, bitexpr.And):
104       if isinstance(expr.operand1, bitexpr.Not):
105           e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1.operand))
106           e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
107           return simple_op(Andc, e2, e1)
108       elif isinstance(expr.operand2, bitexpr.Not):
109           e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
110           e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2.operand))
111           return simple_op(Andc, e1, e2)
112       else:
113           e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
114           e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
115           return simple_op(And, e1, e2)
116    elif isinstance(expr, bitexpr.Sel):
117       sel = genobj.expr_string_to_variable(expr2simd(genobj, expr.sel))
118       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.true_branch))
119       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.false_branch))
120       return simple_op(Sel, sel, e1, e2)
121    elif isinstance(expr, bitexpr.Add):
122       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
123       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
124       return simple_op(Add, e1, e2)
125    elif isinstance(expr, bitexpr.Sub):
126       e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
127       e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
128       return simple_op(Sub, e1, e2)
129
130def pybit_codegen(cgo, stmts):
131        for s in stmts:
132                e = expr2simd(cgo, s.RHS)
133                cgo.add_stmt(s.LHS, e)
134
135def gen_sym_table(code):
136        """Generate a simple symbol table for a three address code
137           each entry is of this form: var_name:[[defs][uses]]
138        """
139        table = {}
140        for index, stmt in enumerate(code):
141                if stmt.LHS in table:
142                        table[stmt.LHS][0].append(index)
143                else:
144                        table[stmt.LHS] = [[index],[]]
145               
146                for var in stmt.RHS.vars:
147                        if var in table:
148                                table[var][1].append(index)
149                        else:
150                                table[var] = [[], [index]]
151                               
152        return table
153
154def make_SSA(code, st):
155        total_lines = len(code)
156        for var in st:
157                st[var][0].append(total_lines)
158                st[var][1].append(total_lines)
159       
160        for var in st:
161                use_index = 0
162                def_index = 1
163                for current, next in pairs(st[var][0])[1:-1]:
164                        code[current].LHS = "%s_%i"%(var, current)
165                        uline = st[var][1][use_index]
166                        while uline <= next and uline < total_lines:
167                                if uline > current:
168                                        code[uline].RHS.update_var(var, "%s_%i"%(var, current))
169                                use_index += 1
170                                uline = st[var][1][use_index]
171
172if __name__ == '__main__':
173        s=ast.parse(r"""def u8_streams(u8bit):
174        Ref2 = bitutil.Advance(lex.RefStart &~ CtCDPI_mask)
175        NumRef2 = Ref2 & lex.Hash
176        GenRef2 = Ref2 & ~lex.Hash
177        NumRef3 = bitutil.Advance(NumRef2)
178        HexRef3 = NumRef3 & lex.x
179        DecRef3 = NumRef3 &~ lex.x
180        HexRef4 = bitutil.Advance(HexRef3)
181        GenRefEnds = bitutil.ScanThru(GenRef2, lex.NameScan)
182        DecRefEnds = bitutil.ScanThru(DecRef3, lex.Digit)
183        HexRefEnds = bitutil.ScanThru(HexRef4, lex.Hex)
184        # Error checks
185        # At least one digit required for DecRef, one hex digit for HexRef.
186        Error = DecRef3 &~ lex.Digit
187        Error |= HexRef4 &~ lex.Hex
188        # Semicolon terminator required (also covers unterminated at EOF).
189        Error |= (GenRefEnds | DecRefEnds | HexRefEnds) &~ lex.Semicolon
190        CallOuts.GenRefs = GenRefEnds - GenRef2
191        CallOuts.DecRefs = DecRefEnds - DecRef3
192        CallOuts.HexRefs = HexRefEnds - HexRef4
193        # Mark references for deletion, but leave the trailing semicolon as
194        # the point for insertion of the "expansion" text (most often a
195        # single character).
196        CallOuts.delmask = (GenRefEnds | DecRefEnds | HexRefEnds) - lex.RefStart
197        CallOuts.error = Error
198""")
199        s=s.body[0].body
200
201        c=CodeGenObject([])
202
203        pybit_codegen(c, py2bitexpr.translate_stmts(s))
204        #print c.showcode(True)
205        code = c.generated_code
206        make_SSA(code, gen_sym_table(code))
207        print c.showcode(True)
Note: See TracBrowser for help on using the repository browser.