source: proto/Compiler/pablo.py @ 787

Last change on this file since 787 was 787, checked in by cameron, 9 years ago

Fix semicolon bug.

File size: 13.2 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, Robert D. Cameron
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11
12# HELPER functions for AST recognition/construction
13
14def is_Advance_Call(fncall):   # extracted from old py2bitexpr.py
15        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
16        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
17                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
18        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
19
20def is_ScanThru_Call(fncall):  # extracted from old py2bitexpr.py
21        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
22        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
23                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
24        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
25
26def is_simd_not(e):
27  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
28
29def mkQname(obj, field):
30  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
31
32def mkCall(fn_name, args):
33  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
34  return ast.Call(fn_name, args, [], None, None)
35
36def mkCallStmt(fn_name, args):
37  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
38  return ast.Expr(ast.Call(fn_name, args, [], None, None))
39
40#
41# Reducing AugAssign
42#
43class AugAssignRemoval(ast.NodeTransformer):
44  def xfrm(self, t):
45    return self.generic_visit(t)
46  def visit_AugAssign(self, e):
47    self.generic_visit(e)
48    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
49
50#
51# Introducing BitBlock logical operations
52#
53class Bitwise_to_SIMD(ast.NodeTransformer):
54  """
55  Make the following substitutions:
56     x & y => simd_and(x, y)
57     x & ~y => simd_andc(x, y)
58     x | y => simd_or(x, y)
59     x ^ y => simd_xor(x, y)
60     ~x    => simd_not(x)
61     0     => simd_const_1(0)
62     -1    => simd_const_1(1)
63     if x: => if bitblock_has_bit(x):
64  while x: => while bitblock_has_bit(x):
65  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
66 
67  pfx = simd_and(bit0, bit1)
68  sfx = simd_and(bit0, simd_not(bit1))
69  >>>
70  """
71  def xfrm(self, t):
72    return self.generic_visit(t)
73  def visit_UnaryOp(self, t):
74    self.generic_visit(t)
75    if isinstance(t.op, ast.Invert):
76      return mkCall('simd_not', [t.operand])
77    else: return t
78  def visit_BinOp(self, t):
79    self.generic_visit(t)
80    if isinstance(t.op, ast.BitOr):
81      return mkCall('simd_or', [t.left, t.right])
82    elif isinstance(t.op, ast.BitAnd):
83      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
84      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
85      else: return mkCall('simd_and', [t.left, t.right])
86    elif isinstance(t.op, ast.BitXor):
87      return mkCall('simd_xor', [t.left, t.right])
88    else: return t
89  def visit_Num(self, numnode):
90    n = numnode.n
91    if n == 0: return mkCall('simd_const_1', [numnode])
92    elif n == -1: return mkCall('simd_const_1', [ast.Num(1)])
93    else: return numnode
94  def visit_If(self, ifNode):
95    self.generic_visit(ifNode)
96    ifNode.test = mkCall('bitblock_has_bit', [ifNode.test])
97    return ifNode
98  def visit_While(self, whileNode):
99    self.generic_visit(whileNode)
100    whileNode.test = mkCall('bitblock_has_bit', [whileNode.test])
101    return whileNode
102  def visit_Subscript(self, numnode):
103    return numnode  # no recursive modifications of index expressions
104
105
106#
107#  Generating BitBlock declarations for Local Variables
108#
109#
110class LocalVars(ast.NodeVisitor):
111  def visit_Name(self, nm):
112    if isinstance(nm.ctx, ast.Param):
113      self.params.append(nm.id)
114    if isinstance(nm.ctx, ast.Store):
115      if nm.id not in self.stores: self.stores.append(nm.id)
116  def get(self, node): 
117    self.params=[]
118    self.stores=[]
119    self.generic_visit(node)
120    return [v for v in self.stores if not v in self.params]
121
122MAX_LINE_LENGTH = 80
123
124def BitBlock_decls_from_vars(varlist):
125  global MAX_LINE_LENGTH
126  decls = "  BitBlock"
127  pending = ""
128  linelgth = 10
129  for v in varlist:
130    if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
131      decls += pending + " " + v
132      linelgth += len(pending + v) + 1
133    else:
134      decls += ";\n  BitBlock " + v
135      linelgth = 11 + len(v)
136    pending = ","
137  decls += ";\n"
138  return decls
139
140def BitBlock_decls_of_fn(fndef):
141  return BitBlock_decls_from_vars(LocalVars().get(fndef))
142
143def BitBlock_header_of_fn(fndef):
144  Ccode = "static inline void " + fndef.name + "("
145  pending = ""
146  for arg in fndef.args.args:
147    if isinstance(arg, ast.Name):
148      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
149      pending = ", "
150  Ccode += ")"
151  return Ccode
152
153
154
155#
156#  Stream Initialization Statement Extraction
157#
158#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
159class StreamInitializations(ast.NodeTransformer):
160  def xfrm(self, node):
161    self.stream_stmts = []
162    self.loop_post_inits = []
163    self.generic_visit(node)
164    return Cgen.py2C().gen(self.stream_stmts)
165  def visit_Assign(self, node):
166    if isinstance(node.value, ast.Num):
167      if node.value.n == 0: return node
168      elif node.value.n == -1: return node
169      else: 
170        stream_init = copy.deepcopy(node)
171        stream_init.value = mkCall('sisd_from_int', [node.value])
172        loop_init = copy.deepcopy(node)
173        loop_init.value.n = 0
174        self.stream_stmts.append(stream_init)
175        self.loop_post_inits.append(loop_init)
176        return None
177    else: return node
178  def visit_FunctionDef(self, node):
179    self.generic_visit(node)
180    node.body = node.body + self.loop_post_inits
181    return node
182
183
184#
185# Carry Introduction Transformation
186#
187
188class CarryCounter(ast.NodeVisitor):
189  def visit_Call(self, callnode):
190    self.generic_visit(callnode)
191    if is_Advance_Call(callnode) or is_ScanThru_Call(callnode): 
192      self.carry_count += 1
193  def visit_BinOp(self, exprnode):
194    self.generic_visit(exprnode)
195    if isinstance(exprnode.op, ast.Sub):
196      self.carry_count += 1
197    if isinstance(exprnode.op, ast.Add):
198      self.carry_count += 1
199  def count(self, nodeToVisit):
200    self.carry_count = 0
201    self.generic_visit(nodeToVisit)
202    return self.carry_count
203
204
205class CarryIntro(ast.NodeTransformer):
206  def __init__(self, carryvar="carryQ", mode = "ci_co"):
207    self.carryvar = ast.Name(carryvar, ast.Load())
208    self.mode = mode
209  def xfrm_fndef(self, fndef):
210    self.current_carry = 0
211    carry_count = CarryCounter().count(fndef)
212    if carry_count == 0: return fndef
213    self.generic_visit(fndef)
214#   
215#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
216    return fndef
217  def generic_xfrm(self, node):
218    self.current_carry = 0
219    carry_count = CarryCounter().count(node)
220    if carry_count == 0: return node
221    self.generic_visit(node)
222    return node
223  def visit_Call(self, callnode):
224    self.generic_visit(callnode)
225    if is_Advance_Call(callnode):
226      rtn = "BitBlock_advance_%s" % self.mode
227      c = mkCall(rtn, [callnode.args[0], self.carryvar, ast.Num(self.current_carry)])
228      self.current_carry += 1
229      return c
230    elif is_ScanThru_Call(callnode): 
231      rtn = "BitBlock_scanthru_%s" % self.mode
232      c = mkCall(rtn, [callnode.args[0], callnode.args[1], self.carryvar, ast.Num(self.current_carry)])
233      self.current_carry += 1
234      return c
235    else: return callnode
236  def visit_BinOp(self, exprnode):
237    self.generic_visit(exprnode)
238    if isinstance(exprnode.op, ast.Sub):
239      rtn = "BitBlock_sub_%s" % self.mode
240      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
241      self.current_carry += 1
242      return c
243    elif isinstance(exprnode.op, ast.Add):
244      rtn = "BitBlock_add_%s" % self.mode
245      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
246      self.current_carry += 1
247      return c
248    else: return exprnode
249  def visit_If(self, ifNode):
250    carry_base = self.current_carry
251    carries = CarryCounter().count(ifNode)
252    self.generic_visit(ifNode)
253    if carries == 0 or self.mode == "co": return ifNode
254    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
255    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('CarryTest', carry_arglist)])
256    new_else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
257    return ast.If(new_test, ifNode.body, new_else_part)
258  def visit_While(self, whileNode):
259    carry_base = self.current_carry
260    carries = CarryCounter().count(whileNode)
261    if carries == 0: return whileNode
262    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
263    local_carryvar = 'sub'+self.carryvar.id
264    inner_while = CarryIntro(local_carryvar, 'co').generic_xfrm(copy.deepcopy(whileNode))
265    self.generic_visit(whileNode)
266    local_decl = mkCallStmt('CarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
267    local_init = mkCallStmt('CarryInit', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
268    inner_while.body.insert(0, local_decl)
269    inner_while.body.insert(1, local_init)
270    final_combine = mkCallStmt('CarryCombine', [self.carryvar, ast.Name(local_carryvar, ast.Load()), ast.Num(carry_base), ast.Num(carries)])
271    inner_while.body.append(final_combine)
272    if self.mode == "co": new_test = whileNode.test
273    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('CarryTest', carry_arglist)])
274    else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
275    return ast.If(new_test, whileNode.body + [inner_while], else_part)
276
277class StreamStructGen(ast.NodeVisitor):
278  """
279  Given a BitStreamSet subclass, generate the equivalent C struct.
280  >>> obj = ast.parse(r'''
281  ... class S1(BitStreamSet):
282  ...   a1 = 0
283  ...   a2 = 0
284  ...   a3 = 0
285  ...
286  ... class S2(BitStreamSet):
287  ...   x1 = 0
288  ...   x2 = 0
289  ... ''')
290  >>> print StreamStructGen().gen(obj)
291  struct S1 {
292    BitBlock a1;
293    BitBlock a2;
294    BitBlock a3;
295  }
296 
297  struct S2 {
298    BitBlock x1;
299    BitBlock x2;
300  }
301  """
302  def __init__(self):
303    pass
304  def gen(self, tree):
305    self.Ccode=""
306    self.generic_visit(tree)
307    return self.Ccode
308  def visit_ClassDef(self, node):
309#    self.Ccode += "typedef struct {\n"
310    self.Ccode += "struct {\n"
311    for stmt in node.body:
312      if isinstance(stmt, ast.Assign):
313        for v in stmt.targets:
314          if isinstance(v, ast.Name):
315            self.Ccode += "  BitBlock " + v.id + ";\n"
316    self.Ccode += "} " + node.name + ";\n\n"
317 
318class StreamFunctionDecl(ast.NodeVisitor):
319  def __init__(self):
320    pass
321  def gen(self, tree):
322    self.Ccode=""
323    self.generic_visit(tree)
324    return self.Ccode
325  def visit_FunctionDef(self, node):
326    self.Ccode += "static inline void " + node.name + "("
327    pending = ""
328    for arg in node.args.args:
329      if isinstance(arg, ast.Name):
330        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
331        pending = ", "
332    self.Ccode += ");\n"
333
334
335
336#  Translate a function
337#
338
339class FunctionXlator(ast.NodeVisitor):
340  def xlat(self, node):
341    self.Ccode=""
342    self.generic_visit(node)
343    return self.Ccode
344  def visit_FunctionDef(self, fndef):
345    AugAssignRemoval().xfrm(fndef)
346    Bitwise_to_SIMD().xfrm(fndef)
347    self.Ccode += BitBlock_header_of_fn(fndef) + " {\n"
348    self.Ccode += BitBlock_decls_of_fn(fndef)
349    CarryIntro().xfrm_fndef(fndef)
350    self.Ccode += Cgen.py2C().gen(fndef.body)
351    self.Ccode += "\n}\n"
352
353def main(infilename, outfile = sys.stdout):
354  t = ast.parse(file(infilename).read())
355  outfile.write(StreamStructGen().gen(t))
356  outfile.write(FunctionXlator().xlat(t))
357
358#
359#
360#  Routines for compatibility with the old compiler/template.
361#  Quick and dirty hacks for now - Dec. 2010.
362#
363
364class MainLoopTransformer:
365  def __init__(self, main_module, carry_var = "carryQ"):
366    self.main_module = main_module
367    self.main_fn = main_module.body[-1]
368    assert (isinstance(self.main_fn, ast.FunctionDef))
369    self.carry_count = CarryCounter().count(self.main_fn)
370    self.carry_var = carry_var
371  def any_carry_expr(self):
372    if self.carry_count == 0: return "1"
373    else: return "CarryTest(%s, 0, %i)\n" % (self.carry_var, self.carry_count)
374  def gen_declarations(self):
375    self.Cdecls = StreamStructGen().gen(self.main_module)
376    self.Cdecls += BitBlock_decls_of_fn(self.main_fn)
377    if self.carry_count > 0: self.Cdecls += "CarryDeclare(%s, %i);\n" % (self.carry_var, self.carry_count)
378  def gen_initializations(self):
379    self.Cinits = ""
380    if self.carry_count > 0: self.Cinits += "CarryInit(%s, %i);\n" % (self.carry_var, self.carry_count)
381    self.Cinits += StreamInitializations().xfrm(self.main_module)
382  def xfrm_block_stmts(self):
383    AugAssignRemoval().xfrm(self.main_fn)
384    Bitwise_to_SIMD().xfrm(self.main_fn)
385    CarryIntro().xfrm_fndef(self.main_fn)
386  def add_loop_carryQ_adjust(self):
387    self.main_fn.body += [mkCallStmt('CarryQ_Adjust', [ast.Name(self.carry_var, ast.Load()), ast.Num(self.carry_count)])]
388  def getCdecls(self):
389    return self.Cdecls
390  def getCinits(self):
391    return self.Cinits
392  def getCstmts(self):
393    return Cgen.py2C().gen(self.main_fn.body)
394
395if __name__ == "__main__":
396                import doctest
397                doctest.testmod()
398
399
Note: See TracBrowser for help on using the repository browser.