source: proto/Compiler/pablo.py @ 2049

Last change on this file since 2049 was 2049, checked in by cameron, 7 years ago

Use library implementations for inclusive/exclusive span

File size: 34.7 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, 2011, Robert D. Cameron, Kenneth S. Herdy
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11name_substitution_map = {}
12do_block_inline_decorator = 'IDISA_INLINE '
13do_final_block_inline_decorator = ''
14
15def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='pablo'):
16        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
17        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
18                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
19        return iscall and len(fncall.args) == builtin_arg_cnt
20
21def dump_Call(fncall):
22        if isinstance(fncall.func, ast.Name): print "fn_name = %s\n" % fncall.func.id
23        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
24                print "fn_name = %s.%s\n" % (fncall.func.value.id, fncall.func.attr)
25        print "len(fncall.args) = %s\n" % len(fncall.args)
26
27def is_simd_not(e):
28  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
29
30def mkQname(obj, field):
31  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
32
33def mkCall(fn_name, args):
34  if isinstance(fn_name, str): 
35        if name_substitution_map.has_key(fn_name): fn_name = name_substitution_map[fn_name]
36        fn_name = ast.Name(fn_name, ast.Load())
37  return ast.Call(fn_name, args, [], None, None)
38
39def mkCallStmt(fn_name, args):
40  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
41  return ast.Expr(ast.Call(fn_name, args, [], None, None))
42
43#
44# Reducing AugAssign, e.g.  x |= y becomes x = x | y
45#
46class AugAssignRemoval(ast.NodeTransformer):
47  def xfrm(self, t):
48    return self.generic_visit(t)
49  def visit_AugAssign(self, e):
50    self.generic_visit(e)
51    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
52
53
54#
55#  Translating pablo.match(marker, str)
56#  Incremental character match with lookahead
57#
58CharNameMap = {'[' : 'LBrak', ']' : 'RBrak', '{' : 'LBrace', '}' : 'LBrace', '(' : 'LParen', ')' : 'RParen', \
59              '!' : 'Exclam', '"' : 'DQuote', '#' : 'Hash', '$' : 'Dollar', '%' : 'PerCent', '&': 'RefStart', \
60              "'" : 'SQuote', '*': 'Star', '+' : 'Plus', ',' : 'Comma', '-' : 'Hyphen', '.' : 'Dot', '/' : 'Slash', \
61              ':' : 'Colon', ';' : 'Semicolon', '=' : 'Equals', '?' : 'QMark', '@' : 'AtSign', '\\' : 'BackSlash', \
62              '^' : 'Caret', '_' : 'Underscore', '|' : 'VBar', '~' : 'Tilde', ' ' : 'SP', '\t' : 'HT', '\m' : 'CR', '\n' : 'LF'}
63
64def GetCharName(char):
65        if char >= 'a' and char <= 'z' or char >= 'A' and char <= 'Z': return 'letter_' + char
66        elif char >= '0' and char <= '9': return 'digit_' + char
67        else: return CharNameMap[char]
68
69def MkCharStream(char):
70        return mkQname('lex', GetCharName(char))
71
72def MkLookAheadExpr(v, i):
73        return mkCall(mkQname('pablo', 'LookAhead'), [v, ast.Num(i)])
74
75def CompileMatch(match_var, string_to_match):
76        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
77        for i in range(1, len(string_to_match)):
78                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
79        return expr
80
81class StringMatchCompiler(ast.NodeTransformer):
82  def xfrm(self, t):
83    return self.generic_visit(t)
84  def visit_Call(self, callnode):
85    if is_BuiltIn_Call(callnode,'match', 2):
86        ast.dump(callnode)
87        assert isinstance(callnode.args[0], ast.Str)
88        string_to_match = callnode.args[0].s
89        match_var = callnode.args[1]
90        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
91        for i in range(1, len(string_to_match)):
92                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
93        return expr
94    else: return callnode
95
96
97
98#
99# Introducing BitBlock logical operations
100#
101class Bitwise_to_SIMD(ast.NodeTransformer):
102  """
103  Make the following substitutions:
104     x & y => simd_and(x, y)
105     x & ~y => simd_andc(x, y)
106     x | y => simd_or(x, y)
107     x ^ y => simd_xor(x, y)
108     ~x    => simd_not(x)
109     0     => simd_const_1(0)
110     -1    => simd_const_1(1)
111     if x: => if bitblock::any(x):
112  while x: => while bitblock::any(x):
113  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
114 
115  pfx = simd_and(bit0, bit1)
116  sfx = simd_and(bit0, simd_not(bit1))
117  >>>
118  """
119  def xfrm(self, t):
120    return self.generic_visit(t)
121  def visit_UnaryOp(self, t):
122    self.generic_visit(t)
123    if isinstance(t.op, ast.Invert):
124      return mkCall('simd_not', [t.operand])
125    else: return t
126  def visit_BinOp(self, t):
127    self.generic_visit(t)
128    if isinstance(t.op, ast.BitOr):
129      return mkCall('simd_or', [t.left, t.right])
130    elif isinstance(t.op, ast.BitAnd):
131      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
132      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
133      else: return mkCall('simd_and', [t.left, t.right])
134    elif isinstance(t.op, ast.BitXor):
135      return mkCall('simd_xor', [t.left, t.right])
136    else: return t
137  def visit_Num(self, numnode):
138    n = numnode.n
139    if n == 0: return mkCall('simd<1>::constant<0>', [])
140    elif n == -1: return mkCall('simd<1>::constant<1>', [])
141    else: return numnode
142  def visit_If(self, ifNode):
143    self.generic_visit(ifNode)
144    ifNode.test = mkCall('bitblock::any', [ifNode.test])
145    return ifNode
146  def visit_While(self, whileNode):
147    self.generic_visit(whileNode)
148    whileNode.test = mkCall('bitblock::any', [whileNode.test])
149    return whileNode
150  def visit_Subscript(self, numnode):
151    return numnode  # no recursive modifications of index expressions
152
153#
154#  Generating BitBlock declarations for Local Variables
155#
156class FunctionVars(ast.NodeVisitor):
157  def __init__(self,node):
158        self.params = []
159        self.stores = []
160        self.generic_visit(node)
161  def visit_Name(self, nm):
162    if isinstance(nm.ctx, ast.Param):
163      self.params.append(nm.id)
164    if isinstance(nm.ctx, ast.Store):
165      if nm.id not in self.stores: self.stores.append(nm.id)
166  def getLocals(self): 
167    return [v for v in self.stores if not v in self.params]
168
169MAX_LINE_LENGTH = 80
170
171def BitBlock_decls_from_vars(varlist):
172  global MAX_LINE_LENGTH
173  decls =  ""
174  if not len(varlist) == 0:
175          decls = "             BitBlock"
176          pending = ""
177          linelgth = 10
178          for v in varlist:
179            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
180              decls += pending + " " + v
181              linelgth += len(pending + v) + 1
182            else:
183              decls += ";\n             BitBlock " + v
184              linelgth = 11 + len(v)
185            pending = ","
186          decls += ";"
187  return decls
188
189def BitBlock_decls_of_fn(fndef):
190  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
191
192def BitBlock_header_of_fn(fndef):
193  Ccode = "static inline void " + fndef.name + "("
194  pending = ""
195  for arg in fndef.args.args:
196    if isinstance(arg, ast.Name):
197      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
198      pending = ", "
199  if CarryCounter().count(fndef) > 0:
200    Ccode += pending + " CarryQtype & carryQ"
201  Ccode += ")"
202  return Ccode
203
204
205
206#
207#  Stream Initialization Statement Extraction
208#
209#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
210class StreamInitializations(ast.NodeTransformer):
211  def xfrm(self, node):
212    self.stream_stmts = []
213    self.loop_post_inits = []
214    self.generic_visit(node)
215    return Cgen.py2C().gen(self.stream_stmts)
216  def visit_Assign(self, node):
217    if isinstance(node.value, ast.Num):
218      if node.value.n == 0: return node
219      elif node.value.n == -1: return node
220      else: 
221        stream_init = copy.deepcopy(node)
222        stream_init.value = mkCall('sisd_from_int', [node.value])
223        loop_init = copy.deepcopy(node)
224        loop_init.value.n = 0
225        self.stream_stmts.append(stream_init)
226        self.loop_post_inits.append(loop_init)
227        return None
228    else: return node
229  def visit_FunctionDef(self, node):
230    self.generic_visit(node)
231    node.body = node.body + self.loop_post_inits
232    return node
233
234
235
236#
237# Carry Introduction Transformation
238#
239class CarryCounter(ast.NodeVisitor):
240  def visit_Call(self, callnode):
241    self.generic_visit(callnode)
242    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2) or is_BuiltIn_Call(callnode,'ScanTo', 2) or is_BuiltIn_Call(callnode,'ScanToFirst', 1) or is_BuiltIn_Call(callnode,'AdvanceThenScanThru', 2) or is_BuiltIn_Call(callnode,'AdvanceThenScanTo', 2) or is_BuiltIn_Call(callnode,'SpanUpTo', 2) or  is_BuiltIn_Call(callnode,'InclusiveSpan', 2) or is_BuiltIn_Call(callnode,'ExclusiveSpan', 2):         
243      self.carry_count += 1
244  def visit_BinOp(self, exprnode):
245    self.generic_visit(exprnode)
246    if isinstance(exprnode.op, ast.Sub):
247      self.carry_count += 1
248    if isinstance(exprnode.op, ast.Add):
249      self.carry_count += 1
250  def count(self, nodeToVisit):
251    self.carry_count = 0
252    self.generic_visit(nodeToVisit)
253    return self.carry_count
254
255class Adv32Counter(ast.NodeVisitor):
256  def visit_Call(self, callnode):
257    self.generic_visit(callnode)
258    if is_BuiltIn_Call(callnode,'Advance32', 1):       
259      self.adv32_count += 1
260  def count(self, nodeToVisit):
261    self.adv32_count = 0
262    self.generic_visit(nodeToVisit)
263    return self.adv32_count
264
265class CarryIntro(ast.NodeTransformer):
266  def __init__(self, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
267    self.carryvar = ast.Name(carryvar, ast.Load())
268    self.carryin = carryin
269    self.carryout = carryout
270  def xfrm_fndef(self, fndef):
271    self.current_carry = 0
272    self.current_adv32 = 0
273    carry_count = CarryCounter().count(fndef)
274    if carry_count == 0: return fndef
275    self.generic_visit(fndef)
276#   
277#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
278    return fndef
279  def generic_xfrm(self, node):
280    self.current_carry = 0
281    self.current_adv32 = 0
282    carry_count = CarryCounter().count(node)
283    adv32_count = Adv32Counter().count(node)
284    if carry_count == 0 and adv32_count == 0: return node
285    self.generic_visit(node)
286    return node
287  def visit_Call(self, callnode):
288    self.generic_visit(callnode)
289    #CARRYSET
290    #carry_args = [ast.Num(self.current_carry)]
291    #adv32_args = [ast.Subscript(ast.Name('pending32', ast.Load()), ast.Num(self.current_adv32), ast.Load())]
292    adv32_pending = ast.Subscript(ast.Name('pending32', ast.Load()), ast.Num(self.current_adv32), ast.Load())
293    if self.carryin == "_ci":
294        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
295        adv32_args = [adv32_pending, adv32_pending]
296    else: 
297        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
298        adv32_args = [ast.Num(0), adv32_pending]
299
300    if is_BuiltIn_Call(callnode, 'Advance', 1):         
301      #CARRYSET
302      rtn = self.carryvar.id + "." + "BitBlock_advance_ci_co"
303      c = mkCall(rtn, callnode.args + carry_args)
304      self.current_carry += 1
305      return c
306    elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
307      #CARRYSET
308      rtn = self.carryvar.id + "." + "BitBlock_advance32_ci_co"
309      c = mkCall(rtn, callnode.args + adv32_args)
310      self.current_adv32 += 1
311      return c
312    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
313      #CARRYSET
314      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co"
315      c = mkCall(rtn, callnode.args + carry_args)
316      self.current_carry += 1
317      return c
318    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanThru', 2):
319      #CARRYSET
320      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru"
321      c = mkCall(rtn, callnode.args + carry_args)
322      self.current_carry += 1
323      return c
324    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanTo', 2):
325      #CARRYSET
326      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru" 
327      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
328      else: scanclass = mkCall('simd_not', [callnode.args[1]])
329      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
330      self.current_carry += 1
331      return c
332    elif is_BuiltIn_Call(callnode, 'SpanUpTo', 2):
333      #CARRYSET
334      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
335      c = mkCall(rtn, callnode.args + carry_args)
336      self.current_carry += 1
337      return c
338    elif is_BuiltIn_Call(callnode, 'InclusiveSpan', 2):
339      #CARRYSET
340#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
341#      c = mkCall('simd_or', [mkCall(rtn, callnode.args + carry_args), callnode.args[1]])
342      rtn = self.carryvar.id + "." + "BitBlock_inclusive_span"
343      c = mkCall(rtn, callnode.args + carry_args)
344      self.current_carry += 1
345      return c
346    elif is_BuiltIn_Call(callnode, 'ExclusiveSpan', 2):
347      #CARRYSET
348#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
349#      c = mkCall('simd_andc', [mkCall(rtn, callnode.args + carry_args), callnode.args[0]])
350      rtn = self.carryvar.id + "." + "BitBlock_exclusive_span"
351      c = mkCall(rtn, callnode.args + carry_args)
352      self.current_carry += 1
353      return c
354    elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
355      # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
356      # in having a separate BitBlock_scanto routine.
357      #CARRYSET
358      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co" 
359      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
360      else: scanclass = mkCall('simd_not', [callnode.args[1]])
361      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
362      self.current_carry += 1
363      return c
364    elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
365      #CARRYSET
366      rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
367      #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
368      c = mkCall(rtn, callnode.args + carry_args)
369      self.current_carry += 1
370      return c
371    elif is_BuiltIn_Call(callnode, 'atEOF', 1):
372      if self.carryout != "": 
373        # Non final block: atEOF(x) = 0.
374        return mkCall('simd<1>::constant<0>', [])
375      else: return mkCall('simd_andc', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
376    elif is_BuiltIn_Call(callnode, 'inFile', 1):
377      if self.carryout != "": 
378        # Non final block: inFile(x) = x.
379        return callnode.args[0]
380      else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
381    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
382      rtn = "StreamScan"           
383      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
384                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
385                                           ast.Name(callnode.args[1].id, ast.Load())])
386      return c
387    else:
388      #dump_Call(callnode)
389      return callnode
390  def visit_BinOp(self, exprnode):
391    self.generic_visit(exprnode)
392    carry_args = [ast.Num(self.current_carry)]
393    if self.carryin == "_ci":
394        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
395    else: 
396        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
397    if isinstance(exprnode.op, ast.Sub):
398      #CARRYSET
399      rtn = self.carryvar.id + "." + "BitBlock_sub_ci_co"
400      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
401      self.current_carry += 1
402      return c
403    elif isinstance(exprnode.op, ast.Add):
404      #CARRYSET
405      rtn = self.carryvar.id + "." + "BitBlock_add_ci_co"
406      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
407      self.current_carry += 1
408      return c
409    else: return exprnode
410  def visit_If(self, ifNode):
411    carry_base = self.current_carry
412    carries = CarryCounter().count(ifNode)
413    assert Adv32Counter().count(ifNode) == 0, "Advance32() within if: illegal\n"
414    self.generic_visit(ifNode)
415    if carries == 0 or self.carryin == "": return ifNode
416    #CARRYSET
417    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
418    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
419    new_else_part = ifNode.orelse + [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]
420    return ast.If(new_test, ifNode.body, new_else_part)
421  def visit_While(self, whileNode):
422    if self.carryout == '':
423      whileNode.test.args[0] = mkCall("simd_and", [whileNode.test.args[0], ast.Name('EOF_mask', ast.Load())])
424    carry_base = self.current_carry
425    assert Adv32Counter().count(whileNode) == 0, "Advance32() within while: illegal\n"
426    carries = CarryCounter().count(whileNode)
427    #CARRYSET
428    if carries == 0: return whileNode
429    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
430    local_carryvar = 'subcarryQ'
431    inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
432    self.generic_visit(whileNode)
433    local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
434    inner_while.body.insert(0, local_carry_decl)
435    final_combine = mkCallStmt('carryQ.CarryCombine', [ast.Name( '(ICarryQueue *) &' + local_carryvar, ast.Load()), ast.Num(carry_base), ast.Num(carries)])
436    inner_while.body.append(final_combine)
437    #CARRYSET
438    if self.carryin == '': new_test = whileNode.test
439    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
440    else_part = [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]   
441    return ast.If(new_test, whileNode.body + [inner_while], else_part)
442
443class StreamStructGen(ast.NodeVisitor):
444  """
445  Given a BitStreamSet subclass, generate the equivalent C struct.
446  >>> obj = ast.parse(r'''
447  ... class S1(BitStreamSet):
448  ...   a1 = 0
449  ...   a2 = 0
450  ...   a3 = 0
451  ...
452  ... class S2(BitStreamSet):
453  ...   x1 = 0
454  ...   x2 = 0
455  ... ''')
456  >>> print StreamStructGen().gen(obj)
457  struct S1 {
458    BitBlock a1;
459    BitBlock a2;
460    BitBlock a3;
461  }    self.current_adv32 = 0
462
463 
464  struct S2 {
465    BitBlock x1;
466    BitBlock x2;
467  }
468  """
469  def __init__(self, asType=False):
470    self.asType = asType
471  def gen(self, tree):
472    self.Ccode=""
473    self.generic_visit(tree)
474    return self.Ccode
475  def gen_struct_types(self, tree):
476    self.asType = True
477    self.Ccode=""
478    self.generic_visit(tree)
479    return self.Ccode
480  def gen_struct_vars(self, tree):
481    self.asType = False
482    self.Ccode=""
483    self.generic_visit(tree)
484    return self.Ccode
485  def visit_ClassDef(self, node):
486    class_name = node.name[0].upper() + node.name[1:]
487    instance_name = node.name[0].lower() + node.name[1:]
488    self.Ccode += "  struct " + class_name
489    if self.asType:
490            self.Ccode += " {\n"
491            for stmt in node.body:
492              if isinstance(stmt, ast.Assign):
493                for v in stmt.targets:
494                  if isinstance(v, ast.Name):
495                    self.Ccode += "  BitBlock " + v.id + ";\n"
496            self.Ccode += "}" 
497    else: self.Ccode += " " + instance_name
498    self.Ccode += ";\n\n"
499 
500class StreamFunctionDecl(ast.NodeVisitor):
501  def __init__(self):
502    pass
503  def gen(self, tree):
504    self.Ccode=""
505    self.generic_visit(tree)
506    return self.Ccode
507  def visit_FunctionDef(self, node):
508    self.Ccode += "static inline void " + node.name + "("
509    pending = ""
510    for arg in node.args.args:
511      if isinstance(arg, ast.Name):
512        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
513        pending = ", "
514    self.Ccode += ");\n"
515
516#
517# Adding Debugging Statements
518#
519class Add_SIMD_Register_Dump(ast.NodeTransformer):
520  def xfrm(self, t):
521    return self.generic_visit(t)
522  def visit_Assign(self, t):
523    self.generic_visit(t)
524    v = t.targets[0]
525    dump_stmt = mkCallStmt(' print_register<BitBlock>', [ast.Str(Cgen.py2C().gen(v)), v])
526    return [t, dump_stmt]
527
528#
529# Adding ASSERT_BITBLOCK_ALIGN macros
530#
531class Add_Assert_BitBlock_Align(ast.NodeTransformer):
532    def xfrm(self, t):
533      return self.generic_visit(t)
534    def visit_Assign(self, t):
535      self.generic_visit(t)
536      v = t.targets[0]
537      dump_stmt = mkCallStmt(' ASSERT_BITBLOCK_ALIGN', [v])
538      return [t, dump_stmt]
539
540class StreamFunctionCarryCounter(ast.NodeVisitor):
541  def __init__(self):
542        self.carry_count = {}
543       
544  def count(self, node):
545        self.generic_visit(node)
546        return self.carry_count
547                                   
548  def visit_FunctionDef(self, node):   
549        type_name = node.name[0].upper() + node.name[1:]                       
550        self.carry_count[type_name] = CarryCounter().count(node)
551     
552class StreamFunctionCallXlator(ast.NodeTransformer):
553  def __init__(self, xlate_type="normal"):
554        self.stream_function_type_names = []
555        self.xlate_type = xlate_type
556
557  def xfrm(self, node, stream_function_type_names, C_syntax):
558        self.stream_function_type_names = stream_function_type_names
559        self.C_syntax = C_syntax
560        self.generic_visit(node)
561       
562  def visit_Call(self, node):   
563        self.generic_visit(node)
564
565        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):# and node.func.id in self.stream_function_type_names:
566             name = lower1(node.func.id)
567             node.func.id = name + ("_" if self.C_syntax else ".") + ("do_final_block" if self.xlate_type == "final" else "do_block")
568             if self.C_syntax:
569                     node.args = [ast.Name(lower1(name), ast.Load())] + node.args
570             if self.xlate_type == "final":
571                   node.args = node.args + [ast.Name("EOF_mask", ast.Load())]
572                     
573        return node     
574               
575class StreamFunctionVisitor(ast.NodeVisitor):
576        def __init__(self,node):
577                self.stream_function_node = {}
578                self.generic_visit(node)
579                                                             
580        def visit_FunctionDef(self, node):                     
581                key = node.name[0].upper() + node.name[1:]
582                self.stream_function_node[key] = node
583               
584class StreamFunction():
585        def __init__(self):
586                self.carry_count = 0 
587                self.adv32_count = 0 
588                self.type_name = ""
589                self.instance_name = "" 
590                self.parameters = ""
591                self.declarations = "" 
592                self.initializations = "" 
593#
594# TODO Consolidate *all* C code generation into the Emitter class.   Medium priority.
595# TODO Implement 'pretty print' indentation.   Low priority.
596# TODO Migrate Emiter() class to Emitter module.  Medium priority.
597
598def lower1(name):
599    return name[0].lower() + name[1:]
600def upper1(name):
601    return name[0].upper() + name[1:]
602
603def escape_newlines(str):
604  return str.replace('\n', '\\\n')
605
606class Emitter():
607        def __init__(self, use_C_syntax):
608                self.use_C_syntax = use_C_syntax
609
610        def definition(self, stream_function, icount=0):
611               
612                constructor = ""
613                carry_declaration = ""
614                self.type_name = stream_function.type_name
615               
616                if stream_function.carry_count > 0 or stream_function.adv32_count > 0:
617                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.adv32_count)
618                        carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv32_count)
619
620                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), 
621                                                stream_function.declarations, 
622                                                stream_function.initializations, 
623                                                stream_function.statements)             
624
625                do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), 
626                                                stream_function.declarations, 
627                                                stream_function.initializations, 
628                                                stream_function.final_block_statements)                 
629
630                do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), 
631                                                self.do_segment_args(stream_function.parameters))       
632
633                if self.use_C_syntax:
634                        return self.indent(icount) + "struct " + stream_function.type_name + " {" \
635                               + "\n" + self.indent(icount) + carry_declaration \
636                               + "\n" + self.indent(icount) + "};\n" \
637                               + "\n" + self.indent(icount) + do_block_function \
638                               + "\n" + self.indent(icount) + do_final_block_function \
639                               + "\n" + self.indent(icount) + do_segment_function + "\n\n"
640                               
641                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
642                + "\n" + self.indent(icount) + constructor \
643                + "\n" + self.indent(icount) + do_block_function \
644                + "\n" + self.indent(icount) + do_final_block_function \
645                + "\n" + self.indent(icount) + do_segment_function \
646                + "\n" + self.indent(icount) + carry_declaration \
647                + "\n" + self.indent(icount) + "};\n\n"
648
649        def constructor(self, type_name, carry_count, adv32_count, icount=0):
650                adv32_decl = ""
651                for i in range(adv32_count): adv32_decl += self.indent(icount+2) + "pending32[%s] = 0;\n" % i   
652                return self.indent(icount) + "%s() { ""\n" % (type_name) + adv32_decl + self.carry_init(carry_count) + " }" 
653                       
654        def do_block(self, parameters, declarations, initializations, statements, icount=0):
655                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
656                if self.use_C_syntax:
657                        return "#define " + pfx + "do_block(" + parameters + ")\\\n do {" \
658                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
659                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
660                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
661                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
662                return self.indent(icount) + do_block_inline_decorator + "void " + pfx + "do_block(" + parameters + ") {" \
663                + "\n" + self.indent(icount) + declarations \
664                + "\n" + self.indent(icount) + initializations \
665                + "\n" + self.indent(icount) + statements \
666                + "\n" + self.indent(icount + 2) + "}" 
667
668
669
670
671        def do_final_block(self, parameters, declarations, initializations, statements, icount=0):
672                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
673                if self.use_C_syntax:
674                        return "#define " + pfx + "do_final_block(" + parameters + ")\\\n do {" \
675                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
676                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
677                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
678                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
679                return self.indent(icount) + do_final_block_inline_decorator + "void " + pfx + "do_final_block(" + parameters + ") {" \
680                + "\n" + self.indent(icount) + declarations \
681                + "\n" + self.indent(icount) + initializations \
682                + "\n" + self.indent(icount) + statements \
683                + "\n" + self.indent(icount + 2) + "}" 
684
685        def do_segment(self, parameters, do_block_call_args, icount=0):
686                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
687                if self.use_C_syntax:
688                        return "#define " + pfx + "do_segment(" + parameters + ")\\\n do {" \
689                        + "\\\n" + self.indent(icount) + "  int i;" \
690                        + "\\\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
691                        + "\\\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
692                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
693                return self.indent(icount) + "void " + pfx + "do_segment(" + parameters + ") {" \
694                + "\n" + self.indent(icount) + "  int i;" \
695                + "\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
696                + "\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
697                + "\n" + self.indent(icount + 2) + "}" 
698
699        def declaration(self, type_name, instance_name, icount=0):
700                if self.use_C_syntax: return self.indent(icount) + "struct " + type_name + " " + instance_name + ";\n"
701                return self.indent(icount) + type_name + " " + instance_name + ";\n"
702               
703        def carry_init(self, carry_count, icount=0):   
704                #CARRY SET
705                return ""
706               
707        def carry_declare(self, carry_variable, carry_count, adv32_count=0, icount=0):
708                adv32_decl = ""
709                if adv32_count > 0:
710                        adv32_decl = "\n" + self.indent(icount) + "uint32_t pending32[%s];" % adv32_count
711                #CARRY SET
712                return self.indent(icount) + "CarryArray<%i> %s;" % (carry_count, carry_variable) + adv32_decl
713
714        def carry_test(self, carry_variable, carry_count, icount=0):
715                #CARRY SET
716                return self.indent(icount) + "carryQ.CarryTest(0, %i)" % (carry_count)         
717               
718        def indent(self, icount):
719                s = ""
720                for i in range(0,icount): s += " "
721                return s       
722               
723        def do_block_parameters(self, parameters):
724                if self.use_C_syntax:
725                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters])
726                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
727                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters])
728               
729        def do_final_block_parameters(self, parameters):
730                if self.use_C_syntax:
731                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
732                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters]+ ["EOF_mask"])
733                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
734               
735        def do_segment_parameters(self, parameters):
736                if self.use_C_syntax:
737                        #return ", ".join([self.type_name + " * " + + self.instance_name] + [upper1(p) + " " + lower1(p) + "[]" for p in parameters])
738                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ["int segment_blocks"])
739                else: return ", ".join([upper1(p) + " " + lower1(p) + "[]" for p in parameters] + ["int segment_blocks"])
740
741        def do_segment_args(self, parameters):
742                if self.use_C_syntax:
743                        return ", ".join([lower1(self.type_name)] + [lower1(p) + "[i]" for p in parameters])
744                else: return ", ".join([lower1(p) + "[i]" for p in parameters])
745
746def main(infilename, outfile = sys.stdout):
747  t = ast.parse(file(infilename).read())
748  outfile.write(StreamStructGen(True).gen(t))
749  outfile.write(FunctionXlator().xlat(t))
750
751#
752#  Routines for compatibility with the old compiler/template.
753#  Quick and dirty hacks for now - Dec. 2010.
754#
755
756class MainLoopTransformer:
757  def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, add_assert_bitblock_align=False, main_node_id='Main'):
758       
759    self.main_module = main_module
760    self.main_node_id = main_node_id
761    self.use_C_syntax = C_syntax
762    self.add_dump_stmts = add_dump_stmts
763    self.add_assert_bitblock_align = add_assert_bitblock_align
764   
765        # Gather and partition function definition nodes.
766    stream_function_visitor = StreamFunctionVisitor(self.main_module)
767    self.stream_function_node = stream_function_visitor.stream_function_node
768    self.main_node = self.stream_function_node[main_node_id]
769    self.main_carry_count = CarryCounter().count(self.main_node)
770    self.main_adv32_count = Adv32Counter().count(self.main_node)
771    assert self.main_adv32_count == 0, "Advance32() in main not supported.\n"
772    del self.stream_function_node[self.main_node_id]
773   
774    self.stream_functions = {}
775    for key, node in self.stream_function_node.iteritems():
776                stream_function = StreamFunction()
777                stream_function.carry_count = CarryCounter().count(node)
778                stream_function.adv32_count = Adv32Counter().count(node)
779                stream_function.type_name = node.name[0].upper() + node.name[1:]
780                stream_function.instance_name = node.name[0].lower() + node.name[1:]
781                stream_function.parameters = FunctionVars(node).params
782                stream_function.declarations = BitBlock_decls_of_fn(node)
783                stream_function.initializations = StreamInitializations().xfrm(node) 
784               
785                StringMatchCompiler().xfrm(node)
786                AugAssignRemoval().xfrm(node)
787                Bitwise_to_SIMD().xfrm(node)
788                final_block_node = copy.deepcopy(node)
789                if self.use_C_syntax:
790                        carryQname = stream_function.instance_name + ".carryQ"
791                else: carryQname = "carryQ"
792                CarryIntro(carryQname).xfrm_fndef(node)
793                CarryIntro(carryQname, "_ci", "").xfrm_fndef(final_block_node)
794
795                if self.add_dump_stmts: 
796                        Add_SIMD_Register_Dump().xfrm(node)
797                        Add_SIMD_Register_Dump().xfrm(final_block_node)
798
799                if self.add_assert_bitblock_align:
800                        Add_Assert_BitBlock_Align().xfrm(node)
801                        Add_Assert_BitBlock_Align().xfrm(final_block_node)
802
803                if stream_function.carry_count > 0:
804                        node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
805               
806                stream_function.statements = Cgen.py2C(4).gen(node.body)
807                stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
808                self.stream_functions[stream_function.type_name] = stream_function
809       
810    self.emitter = Emitter(self.use_C_syntax)
811   
812  def any_carry_expr(self):
813       
814        carry_test = []
815       
816        if self.main_carry_count > 0:
817                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count)) 
818                        carry_test.append(" || ")
819
820        for key in self.stream_functions.keys():               
821                if self.stream_functions[key].carry_count > 0:
822                        carry_test.append(self.stream_functions[key].instance_name + "." + self.emitter.carry_test('carryQ',self.stream_functions[key].carry_count))# TODO Update self.emitter.carry_test
823                        carry_test.append(" || ")
824
825        if len(carry_test) > 0:
826                carry_test.pop()
827                return "".join(carry_test)
828        return "1"
829
830  def gen_globals(self):
831    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
832    for key in self.stream_functions.keys():
833                self.Cglobals += Emitter(self.use_C_syntax).definition(self.stream_functions[key],2)       
834                       
835  def gen_declarations(self): 
836    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
837    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
838    if self.main_carry_count > 0: 
839        self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
840               
841  def gen_initializations(self):
842    self.Cinits = ""
843    if self.main_carry_count > 0: 
844        self.Cinits += self.emitter.carry_init(self.main_carry_count)
845    self.Cinits += StreamInitializations().xfrm(self.main_module)
846    if self.use_C_syntax:
847                for key in self.stream_functions.keys():
848                        if self.stream_functions[key].carry_count == 0: continue
849                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
850                        self.Cinits += "CarryInit(" + self.stream_functions[key].instance_name + ".carryQ, %i);\n" % (self.stream_functions[key].carry_count)
851    else:
852                for key in self.stream_functions.keys():
853                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
854
855                       
856  def xfrm_block_stmts(self):
857    StringMatchCompiler().xfrm(self.main_node)
858    AugAssignRemoval().xfrm(self.main_node)
859    Bitwise_to_SIMD().xfrm(self.main_node)
860    final_block_main = copy.deepcopy(self.main_node)
861    CarryIntro().xfrm_fndef(self.main_node)
862    CarryIntro('carryQ', '_ci', '').xfrm_fndef(final_block_main)
863    if self.add_dump_stmts: 
864        Add_SIMD_Register_Dump().xfrm(self.main_node)
865        Add_SIMD_Register_Dump().xfrm(final_block_main)
866               
867    if self.add_assert_bitblock_align:
868        print "add_assert_bitblock_align"
869        Add_Assert_BitBlock_Align().xfrm(self.main_node)
870        Add_Assert_BitBlock_Align().xfrm(final_block_main)
871
872    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
873    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
874   
875    if self.main_carry_count > 0:
876                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
877                self.main_node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(self.main_carry_count)])]
878   
879   
880       
881    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
882    self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
883   
884if __name__ == "__main__":
885                import doctest
886                doctest.testmod()
Note: See TracBrowser for help on using the repository browser.