source: proto/Compiler/pablo.py @ 865

Last change on this file since 865 was 865, checked in by cameron, 9 years ago

Add @global declaration template facility; use for struct types.

File size: 14.3 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, 2011, Robert D. Cameron
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11
12# HELPER functions for AST recognition/construction
13
14def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='bitutil'):
15        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
16        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
17                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
18        return iscall and len(fncall.args) == builtin_arg_cnt and fncall.kwargs == None and fncall.starargs == None
19
20def is_simd_not(e):
21  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
22
23def mkQname(obj, field):
24  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
25
26def mkCall(fn_name, args):
27  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
28  return ast.Call(fn_name, args, [], None, None)
29
30def mkCallStmt(fn_name, args):
31  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
32  return ast.Expr(ast.Call(fn_name, args, [], None, None))
33
34#
35# Reducing AugAssign, e.g.  x |= y becomes x = x | y
36#
37class AugAssignRemoval(ast.NodeTransformer):
38  def xfrm(self, t):
39    return self.generic_visit(t)
40  def visit_AugAssign(self, e):
41    self.generic_visit(e)
42    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
43
44#
45# Introducing BitBlock logical operations
46#
47class Bitwise_to_SIMD(ast.NodeTransformer):
48  """
49  Make the following substitutions:
50     x & y => simd_and(x, y)
51     x & ~y => simd_andc(x, y)
52     x | y => simd_or(x, y)
53     x ^ y => simd_xor(x, y)
54     ~x    => simd_not(x)
55     0     => simd_const_1(0)
56     -1    => simd_const_1(1)
57     if x: => if bitblock_has_bit(x):
58  while x: => while bitblock_has_bit(x):
59  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
60 
61  pfx = simd_and(bit0, bit1)
62  sfx = simd_and(bit0, simd_not(bit1))
63  >>>
64  """
65  def xfrm(self, t):
66    return self.generic_visit(t)
67  def visit_UnaryOp(self, t):
68    self.generic_visit(t)
69    if isinstance(t.op, ast.Invert):
70      return mkCall('simd_not', [t.operand])
71    else: return t
72  def visit_BinOp(self, t):
73    self.generic_visit(t)
74    if isinstance(t.op, ast.BitOr):
75      return mkCall('simd_or', [t.left, t.right])
76    elif isinstance(t.op, ast.BitAnd):
77      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
78      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
79      else: return mkCall('simd_and', [t.left, t.right])
80    elif isinstance(t.op, ast.BitXor):
81      return mkCall('simd_xor', [t.left, t.right])
82    else: return t
83  def visit_Num(self, numnode):
84    n = numnode.n
85    if n == 0: return mkCall('simd_const_1', [numnode])
86    elif n == -1: return mkCall('simd_const_1', [ast.Num(1)])
87    else: return numnode
88  def visit_If(self, ifNode):
89    self.generic_visit(ifNode)
90    ifNode.test = mkCall('bitblock_has_bit', [ifNode.test])
91    return ifNode
92  def visit_While(self, whileNode):
93    self.generic_visit(whileNode)
94    whileNode.test = mkCall('bitblock_has_bit', [whileNode.test])
95    return whileNode
96  def visit_Subscript(self, numnode):
97    return numnode  # no recursive modifications of index expressions
98
99
100#
101#  Generating BitBlock declarations for Local Variables
102#
103#
104class LocalVars(ast.NodeVisitor):
105  def visit_Name(self, nm):
106    if isinstance(nm.ctx, ast.Param):
107      self.params.append(nm.id)
108    if isinstance(nm.ctx, ast.Store):
109      if nm.id not in self.stores: self.stores.append(nm.id)
110  def get(self, node): 
111    self.params=[]
112    self.stores=[]
113    self.generic_visit(node)
114    return [v for v in self.stores if not v in self.params]
115
116MAX_LINE_LENGTH = 80
117
118def BitBlock_decls_from_vars(varlist):
119  global MAX_LINE_LENGTH
120  decls = "  BitBlock"
121  pending = ""
122  linelgth = 10
123  for v in varlist:
124    if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
125      decls += pending + " " + v
126      linelgth += len(pending + v) + 1
127    else:
128      decls += ";\n  BitBlock " + v
129      linelgth = 11 + len(v)
130    pending = ","
131  decls += ";\n"
132  return decls
133
134def BitBlock_decls_of_fn(fndef):
135  return BitBlock_decls_from_vars(LocalVars().get(fndef))
136
137def BitBlock_header_of_fn(fndef):
138  Ccode = "static inline void " + fndef.name + "("
139  pending = ""
140  for arg in fndef.args.args:
141    if isinstance(arg, ast.Name):
142      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
143      pending = ", "
144  if CarryCounter().count(fndef) > 0:
145    Ccode += pending + " CarryQtype & carryQ"
146  Ccode += ")"
147  return Ccode
148
149
150
151#
152#  Stream Initialization Statement Extraction
153#
154#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
155class StreamInitializations(ast.NodeTransformer):
156  def xfrm(self, node):
157    self.stream_stmts = []
158    self.loop_post_inits = []
159    self.generic_visit(node)
160    return Cgen.py2C().gen(self.stream_stmts)
161  def visit_Assign(self, node):
162    if isinstance(node.value, ast.Num):
163      if node.value.n == 0: return node
164      elif node.value.n == -1: return node
165      else: 
166        stream_init = copy.deepcopy(node)
167        stream_init.value = mkCall('sisd_from_int', [node.value])
168        loop_init = copy.deepcopy(node)
169        loop_init.value.n = 0
170        self.stream_stmts.append(stream_init)
171        self.loop_post_inits.append(loop_init)
172        return None
173    else: return node
174  def visit_FunctionDef(self, node):
175    self.generic_visit(node)
176    node.body = node.body + self.loop_post_inits
177    return node
178
179
180#
181# Carry Introduction Transformation
182#
183
184class CarryCounter(ast.NodeVisitor):
185  def visit_Call(self, callnode):
186    self.generic_visit(callnode)
187    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2):       
188      self.carry_count += 1
189  def visit_BinOp(self, exprnode):
190    self.generic_visit(exprnode)
191    if isinstance(exprnode.op, ast.Sub):
192      self.carry_count += 1
193    if isinstance(exprnode.op, ast.Add):
194      self.carry_count += 1
195  def count(self, nodeToVisit):
196    self.carry_count = 0
197    self.generic_visit(nodeToVisit)
198    return self.carry_count
199
200
201class CarryIntro(ast.NodeTransformer):
202  def __init__(self, carryvar="carryQ", mode = "ci_co"):
203    self.carryvar = ast.Name(carryvar, ast.Load())
204    self.mode = mode
205  def xfrm_fndef(self, fndef):
206    self.current_carry = 0
207    carry_count = CarryCounter().count(fndef)
208    if carry_count == 0: return fndef
209    self.generic_visit(fndef)
210#   
211#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
212    return fndef
213  def generic_xfrm(self, node):
214    self.current_carry = 0
215    carry_count = CarryCounter().count(node)
216    if carry_count == 0: return node
217    self.generic_visit(node)
218    return node
219  def visit_Call(self, callnode):
220    self.generic_visit(callnode)
221    if is_BuiltIn_Call(callnode, 'Advance', 1):         
222      rtn = "BitBlock_advance_%s" % self.mode
223      c = mkCall(rtn, [callnode.args[0], self.carryvar, ast.Num(self.current_carry)])
224      self.current_carry += 1
225      return c
226    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
227      rtn = "BitBlock_scanthru_%s" % self.mode
228      c = mkCall(rtn, [callnode.args[0], callnode.args[1], self.carryvar, ast.Num(self.current_carry)])
229      self.current_carry += 1
230      return c
231    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
232      rtn = "StreamScan"           
233      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
234                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
235                                           ast.Name(callnode.args[1].id, ast.Load())])
236      return c
237    else: return callnode
238  def visit_BinOp(self, exprnode):
239    self.generic_visit(exprnode)
240    if isinstance(exprnode.op, ast.Sub):
241      rtn = "BitBlock_sub_%s" % self.mode
242      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
243      self.current_carry += 1
244      return c
245    elif isinstance(exprnode.op, ast.Add):
246      rtn = "BitBlock_add_%s" % self.mode
247      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
248      self.current_carry += 1
249      return c
250    else: return exprnode
251  def visit_If(self, ifNode):
252    carry_base = self.current_carry
253    carries = CarryCounter().count(ifNode)
254    self.generic_visit(ifNode)
255    if carries == 0 or self.mode == "co": return ifNode
256    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
257    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('CarryTest', carry_arglist)])
258    new_else_part = ifNode.orelse + [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
259    return ast.If(new_test, ifNode.body, new_else_part)
260  def visit_While(self, whileNode):
261    carry_base = self.current_carry
262    carries = CarryCounter().count(whileNode)
263    if carries == 0: return whileNode
264    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
265    local_carryvar = 'sub'+self.carryvar.id
266    inner_while = CarryIntro(local_carryvar, 'co').generic_xfrm(copy.deepcopy(whileNode))
267    self.generic_visit(whileNode)
268    local_decl = mkCallStmt('CarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
269    local_init = mkCallStmt('CarryInit', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
270    inner_while.body.insert(0, local_decl)
271    inner_while.body.insert(1, local_init)
272    final_combine = mkCallStmt('CarryCombine', [self.carryvar, ast.Name(local_carryvar, ast.Load()), ast.Num(carry_base), ast.Num(carries)])
273    inner_while.body.append(final_combine)
274    if self.mode == "co": new_test = whileNode.test
275    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('CarryTest', carry_arglist)])
276    else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
277    return ast.If(new_test, whileNode.body + [inner_while], else_part)
278
279class StreamStructGen(ast.NodeVisitor):
280  """
281  Given a BitStreamSet subclass, generate the equivalent C struct.
282  >>> obj = ast.parse(r'''
283  ... class S1(BitStreamSet):
284  ...   a1 = 0
285  ...   a2 = 0
286  ...   a3 = 0
287  ...
288  ... class S2(BitStreamSet):
289  ...   x1 = 0
290  ...   x2 = 0
291  ... ''')
292  >>> print StreamStructGen().gen(obj)
293  struct S1 {
294    BitBlock a1;
295    BitBlock a2;
296    BitBlock a3;
297  }
298 
299  struct S2 {
300    BitBlock x1;
301    BitBlock x2;
302  }
303  """
304  def __init__(self, asType=False):
305    self.asType = asType
306  def gen(self, tree):
307    self.Ccode=""
308    self.generic_visit(tree)
309    return self.Ccode
310  def gen_struct_types(self, tree):
311    self.asType = True
312    self.Ccode=""
313    self.generic_visit(tree)
314    return self.Ccode
315  def gen_struct_vars(self, tree):
316    self.asType = False
317    self.Ccode=""
318    self.generic_visit(tree)
319    return self.Ccode
320  def visit_ClassDef(self, node):
321    class_name = node.name[0].upper() + node.name[1:]
322    instance_name = node.name[0].lower() + node.name[1:]
323    self.Ccode += "struct " + class_name
324    if self.asType:
325            self.Ccode += " {\n"
326            for stmt in node.body:
327              if isinstance(stmt, ast.Assign):
328                for v in stmt.targets:
329                  if isinstance(v, ast.Name):
330                    self.Ccode += "  BitBlock " + v.id + ";\n"
331            self.Ccode += "}" 
332    else: self.Ccode += " " + instance_name
333    self.Ccode += ";\n\n"
334 
335class StreamFunctionDecl(ast.NodeVisitor):
336  def __init__(self):
337    pass
338  def gen(self, tree):
339    self.Ccode=""
340    self.generic_visit(tree)
341    return self.Ccode
342  def visit_FunctionDef(self, node):
343    self.Ccode += "static inline void " + node.name + "("
344    pending = ""
345    for arg in node.args.args:
346      if isinstance(arg, ast.Name):
347        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
348        pending = ", "
349    self.Ccode += ");\n"
350
351
352
353#
354# Adding Debugging Statements
355#
356class Add_SIMD_Register_Dump(ast.NodeTransformer):
357  def xfrm(self, t):
358    return self.generic_visit(t)
359  def visit_Assign(self, t):
360    self.generic_visit(t)
361    v = t.targets[0]
362    dump_stmt = mkCallStmt('print_simd_register', [ast.Str(Cgen.py2C().gen(v)), v])
363    return [t, dump_stmt]
364   
365#
366#  Translate a function
367#
368
369class FunctionXlator(ast.NodeVisitor):
370  def xlat(self, node):
371    self.Ccode=""
372    self.generic_visit(node)
373    return self.Ccode
374  def visit_FunctionDef(self, fndef):
375    AugAssignRemoval().xfrm(fndef)
376    Bitwise_to_SIMD().xfrm(fndef)
377    self.Ccode += BitBlock_header_of_fn(fndef) + " {\n"
378    self.Ccode += BitBlock_decls_of_fn(fndef)
379    CarryIntro().xfrm_fndef(fndef)
380    self.Ccode += Cgen.py2C().gen(fndef.body)
381    self.Ccode += "\n}\n"
382
383def main(infilename, outfile = sys.stdout):
384  t = ast.parse(file(infilename).read())
385  outfile.write(StreamStructGen(True).gen(t))
386  outfile.write(FunctionXlator().xlat(t))
387
388#
389#
390#  Routines for compatibility with the old compiler/template.
391#  Quick and dirty hacks for now - Dec. 2010.
392#
393
394class MainLoopTransformer:
395  def __init__(self, main_module, carry_var = "carryQ"):
396    self.main_module = main_module
397    self.main_fn = main_module.body[-1]
398    assert (isinstance(self.main_fn, ast.FunctionDef))
399    self.carry_count = CarryCounter().count(self.main_fn)
400    self.carry_var = carry_var
401  def any_carry_expr(self):
402    if self.carry_count == 0: return "1"
403    else: return "CarryTest(%s, 0, %i)\n" % (self.carry_var, self.carry_count)
404  def gen_declarations(self):
405    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
406    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
407    self.Cdecls += BitBlock_decls_of_fn(self.main_fn)
408    if self.carry_count > 0: self.Cdecls += "CarryDeclare(%s, %i);\n" % (self.carry_var, self.carry_count)
409  def gen_initializations(self):
410    self.Cinits = ""
411    if self.carry_count > 0: self.Cinits += "CarryInit(%s, %i);\n" % (self.carry_var, self.carry_count)
412    self.Cinits += StreamInitializations().xfrm(self.main_module)
413  def xfrm_block_stmts(self, add_dump_stmts=False):
414    AugAssignRemoval().xfrm(self.main_fn)
415    Bitwise_to_SIMD().xfrm(self.main_fn)
416    CarryIntro().xfrm_fndef(self.main_fn)
417    if add_dump_stmts: Add_SIMD_Register_Dump().xfrm(self.main_fn)
418  def add_loop_carryQ_adjust(self):
419    self.main_fn.body += [mkCallStmt('CarryQ_Adjust', [ast.Name(self.carry_var, ast.Load()), ast.Num(self.carry_count)])]
420  def getCglobals(self):
421    return self.Cglobals
422  def getCdecls(self):
423    return self.Cdecls
424  def getCinits(self):
425    return self.Cinits
426  def getCstmts(self):
427    return Cgen.py2C().gen(self.main_fn.body)
428
429if __name__ == "__main__":
430                import doctest
431                doctest.testmod()
432
433
Note: See TracBrowser for help on using the repository browser.