source: proto/Compiler/pablo.py @ 765

Last change on this file since 765 was 765, checked in by cameron, 8 years ago

Fix visit_BinOp method name.

File size: 11.9 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, Robert D. Cameron
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11
12# HELPER functions for AST recognition/construction
13
14def is_Advance_Call(fncall):   # extracted from old py2bitexpr.py
15        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
16        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
17                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
18        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
19
20def is_ScanThru_Call(fncall):  # extracted from old py2bitexpr.py
21        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
22        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
23                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
24        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
25
26def mkQname(obj, field):
27  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
28
29def mkCall(fn_name, args):
30  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
31  return ast.Call(fn_name, args, [], None, None)
32
33def mkCallStmt(fn_name, args):
34  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
35  return ast.Expr(ast.Call(fn_name, args, [], None, None))
36
37#
38# Introducing BitBlock logical operations
39#
40class Bitwise_to_SIMD(ast.NodeTransformer):
41  """
42  Make the following substitutions:
43     x & y => simd_and(x, y)
44     x | y => simd_or(x, y)
45     x ^ y => simd_xor(x, y)
46     ~x    => simd_not(x)
47     0     => simd_const_1(0)
48     -1    => simd_const_1(1)
49     if x: => if bitblock_has_bit(x):
50  while x: => while bitblock_has_bit(x):
51  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
52 
53  pfx = simd_and(bit0, bit1)
54  sfx = simd_and(bit0, simd_not(bit1))
55  >>>
56  """
57  def xfrm(self, t):
58    return self.generic_visit(t)
59  def visit_UnaryOp(self, t):
60    self.generic_visit(t)
61    if isinstance(t.op, ast.Invert):
62      return mkCall('simd_not', [t.operand])
63    else: return t
64  def visit_BinOp(self, t):
65    self.generic_visit(t)
66    if isinstance(t.op, ast.BitOr):
67      return mkCall('simd_or', [t.left, t.right])
68    elif isinstance(t.op, ast.BitAnd):
69      return mkCall('simd_and', [t.left, t.right])
70    elif isinstance(t.op, ast.BitXor):
71      return mkCall('simd_xor', [t.left, t.right])
72    else: return t
73  def visit_Num(self, numnode):
74    n = numnode.n
75    if n == 0: return mkCall('simd_const_1', [numnode])
76    elif n == -1: return mkCall('simd_const_1', [ast.Num(1)])
77    else: return numnode
78  def visit_If(self, ifNode):
79    self.generic_visit(ifNode)
80    ifNode.test = mkCall('bitblock_has_bit', [ifNode.test])
81    return ifNode
82  def visit_While(self, whileNode):
83    self.generic_visit(whileNode)
84    whileNode.test = mkCall('bitblock_has_bit', [whileNode.test])
85    return whileNode
86  def visit_Subscript(self, numnode):
87    return numnode  # no recursive modifications of index expressions
88
89
90#
91#  Generating BitBlock declarations for Local Variables
92#
93#
94class LocalVars(ast.NodeVisitor):
95  def visit_Name(self, nm):
96    if isinstance(nm.ctx, ast.Param):
97      self.params.append(nm.id)
98    if isinstance(nm.ctx, ast.Store):
99      if nm.id not in self.stores: self.stores.append(nm.id)
100  def get(self, node): 
101    self.params=[]
102    self.stores=[]
103    self.generic_visit(node)
104    return [v for v in self.stores if not v in self.params]
105
106MAX_LINE_LENGTH = 80
107
108def BitBlock_decls_from_vars(varlist):
109  global MAX_LINE_LENGTH
110  decls = "  BitBlock"
111  pending = ""
112  linelgth = 10
113  for v in varlist:
114    if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
115      decls += pending + " " + v
116      linelgth += len(pending + v) + 1
117    else:
118      decls += ";\n  BitBlock " + v
119      linelgth = 11 + len(v)
120    pending = ","
121  decls += ";\n"
122  return decls
123
124def BitBlock_decls_of_fn(fndef):
125  return BitBlock_decls_from_vars(LocalVars().get(fndef))
126
127def BitBlock_header_of_fn(fndef):
128  Ccode = "static inline void " + fndef.name + "("
129  pending = ""
130  for arg in fndef.args.args:
131    if isinstance(arg, ast.Name):
132      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
133      pending = ", "
134  Ccode += ")"
135  return Ccode
136
137
138
139#
140#  Stream Initialization Statement Extraction
141#
142#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
143class StreamInitializations(ast.NodeTransformer):
144  def xfrm(self, node):
145    self.stream_stmts = []
146    self.loop_post_inits = []
147    self.generic_visit(node)
148    return Cgen.py2C().gen(self.stream_stmts)
149  def visit_Assign(self, node):
150    if isinstance(node.value, ast.Num):
151      if node.value.n == 0: return node
152      elif node.value.n == -1: return node
153      else: 
154        stream_init = copy.deepcopy(node)
155        stream_init.value = mkCall('sisd_from_int', [node.value])
156        loop_init = copy.deepcopy(node)
157        loop_init.value.n = 0
158        self.stream_stmts.append(stream_init)
159        self.loop_post_inits.append(loop_init)
160        return None
161    else: return node
162  def visit_FunctionDef(self, node):
163    self.generic_visit(node)
164    node.body = node.body + self.loop_post_inits
165    return node
166
167
168#
169# Carry Introduction Transformation
170#
171
172class CarryCounter(ast.NodeVisitor):
173  def visit_Call(self, callnode):
174    self.generic_visit(callnode)
175    if is_Advance_Call(callnode) or is_ScanThru_Call(callnode): 
176      self.carry_count += 1
177  def visit_Binop(self, exprnode):
178    self.generic_visit(exprnode)
179    if isinstance(exprnode.op, ast.Sub):
180      self.carry_count += 1
181  def count(self, nodeToVisit):
182    self.carry_count = 0
183    self.generic_visit(nodeToVisit)
184    return self.carry_count
185
186
187def Carry_decl_of_fn(fndef):
188  carry_count = CarryCounter().count(fndef)
189  if carry_count == 0: return ""
190  else: return "  CarryDeclare(carryQ, %i);\n" % carry_count
191
192                   
193class CarryIntro(ast.NodeTransformer):
194  def __init__(self, carryvar="carryQ", mode = "ci_co"):
195    self.carryvar = ast.Name(carryvar, ast.Load())
196    self.mode = mode
197  def any_carry_expr(self, fndef):
198    carry_count = CarryCounter().count(fndef)
199    carry_arglist = [self.carryvar, ast.Num(0), ast.Num(carry_count)]
200    return mkCall('CarryTest', carry_arglist)
201  def xfrm_fndef(self, fndef):
202    self.current_carry = 0
203    carry_count = CarryCounter().count(fndef)
204    if carry_count == 0: return fndef
205    self.generic_visit(fndef)
206#   
207#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
208    return fndef
209  def generic_xfrm(self, node):
210    self.current_carry = 0
211    carry_count = CarryCounter().count(node)
212    if carry_count == 0: return node
213    self.generic_visit(node)
214    return node
215  def visit_Call(self, callnode):
216    self.generic_visit(callnode)
217    if is_Advance_Call(callnode):
218      rtn = "BitBlock_advance_%s" % self.mode
219      c = mkCall(rtn, [callnode.args[0], self.carryvar, ast.Num(self.current_carry)])
220      self.current_carry += 1
221      return c
222    elif is_ScanThru_Call(callnode): 
223      rtn = "BitBlock_scanthru_%s" % self.mode
224      c = mkCall(rtn, [callnode.args[0], callnode.args[1], self.carryvar, ast.Num(self.current_carry)])
225      self.current_carry += 1
226      return c
227    else: return callnode
228  def visit_BinOp(self, exprnode):
229    self.generic_visit(exprnode)
230    if isinstance(exprnode.op, ast.Sub):
231      rtn = "BitBlock_sub_%s" % self.mode
232      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
233      self.current_carry += 1
234      return c
235    elif isinstance(exprnode.op, ast.Add):
236      rtn = "BitBlock_add_%s" % self.mode
237      c = mkCall(rtn, [exprnode.left, exprnode.right, self.carryvar, ast.Num(self.current_carry)])
238      self.current_carry += 1
239      return c
240    else: return exprnode
241  def visit_If(self, ifNode):
242    carry_base = self.current_carry
243    carries = CarryCounter().count(ifNode)
244    self.generic_visit(ifNode)
245    if carries == 0 or self.mode == "co": return ifNode
246    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
247    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('CarryTest', carry_arglist)])
248    new_else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
249    return ast.If(new_test, ifNode.body, new_else_part)
250  def visit_While(self, whileNode):
251    carry_base = self.current_carry
252    carries = CarryCounter().count(whileNode)
253    if carries == 0: return whileNode
254    carry_arglist = [self.carryvar, ast.Num(carry_base), ast.Num(carries)]
255    local_carryvar = 'sub'+self.carryvar.id
256    inner_while = CarryIntro(local_carryvar, 'co').generic_xfrm(copy.deepcopy(whileNode))
257    self.generic_visit(whileNode)
258    local_decl = mkCallStmt('CarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
259    inner_while.body.insert(0, local_decl)
260    final_combine = mkCallStmt('CarryCombine', [self.carryvar, ast.Name(local_carryvar, ast.Load()), ast.Num(carry_base), ast.Num(carries)])
261    inner_while.body.append(final_combine)
262    if self.mode == "co": new_test = whileNode.test
263    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('CarryTest', carry_arglist)])
264    else_part = [mkCallStmt('CarryDequeueEnqueue', carry_arglist)]
265    return ast.If(new_test, whileNode.body + [inner_while], else_part)
266
267class StreamStructGen(ast.NodeVisitor):
268  """
269  Given a BitStreamSet subclass, generate the equivalent C struct.
270  >>> obj = ast.parse(r'''
271  ... class S1(BitStreamSet):
272  ...   a1 = 0
273  ...   a2 = 0
274  ...   a3 = 0
275  ...
276  ... class S2(BitStreamSet):
277  ...   x1 = 0
278  ...   x2 = 0
279  ... ''')
280  >>> print StreamStructGen().gen(obj)
281  struct S1 {
282    BitBlock a1;
283    BitBlock a2;
284    BitBlock a3;
285  }
286 
287  struct S2 {
288    BitBlock x1;
289    BitBlock x2;
290  }
291  """
292  def __init__(self):
293    pass
294  def gen(self, tree):
295    self.Ccode=""
296    self.generic_visit(tree)
297    return self.Ccode
298  def visit_ClassDef(self, node):
299#    self.Ccode += "typedef struct {\n"
300    self.Ccode += "struct {\n"
301    for stmt in node.body:
302      if isinstance(stmt, ast.Assign):
303        for v in stmt.targets:
304          if isinstance(v, ast.Name):
305            self.Ccode += "  BitBlock " + v.id + ";\n"
306    self.Ccode += "} " + node.name + ";\n\n"
307 
308class StreamFunctionDecl(ast.NodeVisitor):
309  def __init__(self):
310    pass
311  def gen(self, tree):
312    self.Ccode=""
313    self.generic_visit(tree)
314    return self.Ccode
315  def visit_FunctionDef(self, node):
316    self.Ccode += "static inline void " + node.name + "("
317    pending = ""
318    for arg in node.args.args:
319      if isinstance(arg, ast.Name):
320        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
321        pending = ", "
322    self.Ccode += ");\n"
323
324
325
326#  Translate a function
327#
328
329class FunctionXlator(ast.NodeVisitor):
330  def xlat(self, node):
331    self.Ccode=""
332    self.generic_visit(node)
333    return self.Ccode
334  def visit_FunctionDef(self, fndef):
335    Bitwise_to_SIMD().xfrm(fndef)
336    self.Ccode += BitBlock_header_of_fn(fndef) + " {\n"
337    self.Ccode += BitBlock_decls_of_fn(fndef)
338    CarryIntro().xfrm_fndef(fndef)
339    self.Ccode += Cgen.py2C().gen(fndef.body)
340    self.Ccode += "\n}\n"
341
342def main(infilename, outfile = sys.stdout):
343  t = ast.parse(file(infilename).read())
344  outfile.write(StreamStructGen().gen(t))
345  outfile.write(FunctionXlator().xlat(t))
346
347#
348#
349#  Routines for compatibility with the old compiler/template.
350#  Quick and dirty hacks for now - Dec. 2010.
351#
352def gen_declarations(main_module, main_fn):
353  Cdecls = StreamStructGen().gen(main_module)
354  Cdecls += BitBlock_decls_of_fn(main_fn)
355  Cdecls += Carry_decl_of_fn(main_fn)
356  return Cdecls
357
358def stream_stmts(main_fn):
359  return StreamInitializations().xfrm(main_fn)
360
361def block_stmts(main_fn):
362  Bitwise_to_SIMD().xfrm(main_fn)
363  CarryIntro().xfrm_fndef(main_fn)
364  return Cgen.py2C().gen(main_fn.body)
365
366def any_carry_expr(main_fn):
367  return Cgen.py2C().gen(CarryIntro().any_carry_expr(main_fn))
368
369if __name__ == "__main__":
370                import doctest
371                doctest.testmod()
372
373
Note: See TracBrowser for help on using the repository browser.