source: proto/Compiler/pablo.py @ 813

Last change on this file since 813 was 813, checked in by ksherdy, 9 years ago

Refactor is_Advance_Call and is_ScanThru_Call into a single method is_BuildIn_Call to avoid code duplication.

File size: 14.1 KB
RevLine 
[753]1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, Robert D. Cameron
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11
12# HELPER functions for AST recognition/construction
[813]13# modified from old py2bitexpr.py
14def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='bitutil'):
15        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
16        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
17                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
18        return iscall and len(fncall.args) == builtin_arg_cnt and fncall.kwargs == None and fncall.starargs == None
[753]19
20def is_Advance_Call(fncall):   # extracted from old py2bitexpr.py
21        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
22        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
23                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
24        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
25
26def is_ScanThru_Call(fncall):  # extracted from old py2bitexpr.py
27        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
28        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
29                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
30        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
31
[770]32def is_simd_not(e):
33  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
34
[753]35def mkQname(obj, field):
36  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
37
38def mkCall(fn_name, args):
39  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
40  return ast.Call(fn_name, args, [], None, None)
41
42def mkCallStmt(fn_name, args):
43  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
44  return ast.Expr(ast.Call(fn_name, args, [], None, None))
45
46#
[771]47# Reducing AugAssign
48#
49class AugAssignRemoval(ast.NodeTransformer):
50  def xfrm(self, t):
51    return self.generic_visit(t)
52  def visit_AugAssign(self, e):
53    self.generic_visit(e)
54    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
55
56#
[753]57# Introducing BitBlock logical operations
58#
59class Bitwise_to_SIMD(ast.NodeTransformer):
60  """
61  Make the following substitutions:
62     x & y => simd_and(x, y)
[770]63     x & ~y => simd_andc(x, y)
[753]64     x | y => simd_or(x, y)
65     x ^ y => simd_xor(x, y)
66     ~x    => simd_not(x)
67     0     => simd_const_1(0)
68     -1    => simd_const_1(1)
69     if x: => if bitblock_has_bit(x):
70  while x: => while bitblock_has_bit(x):
71  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
72 
73  pfx = simd_and(bit0, bit1)
74  sfx = simd_and(bit0, simd_not(bit1))
75  >>>
76  """
77  def xfrm(self, t):
78    return self.generic_visit(t)
79  def visit_UnaryOp(self, t):
80    self.generic_visit(t)
81    if isinstance(t.op, ast.Invert):
82      return mkCall('simd_not', [t.operand])
83    else: return t
84  def visit_BinOp(self, t):
85    self.generic_visit(t)
86    if isinstance(t.op, ast.BitOr):
87      return mkCall('simd_or', [t.left, t.right])
88    elif isinstance(t.op, ast.BitAnd):
[770]89      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
90      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
91      else: return mkCall('simd_and', [t.left, t.right])
[753]92    elif isinstance(t.op, ast.BitXor):
93      return mkCall('simd_xor', [t.left, t.right])
94    else: return t
95  def visit_Num(self, numnode):
96    n = numnode.n
97    if n == 0: return mkCall('simd_const_1', [numnode])
98    elif n == -1: return mkCall('simd_const_1', [ast.Num(1)])
99    else: return numnode
100  def visit_If(self, ifNode):
101    self.generic_visit(ifNode)
102    ifNode.test = mkCall('bitblock_has_bit', [ifNode.test])
103    return ifNode
104  def visit_While(self, whileNode):
105    self.generic_visit(whileNode)
106    whileNode.test = mkCall('bitblock_has_bit', [whileNode.test])
107    return whileNode
108  def visit_Subscript(self, numnode):
109    return numnode  # no recursive modifications of index expressions
110
111
112#
113#  Generating BitBlock declarations for Local Variables
114#
115#
116class LocalVars(ast.NodeVisitor):
117  def visit_Name(self, nm):
118    if isinstance(nm.ctx, ast.Param):
119      self.params.append(nm.id)
120    if isinstance(nm.ctx, ast.Store):
121      if nm.id not in self.stores: self.stores.append(nm.id)
122  def get(self, node): 
123    self.params=[]
124    self.stores=[]
125    self.generic_visit(node)
126    return [v for v in self.stores if not v in self.params]
127
128MAX_LINE_LENGTH = 80
129
130def BitBlock_decls_from_vars(varlist):
131  global MAX_LINE_LENGTH
132  decls = "  BitBlock"
133  pending = ""
134  linelgth = 10
135  for v in varlist:
136    if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
137      decls += pending + " " + v
138      linelgth += len(pending + v) + 1
139    else:
[756]140      decls += ";\n  BitBlock " + v
[753]141      linelgth = 11 + len(v)
142    pending = ","
143  decls += ";\n"
144  return decls
145
146def BitBlock_decls_of_fn(fndef):
147  return BitBlock_decls_from_vars(LocalVars().get(fndef))
148
149def BitBlock_header_of_fn(fndef):
150  Ccode = "static inline void " + fndef.name + "("
151  pending = ""
152  for arg in fndef.args.args:
153    if isinstance(arg, ast.Name):
154      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
155      pending = ", "
156  Ccode += ")"
157  return Ccode
158
159
[760]160
161#
162#  Stream Initialization Statement Extraction
163#
164#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
165class StreamInitializations(ast.NodeTransformer):
166  def xfrm(self, node):
167    self.stream_stmts = []
168    self.loop_post_inits = []
169    self.generic_visit(node)
170    return Cgen.py2C().gen(self.stream_stmts)
171  def visit_Assign(self, node):
172    if isinstance(node.value, ast.Num):
173      if node.value.n == 0: return node
174      elif node.value.n == -1: return node
175      else: 
176        stream_init = copy.deepcopy(node)
177        stream_init.value = mkCall('sisd_from_int', [node.value])
178        loop_init = copy.deepcopy(node)
179        loop_init.value.n = 0
180        self.stream_stmts.append(stream_init)
181        self.loop_post_inits.append(loop_init)
182        return None
183    else: return node
184  def visit_FunctionDef(self, node):
185    self.generic_visit(node)
186    node.body = node.body + self.loop_post_inits
187    return node
188
189
[753]190#
191# Carry Introduction Transformation
192#
193
194class CarryCounter(ast.NodeVisitor):
195  def visit_Call(self, callnode):
196    self.generic_visit(callnode)
[813]197    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2):       
[753]198      self.carry_count += 1
[769]199  def visit_BinOp(self, exprnode):
[753]200    self.generic_visit(exprnode)
201    if isinstance(exprnode.op, ast.Sub):
202      self.carry_count += 1
[769]203    if isinstance(exprnode.op, ast.Add):
204      self.carry_count += 1
[753]205  def count(self, nodeToVisit):
206    self.carry_count = 0
207    self.generic_visit(nodeToVisit)
208    return self.carry_count
209
210
211class CarryIntro(ast.NodeTransformer):
212  def __init__(self, carryvar="carryQ", mode = "ci_co"):
213    self.carryvar = ast.Name(carryvar, ast.Load())
214    self.mode = mode
215  def xfrm_fndef(self, fndef):
216    self.current_carry = 0
217    carry_count = CarryCounter().count(fndef)
218    if carry_count == 0: return fndef
219    self.generic_visit(fndef)
[762]220#   
221#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
[753]222    return fndef
223  def generic_xfrm(self, node):
224    self.current_carry = 0
225    carry_count = CarryCounter().count(node)
226    if carry_count == 0: return node
227    self.generic_visit(node)
228    return node
229  def visit_Call(self, callnode):
230    self.generic_visit(callnode)
[813]231    if is_BuiltIn_Call(callnode, 'Advance', 1):         
[753]232      rtn = "BitBlock_advance_%s" % self.mode
233      c = mkCall(rtn, [callnode.args[0], self.carryvar, ast.Num(self.current_carry)])
234      self.current_carry += 1
235      return c
[813]236    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
[753]237      rtn = "BitBlock_scanthru_%s" % self.mode
238      c = mkCall(rtn, [callnode.args[0], callnode.args[1], self.carryvar, ast.Num(self.current_carry)])
239      self.current_carry += 1
240      return c
241    else: return callnode
[765]242  def visit_BinOp(self, exprnode):
[753]243    self.generic_visit(exprnode)
244    if isinstance(exprnode.op, ast.Sub):
245      rtn = "BitBlock_sub_%s" % self.mode
246      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
247      self.current_carry += 1
248      return c
249    elif isinstance(exprnode.op, ast.Add):
250      rtn = "BitBlock_add_%s" % self.mode
251      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
252      self.current_carry += 1
253      return c
254    else: return exprnode
255  def visit_If(self, ifNode):
256    carry_base = self.current_carry
257    carries = CarryCounter().count(ifNode)
258    self.generic_visit(ifNode)
259    if carries == 0 or self.mode == "co": return ifNode
260    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
261    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('CarryTest', carry_arglist)])
262    new_else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
263    return ast.If(new_test, ifNode.body, new_else_part)
264  def visit_While(self, whileNode):
265    carry_base = self.current_carry
266    carries = CarryCounter().count(whileNode)
267    if carries == 0: return whileNode
268    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
269    local_carryvar = 'sub'+self.carryvar.id
270    inner_while = CarryIntro(local_carryvar, 'co').generic_xfrm(copy.deepcopy(whileNode))
271    self.generic_visit(whileNode)
272    local_decl = mkCallStmt('CarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
[777]273    local_init = mkCallStmt('CarryInit', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
[753]274    inner_while.body.insert(0, local_decl)
[777]275    inner_while.body.insert(1, local_init)
[753]276    final_combine = mkCallStmt('CarryCombine', [self.carryvar, ast.Name(local_carryvar, ast.Load()), ast.Num(carry_base), ast.Num(carries)])
277    inner_while.body.append(final_combine)
278    if self.mode == "co": new_test = whileNode.test
279    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('CarryTest', carry_arglist)])
280    else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
281    return ast.If(new_test, whileNode.body + [inner_while], else_part)
282
283class StreamStructGen(ast.NodeVisitor):
284  """
285  Given a BitStreamSet subclass, generate the equivalent C struct.
286  >>> obj = ast.parse(r'''
287  ... class S1(BitStreamSet):
288  ...   a1 = 0
289  ...   a2 = 0
290  ...   a3 = 0
291  ...
292  ... class S2(BitStreamSet):
293  ...   x1 = 0
294  ...   x2 = 0
295  ... ''')
296  >>> print StreamStructGen().gen(obj)
297  struct S1 {
298    BitBlock a1;
299    BitBlock a2;
300    BitBlock a3;
301  }
302 
303  struct S2 {
304    BitBlock x1;
305    BitBlock x2;
306  }
307  """
308  def __init__(self):
309    pass
310  def gen(self, tree):
311    self.Ccode=""
312    self.generic_visit(tree)
313    return self.Ccode
314  def visit_ClassDef(self, node):
[756]315#    self.Ccode += "typedef struct {\n"
316    self.Ccode += "struct {\n"
[753]317    for stmt in node.body:
318      if isinstance(stmt, ast.Assign):
319        for v in stmt.targets:
320          if isinstance(v, ast.Name):
321            self.Ccode += "  BitBlock " + v.id + ";\n"
322    self.Ccode += "} " + node.name + ";\n\n"
323 
324class StreamFunctionDecl(ast.NodeVisitor):
325  def __init__(self):
326    pass
327  def gen(self, tree):
328    self.Ccode=""
329    self.generic_visit(tree)
330    return self.Ccode
331  def visit_FunctionDef(self, node):
332    self.Ccode += "static inline void " + node.name + "("
333    pending = ""
334    for arg in node.args.args:
335      if isinstance(arg, ast.Name):
336        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
337        pending = ", "
338    self.Ccode += ");\n"
339
340
341
[810]342#
343# Adding Debugging Statements
344#
345class Add_SIMD_Register_Dump(ast.NodeTransformer):
346  def xfrm(self, t):
347    return self.generic_visit(t)
348  def visit_Assign(self, t):
349    self.generic_visit(t)
350    v = t.targets[0]
351    dump_stmt = mkCallStmt('print_simd_register', [ast.Str(Cgen.py2C().gen(v)), v])
352    return [t, dump_stmt]
353   
[813]354#
[753]355#  Translate a function
356#
357
358class FunctionXlator(ast.NodeVisitor):
359  def xlat(self, node):
360    self.Ccode=""
361    self.generic_visit(node)
362    return self.Ccode
363  def visit_FunctionDef(self, fndef):
[771]364    AugAssignRemoval().xfrm(fndef)
[753]365    Bitwise_to_SIMD().xfrm(fndef)
366    self.Ccode += BitBlock_header_of_fn(fndef) + " {\n"
367    self.Ccode += BitBlock_decls_of_fn(fndef)
368    CarryIntro().xfrm_fndef(fndef)
369    self.Ccode += Cgen.py2C().gen(fndef.body)
370    self.Ccode += "\n}\n"
371
372def main(infilename, outfile = sys.stdout):
373  t = ast.parse(file(infilename).read())
374  outfile.write(StreamStructGen().gen(t))
375  outfile.write(FunctionXlator().xlat(t))
376
377#
378#
379#  Routines for compatibility with the old compiler/template.
380#  Quick and dirty hacks for now - Dec. 2010.
381#
382
[777]383class MainLoopTransformer:
384  def __init__(self, main_module, carry_var = "carryQ"):
385    self.main_module = main_module
386    self.main_fn = main_module.body[-1]
387    assert (isinstance(self.main_fn, ast.FunctionDef))
388    self.carry_count = CarryCounter().count(self.main_fn)
389    self.carry_var = carry_var
390  def any_carry_expr(self):
391    if self.carry_count == 0: return "1"
[787]392    else: return "CarryTest(%s, 0, %i)\n" % (self.carry_var, self.carry_count)
[777]393  def gen_declarations(self):
394    self.Cdecls = StreamStructGen().gen(self.main_module)
395    self.Cdecls += BitBlock_decls_of_fn(self.main_fn)
396    if self.carry_count > 0: self.Cdecls += "CarryDeclare(%s, %i);\n" % (self.carry_var, self.carry_count)
397  def gen_initializations(self):
398    self.Cinits = ""
399    if self.carry_count > 0: self.Cinits += "CarryInit(%s, %i);\n" % (self.carry_var, self.carry_count)
400    self.Cinits += StreamInitializations().xfrm(self.main_module)
[810]401  def xfrm_block_stmts(self, add_dump_stmts=False):
[777]402    AugAssignRemoval().xfrm(self.main_fn)
403    Bitwise_to_SIMD().xfrm(self.main_fn)
404    CarryIntro().xfrm_fndef(self.main_fn)
[810]405    if add_dump_stmts: Add_SIMD_Register_Dump().xfrm(self.main_fn)
[779]406  def add_loop_carryQ_adjust(self):
407    self.main_fn.body += [mkCallStmt('CarryQ_Adjust', [ast.Name(self.carry_var, ast.Load()), ast.Num(self.carry_count)])]
[777]408  def getCdecls(self):
409    return self.Cdecls
410  def getCinits(self):
411    return self.Cinits
412  def getCstmts(self):
413    return Cgen.py2C().gen(self.main_fn.body)
[753]414
415if __name__ == "__main__":
416                import doctest
417                doctest.testmod()
418
419
Note: See TracBrowser for help on using the repository browser.