source: proto/Compiler/pablo.py @ 2215

Last change on this file since 2215 was 2215, checked in by cameron, 7 years ago

Compiler fix: apply AdvanceCombiner? before CarryCounter?

File size: 36.8 KB
RevLine 
[753]1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
[903]5# Copyright 2010, 2011, Robert D. Cameron, Kenneth S. Herdy
[753]6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
[1271]11name_substitution_map = {}
[1997]12do_block_inline_decorator = 'IDISA_INLINE '
13do_final_block_inline_decorator = ''
[2099]14error_routine = 'raise_assert'
[1271]15
[2099]16
[1440]17def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='pablo'):
[813]18        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
19        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
20                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
[1864]21        return iscall and len(fncall.args) == builtin_arg_cnt
[753]22
[1864]23def dump_Call(fncall):
24        if isinstance(fncall.func, ast.Name): print "fn_name = %s\n" % fncall.func.id
25        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
26                print "fn_name = %s.%s\n" % (fncall.func.value.id, fncall.func.attr)
27        print "len(fncall.args) = %s\n" % len(fncall.args)
28
[770]29def is_simd_not(e):
30  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
31
[753]32def mkQname(obj, field):
33  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
34
35def mkCall(fn_name, args):
[1271]36  if isinstance(fn_name, str): 
37        if name_substitution_map.has_key(fn_name): fn_name = name_substitution_map[fn_name]
38        fn_name = ast.Name(fn_name, ast.Load())
[753]39  return ast.Call(fn_name, args, [], None, None)
40
41def mkCallStmt(fn_name, args):
42  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
43  return ast.Expr(ast.Call(fn_name, args, [], None, None))
44
45#
[863]46# Reducing AugAssign, e.g.  x |= y becomes x = x | y
[771]47#
48class AugAssignRemoval(ast.NodeTransformer):
49  def xfrm(self, t):
50    return self.generic_visit(t)
51  def visit_AugAssign(self, e):
52    self.generic_visit(e)
53    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
[2208]54#
55# pablo.Advance(pablo.Advance(x, n)) => pablo.Advance(x, n+1)
56#
57class AdvanceCombiner(ast.NodeTransformer):
58  def xfrm(self, t):
59    return self.generic_visit(t)
60  def visit_if(self, ifNode):
61    return IfNode
62  def visit_While(self, whileNode):
63    return whileNode
64  def visit_Call(self, callnode):
65    self.generic_visit(callnode)
66    if not isinstance(callnode.args[0], ast.Call): return callnode
67    if is_BuiltIn_Call(callnode,'Advance', 1):
68        if is_BuiltIn_Call(callnode.args[0],'Advance', 1):
69          callnode.args = [callnode.args[0].args[0], ast.Num(2)]
70        elif is_BuiltIn_Call(callnode.args[0], 'Advance', 2):
71          if isinstance(callnode.args[0].args[1], ast.Num):
72            callnode.args = [callnode.args[0].args[0], ast.Num(callnode.args[0].args[1].n + 1)]
73          else:
74            callnode.args = [callnode.args[0].args[0], ast.BinOp(callnode.args[0].args[1], ast.Add(), ast.Num(1))]
75    return callnode
[771]76
[2009]77
[771]78#
[2009]79#  Translating pablo.match(marker, str)
80#  Incremental character match with lookahead
81#
82CharNameMap = {'[' : 'LBrak', ']' : 'RBrak', '{' : 'LBrace', '}' : 'LBrace', '(' : 'LParen', ')' : 'RParen', \
83              '!' : 'Exclam', '"' : 'DQuote', '#' : 'Hash', '$' : 'Dollar', '%' : 'PerCent', '&': 'RefStart', \
84              "'" : 'SQuote', '*': 'Star', '+' : 'Plus', ',' : 'Comma', '-' : 'Hyphen', '.' : 'Dot', '/' : 'Slash', \
85              ':' : 'Colon', ';' : 'Semicolon', '=' : 'Equals', '?' : 'QMark', '@' : 'AtSign', '\\' : 'BackSlash', \
86              '^' : 'Caret', '_' : 'Underscore', '|' : 'VBar', '~' : 'Tilde', ' ' : 'SP', '\t' : 'HT', '\m' : 'CR', '\n' : 'LF'}
87
88def GetCharName(char):
89        if char >= 'a' and char <= 'z' or char >= 'A' and char <= 'Z': return 'letter_' + char
90        elif char >= '0' and char <= '9': return 'digit_' + char
91        else: return CharNameMap[char]
92
93def MkCharStream(char):
94        return mkQname('lex', GetCharName(char))
95
96def MkLookAheadExpr(v, i):
97        return mkCall(mkQname('pablo', 'LookAhead'), [v, ast.Num(i)])
98
99def CompileMatch(match_var, string_to_match):
100        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
101        for i in range(1, len(string_to_match)):
102                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
103        return expr
104
105class StringMatchCompiler(ast.NodeTransformer):
106  def xfrm(self, t):
107    return self.generic_visit(t)
108  def visit_Call(self, callnode):
109    if is_BuiltIn_Call(callnode,'match', 2):
110        ast.dump(callnode)
111        assert isinstance(callnode.args[0], ast.Str)
112        string_to_match = callnode.args[0].s
113        match_var = callnode.args[1]
114        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
115        for i in range(1, len(string_to_match)):
116                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
117        return expr
118    else: return callnode
119
120
121
122#
[753]123# Introducing BitBlock logical operations
124#
125class Bitwise_to_SIMD(ast.NodeTransformer):
126  """
127  Make the following substitutions:
128     x & y => simd_and(x, y)
[770]129     x & ~y => simd_andc(x, y)
[753]130     x | y => simd_or(x, y)
131     x ^ y => simd_xor(x, y)
132     ~x    => simd_not(x)
133     0     => simd_const_1(0)
134     -1    => simd_const_1(1)
[1916]135     if x: => if bitblock::any(x):
136  while x: => while bitblock::any(x):
[753]137  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
138 
139  pfx = simd_and(bit0, bit1)
140  sfx = simd_and(bit0, simd_not(bit1))
141  >>>
142  """
143  def xfrm(self, t):
144    return self.generic_visit(t)
145  def visit_UnaryOp(self, t):
146    self.generic_visit(t)
147    if isinstance(t.op, ast.Invert):
148      return mkCall('simd_not', [t.operand])
149    else: return t
150  def visit_BinOp(self, t):
151    self.generic_visit(t)
152    if isinstance(t.op, ast.BitOr):
153      return mkCall('simd_or', [t.left, t.right])
154    elif isinstance(t.op, ast.BitAnd):
[770]155      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
156      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
157      else: return mkCall('simd_and', [t.left, t.right])
[753]158    elif isinstance(t.op, ast.BitXor):
159      return mkCall('simd_xor', [t.left, t.right])
160    else: return t
161  def visit_Num(self, numnode):
162    n = numnode.n
[1916]163    if n == 0: return mkCall('simd<1>::constant<0>', [])
164    elif n == -1: return mkCall('simd<1>::constant<1>', [])
[753]165    else: return numnode
166  def visit_If(self, ifNode):
167    self.generic_visit(ifNode)
[1916]168    ifNode.test = mkCall('bitblock::any', [ifNode.test])
[753]169    return ifNode
170  def visit_While(self, whileNode):
171    self.generic_visit(whileNode)
[1916]172    whileNode.test = mkCall('bitblock::any', [whileNode.test])
[753]173    return whileNode
174  def visit_Subscript(self, numnode):
175    return numnode  # no recursive modifications of index expressions
176
177#
178#  Generating BitBlock declarations for Local Variables
179#
[880]180class FunctionVars(ast.NodeVisitor):
181  def __init__(self,node):
182        self.params = []
183        self.stores = []
184        self.generic_visit(node)
[753]185  def visit_Name(self, nm):
186    if isinstance(nm.ctx, ast.Param):
187      self.params.append(nm.id)
188    if isinstance(nm.ctx, ast.Store):
189      if nm.id not in self.stores: self.stores.append(nm.id)
[880]190  def getLocals(self): 
[753]191    return [v for v in self.stores if not v in self.params]
192
193MAX_LINE_LENGTH = 80
194
195def BitBlock_decls_from_vars(varlist):
196  global MAX_LINE_LENGTH
[880]197  decls =  ""
198  if not len(varlist) == 0:
199          decls = "             BitBlock"
200          pending = ""
201          linelgth = 10
202          for v in varlist:
203            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
204              decls += pending + " " + v
205              linelgth += len(pending + v) + 1
206            else:
207              decls += ";\n             BitBlock " + v
208              linelgth = 11 + len(v)
209            pending = ","
210          decls += ";"
[753]211  return decls
212
213def BitBlock_decls_of_fn(fndef):
[880]214  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
[753]215
216def BitBlock_header_of_fn(fndef):
217  Ccode = "static inline void " + fndef.name + "("
218  pending = ""
219  for arg in fndef.args.args:
220    if isinstance(arg, ast.Name):
221      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
222      pending = ", "
[863]223  if CarryCounter().count(fndef) > 0:
224    Ccode += pending + " CarryQtype & carryQ"
[753]225  Ccode += ")"
226  return Ccode
227
228
[760]229
230#
231#  Stream Initialization Statement Extraction
232#
233#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
234class StreamInitializations(ast.NodeTransformer):
235  def xfrm(self, node):
236    self.stream_stmts = []
237    self.loop_post_inits = []
238    self.generic_visit(node)
239    return Cgen.py2C().gen(self.stream_stmts)
240  def visit_Assign(self, node):
241    if isinstance(node.value, ast.Num):
242      if node.value.n == 0: return node
243      elif node.value.n == -1: return node
244      else: 
245        stream_init = copy.deepcopy(node)
246        stream_init.value = mkCall('sisd_from_int', [node.value])
247        loop_init = copy.deepcopy(node)
248        loop_init.value.n = 0
249        self.stream_stmts.append(stream_init)
250        self.loop_post_inits.append(loop_init)
251        return None
252    else: return node
253  def visit_FunctionDef(self, node):
254    self.generic_visit(node)
255    node.body = node.body + self.loop_post_inits
256    return node
257
[2041]258
259
[753]260#
261# Carry Introduction Transformation
262#
263class CarryCounter(ast.NodeVisitor):
264  def visit_Call(self, callnode):
265    self.generic_visit(callnode)
[2041]266    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2) or is_BuiltIn_Call(callnode,'ScanTo', 2) or is_BuiltIn_Call(callnode,'ScanToFirst', 1) or is_BuiltIn_Call(callnode,'AdvanceThenScanThru', 2) or is_BuiltIn_Call(callnode,'AdvanceThenScanTo', 2) or is_BuiltIn_Call(callnode,'SpanUpTo', 2) or  is_BuiltIn_Call(callnode,'InclusiveSpan', 2) or is_BuiltIn_Call(callnode,'ExclusiveSpan', 2):         
[753]267      self.carry_count += 1
[769]268  def visit_BinOp(self, exprnode):
[753]269    self.generic_visit(exprnode)
270    if isinstance(exprnode.op, ast.Sub):
271      self.carry_count += 1
[769]272    if isinstance(exprnode.op, ast.Add):
273      self.carry_count += 1
[753]274  def count(self, nodeToVisit):
275    self.carry_count = 0
276    self.generic_visit(nodeToVisit)
277    return self.carry_count
278
[2206]279class adv_nCounter(ast.NodeVisitor):
[1211]280  def visit_Call(self, callnode):
281    self.generic_visit(callnode)
282    if is_BuiltIn_Call(callnode,'Advance32', 1):       
[2206]283      self.adv_n_count += 1
284    if is_BuiltIn_Call(callnode,'Advance', 2):         
285      self.adv_n_count += 1
[1211]286  def count(self, nodeToVisit):
[2206]287    self.adv_n_count = 0
[1211]288    self.generic_visit(nodeToVisit)
[2206]289    return self.adv_n_count
[1211]290
[753]291class CarryIntro(ast.NodeTransformer):
[921]292  def __init__(self, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
[753]293    self.carryvar = ast.Name(carryvar, ast.Load())
[921]294    self.carryin = carryin
295    self.carryout = carryout
[753]296  def xfrm_fndef(self, fndef):
297    self.current_carry = 0
[2206]298    self.current_adv_n = 0
[753]299    carry_count = CarryCounter().count(fndef)
300    if carry_count == 0: return fndef
301    self.generic_visit(fndef)
[762]302#   
303#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
[753]304    return fndef
305  def generic_xfrm(self, node):
306    self.current_carry = 0
[2206]307    self.current_adv_n = 0
[753]308    carry_count = CarryCounter().count(node)
[2206]309    adv_n_count = adv_nCounter().count(node)
310    if carry_count == 0 and adv_n_count == 0: return node
[753]311    self.generic_visit(node)
312    return node
313  def visit_Call(self, callnode):
314    self.generic_visit(callnode)
[1571]315    #CARRYSET
[1997]316    #carry_args = [ast.Num(self.current_carry)]
[2206]317    #adv_n_args = [ast.Subscript(ast.Name(self.carryvar.id + '.pending64', ast.Load()), ast.Num(self.current_adv_n), ast.Load())]
318    adv_n_pending = ast.Subscript(ast.Name(self.carryvar.id + '.pending64', ast.Load()), ast.Num(self.current_adv_n), ast.Load())
[1997]319    if self.carryin == "_ci":
320        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
[2206]321        adv_n_args = [adv_n_pending, ast.Num(self.current_adv_n)]
[1997]322    else: 
323        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
[2206]324        adv_n_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_adv_n)]
[1997]325
[2206]326    if is_BuiltIn_Call(callnode, 'Advance', 2):         
327      #CARRYSET
328      rtn = self.carryvar.id + "." + "BitBlock_advance_n_<%i>" % callnode.args[1].n
329      c = mkCall(rtn, [callnode.args[0]] + adv_n_args)
330      self.current_adv_n += 1
331      return c
[813]332    if is_BuiltIn_Call(callnode, 'Advance', 1):         
[1571]333      #CARRYSET
[1997]334      rtn = self.carryvar.id + "." + "BitBlock_advance_ci_co"
[924]335      c = mkCall(rtn, callnode.args + carry_args)
[753]336      self.current_carry += 1
337      return c
[1211]338    elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
[1571]339      #CARRYSET
[2206]340      rtn = self.carryvar.id + "." + "BitBlock_advance_n_<32>"
341      c = mkCall(rtn, callnode.args + adv_n_args)
342      self.current_adv_n += 1
[1211]343      return c
[813]344    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
[1571]345      #CARRYSET
[1997]346      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co"
[924]347      c = mkCall(rtn, callnode.args + carry_args)
[753]348      self.current_carry += 1
349      return c
[2041]350    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanThru', 2):
351      #CARRYSET
352      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru"
353      c = mkCall(rtn, callnode.args + carry_args)
354      self.current_carry += 1
355      return c
356    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanTo', 2):
357      #CARRYSET
358      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru" 
359      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
360      else: scanclass = mkCall('simd_not', [callnode.args[1]])
[2049]361      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
[2041]362      self.current_carry += 1
363      return c
364    elif is_BuiltIn_Call(callnode, 'SpanUpTo', 2):
365      #CARRYSET
366      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
367      c = mkCall(rtn, callnode.args + carry_args)
368      self.current_carry += 1
369      return c
370    elif is_BuiltIn_Call(callnode, 'InclusiveSpan', 2):
371      #CARRYSET
[2049]372#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
373#      c = mkCall('simd_or', [mkCall(rtn, callnode.args + carry_args), callnode.args[1]])
374      rtn = self.carryvar.id + "." + "BitBlock_inclusive_span"
375      c = mkCall(rtn, callnode.args + carry_args)
[2041]376      self.current_carry += 1
377      return c
378    elif is_BuiltIn_Call(callnode, 'ExclusiveSpan', 2):
379      #CARRYSET
[2049]380#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
381#      c = mkCall('simd_andc', [mkCall(rtn, callnode.args + carry_args), callnode.args[0]])
382      rtn = self.carryvar.id + "." + "BitBlock_exclusive_span"
383      c = mkCall(rtn, callnode.args + carry_args)
[2041]384      self.current_carry += 1
385      return c
[902]386    elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
[1520]387      # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
388      # in having a separate BitBlock_scanto routine.
[1571]389      #CARRYSET
[1997]390      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co" 
[1520]391      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
392      else: scanclass = mkCall('simd_not', [callnode.args[1]])
393      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
[902]394      self.current_carry += 1
395      return c
[1074]396    elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
[1571]397      #CARRYSET
398      rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
[1074]399      #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
400      c = mkCall(rtn, callnode.args + carry_args)
401      self.current_carry += 1
402      return c
[1439]403    elif is_BuiltIn_Call(callnode, 'atEOF', 1):
404      if self.carryout != "": 
405        # Non final block: atEOF(x) = 0.
[1916]406        return mkCall('simd<1>::constant<0>', [])
[1439]407      else: return mkCall('simd_andc', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
408    elif is_BuiltIn_Call(callnode, 'inFile', 1):
409      if self.carryout != "": 
410        # Non final block: inFile(x) = x.
411        return callnode.args[0]
412      else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
[822]413    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
414      rtn = "StreamScan"           
415      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
416                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
417                                           ast.Name(callnode.args[1].id, ast.Load())])
418      return c
[1864]419    else:
[1865]420      #dump_Call(callnode)
[1864]421      return callnode
[765]422  def visit_BinOp(self, exprnode):
[753]423    self.generic_visit(exprnode)
[1571]424    carry_args = [ast.Num(self.current_carry)]
[2003]425    if self.carryin == "_ci":
426        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
427    else: 
428        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
[753]429    if isinstance(exprnode.op, ast.Sub):
[1571]430      #CARRYSET
[2003]431      rtn = self.carryvar.id + "." + "BitBlock_sub_ci_co"
[924]432      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
[753]433      self.current_carry += 1
434      return c
435    elif isinstance(exprnode.op, ast.Add):
[1571]436      #CARRYSET
[2004]437      rtn = self.carryvar.id + "." + "BitBlock_add_ci_co"
[924]438      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
[753]439      self.current_carry += 1
440      return c
441    else: return exprnode
442  def visit_If(self, ifNode):
443    carry_base = self.current_carry
444    carries = CarryCounter().count(ifNode)
[2206]445    assert adv_nCounter().count(ifNode) == 0, "Advance(x,n) within if: illegal\n"
[753]446    self.generic_visit(ifNode)
[921]447    if carries == 0 or self.carryin == "": return ifNode
[1571]448    #CARRYSET
449    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
450    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
451    new_else_part = ifNode.orelse + [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]
[753]452    return ast.If(new_test, ifNode.body, new_else_part)
453  def visit_While(self, whileNode):
[941]454    if self.carryout == '':
455      whileNode.test.args[0] = mkCall("simd_and", [whileNode.test.args[0], ast.Name('EOF_mask', ast.Load())])
[753]456    carry_base = self.current_carry
[2206]457    assert adv_nCounter().count(whileNode) == 0, "Advance(x,n) within while: illegal\n"
[753]458    carries = CarryCounter().count(whileNode)
[1571]459    #CARRYSET
[753]460    if carries == 0: return whileNode
[1571]461    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
[1196]462    local_carryvar = 'subcarryQ'
[921]463    inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
[753]464    self.generic_visit(whileNode)
[1571]465    local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
466    inner_while.body.insert(0, local_carry_decl)
[2213]467    final_combine = mkCallStmt('carryQ.CarryCombine', [ast.Attribute(ast.Name(local_carryvar, ast.Load()), 'cq', ast.Load()),ast.Num(carry_base), ast.Num(carries)])
[1526]468    inner_while.body.append(final_combine)
[1571]469    #CARRYSET
[921]470    if self.carryin == '': new_test = whileNode.test
[1571]471    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
472    else_part = [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]   
[753]473    return ast.If(new_test, whileNode.body + [inner_while], else_part)
474
475class StreamStructGen(ast.NodeVisitor):
476  """
477  Given a BitStreamSet subclass, generate the equivalent C struct.
478  >>> obj = ast.parse(r'''
479  ... class S1(BitStreamSet):
480  ...   a1 = 0
481  ...   a2 = 0
482  ...   a3 = 0
483  ...
484  ... class S2(BitStreamSet):
485  ...   x1 = 0
486  ...   x2 = 0
487  ... ''')
488  >>> print StreamStructGen().gen(obj)
489  struct S1 {
490    BitBlock a1;
491    BitBlock a2;
492    BitBlock a3;
[2206]493  }    self.current_adv_n = 0
[1211]494
[753]495 
496  struct S2 {
497    BitBlock x1;
498    BitBlock x2;
499  }
500  """
[857]501  def __init__(self, asType=False):
502    self.asType = asType
[753]503  def gen(self, tree):
504    self.Ccode=""
505    self.generic_visit(tree)
506    return self.Ccode
[865]507  def gen_struct_types(self, tree):
508    self.asType = True
509    self.Ccode=""
510    self.generic_visit(tree)
511    return self.Ccode
512  def gen_struct_vars(self, tree):
513    self.asType = False
514    self.Ccode=""
515    self.generic_visit(tree)
516    return self.Ccode
[753]517  def visit_ClassDef(self, node):
[857]518    class_name = node.name[0].upper() + node.name[1:]
519    instance_name = node.name[0].lower() + node.name[1:]
[880]520    self.Ccode += "  struct " + class_name
[865]521    if self.asType:
522            self.Ccode += " {\n"
523            for stmt in node.body:
524              if isinstance(stmt, ast.Assign):
525                for v in stmt.targets:
526                  if isinstance(v, ast.Name):
527                    self.Ccode += "  BitBlock " + v.id + ";\n"
528            self.Ccode += "}" 
529    else: self.Ccode += " " + instance_name
[857]530    self.Ccode += ";\n\n"
[753]531 
532class StreamFunctionDecl(ast.NodeVisitor):
533  def __init__(self):
534    pass
535  def gen(self, tree):
536    self.Ccode=""
537    self.generic_visit(tree)
538    return self.Ccode
539  def visit_FunctionDef(self, node):
540    self.Ccode += "static inline void " + node.name + "("
541    pending = ""
542    for arg in node.args.args:
543      if isinstance(arg, ast.Name):
544        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
545        pending = ", "
546    self.Ccode += ");\n"
547
[2099]548class AssertCompiler(ast.NodeTransformer):
549  def __init__(self):
550    self.assert_routine = ast.parse(error_routine).body[0].value
551  def xfrm(self, t):
552    return self.generic_visit(t)
553  def visit_Expr(self, node):
554    if isinstance(node.value, ast.Call):
555        if is_BuiltIn_Call(node.value, "assert_0", 2):
556                err_stream = node.value.args[0]
557                err_code = node.value.args[1]
558                return ast.If(err_stream,
[2100]559                              [ast.Expr(mkCall(self.assert_routine, [err_code, err_stream]))],
[2099]560                              [])
561        else: return node
562    else: return node
563       
[810]564#
565# Adding Debugging Statements
566#
567class Add_SIMD_Register_Dump(ast.NodeTransformer):
568  def xfrm(self, t):
569    return self.generic_visit(t)
570  def visit_Assign(self, t):
571    self.generic_visit(t)
572    v = t.targets[0]
[1843]573    dump_stmt = mkCallStmt(' print_register<BitBlock>', [ast.Str(Cgen.py2C().gen(v)), v])
[810]574    return [t, dump_stmt]
[880]575
[1901]576#
577# Adding ASSERT_BITBLOCK_ALIGN macros
578#
579class Add_Assert_BitBlock_Align(ast.NodeTransformer):
580    def xfrm(self, t):
581      return self.generic_visit(t)
582    def visit_Assign(self, t):
583      self.generic_visit(t)
584      v = t.targets[0]
585      dump_stmt = mkCallStmt(' ASSERT_BITBLOCK_ALIGN', [v])
586      return [t, dump_stmt]
587
[880]588class StreamFunctionCarryCounter(ast.NodeVisitor):
589  def __init__(self):
590        self.carry_count = {}
591       
592  def count(self, node):
593        self.generic_visit(node)
594        return self.carry_count
595                                   
596  def visit_FunctionDef(self, node):   
597        type_name = node.name[0].upper() + node.name[1:]                       
598        self.carry_count[type_name] = CarryCounter().count(node)
599     
600class StreamFunctionCallXlator(ast.NodeTransformer):
[937]601  def __init__(self, xlate_type="normal"):
[880]602        self.stream_function_type_names = []
[937]603        self.xlate_type = xlate_type
604
[1195]605  def xfrm(self, node, stream_function_type_names, C_syntax):
[880]606        self.stream_function_type_names = stream_function_type_names
[1195]607        self.C_syntax = C_syntax
[880]608        self.generic_visit(node)
609       
610  def visit_Call(self, node):   
611        self.generic_visit(node)
612
[1957]613        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):# and node.func.id in self.stream_function_type_names:
[1195]614             name = lower1(node.func.id)
615             node.func.id = name + ("_" if self.C_syntax else ".") + ("do_final_block" if self.xlate_type == "final" else "do_block")
616             if self.C_syntax:
[1196]617                     node.args = [ast.Name(lower1(name), ast.Load())] + node.args
[937]618             if self.xlate_type == "final":
619                   node.args = node.args + [ast.Name("EOF_mask", ast.Load())]
[1195]620                     
[880]621        return node     
622               
623class StreamFunctionVisitor(ast.NodeVisitor):
624        def __init__(self,node):
625                self.stream_function_node = {}
626                self.generic_visit(node)
627                                                             
628        def visit_FunctionDef(self, node):                     
629                key = node.name[0].upper() + node.name[1:]
630                self.stream_function_node[key] = node
631               
632class StreamFunction():
633        def __init__(self):
634                self.carry_count = 0 
[2206]635                self.adv_n_count = 0 
[880]636                self.type_name = ""
637                self.instance_name = "" 
638                self.parameters = ""
639                self.declarations = "" 
640                self.initializations = "" 
[813]641#
[903]642# TODO Consolidate *all* C code generation into the Emitter class.   Medium priority.
643# TODO Implement 'pretty print' indentation.   Low priority.
644# TODO Migrate Emiter() class to Emitter module.  Medium priority.
[908]645
646def lower1(name):
647    return name[0].lower() + name[1:]
648def upper1(name):
649    return name[0].upper() + name[1:]
650
[1196]651def escape_newlines(str):
652  return str.replace('\n', '\\\n')
653
[880]654class Emitter():
[1195]655        def __init__(self, use_C_syntax):
656                self.use_C_syntax = use_C_syntax
[753]657
[880]658        def definition(self, stream_function, icount=0):
659               
660                constructor = ""
661                carry_declaration = ""
[1195]662                self.type_name = stream_function.type_name
[880]663               
[2206]664                if stream_function.carry_count > 0 or stream_function.adv_n_count > 0:
665                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.adv_n_count)
666                        carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv_n_count)
[753]667
[880]668                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), 
669                                                stream_function.declarations, 
670                                                stream_function.initializations, 
671                                                stream_function.statements)             
[906]672
[922]673                do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), 
674                                                stream_function.declarations, 
675                                                stream_function.initializations, 
676                                                stream_function.final_block_statements)                 
677
[906]678                do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), 
679                                                self.do_segment_args(stream_function.parameters))       
680
[1195]681                if self.use_C_syntax:
[1196]682                        return self.indent(icount) + "struct " + stream_function.type_name + " {" \
[1195]683                               + "\n" + self.indent(icount) + carry_declaration \
684                               + "\n" + self.indent(icount) + "};\n" \
685                               + "\n" + self.indent(icount) + do_block_function \
686                               + "\n" + self.indent(icount) + do_final_block_function \
687                               + "\n" + self.indent(icount) + do_segment_function + "\n\n"
688                               
[880]689                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
690                + "\n" + self.indent(icount) + constructor \
691                + "\n" + self.indent(icount) + do_block_function \
[922]692                + "\n" + self.indent(icount) + do_final_block_function \
[906]693                + "\n" + self.indent(icount) + do_segment_function \
[880]694                + "\n" + self.indent(icount) + carry_declaration \
695                + "\n" + self.indent(icount) + "};\n\n"
696
[2206]697        def constructor(self, type_name, carry_count, adv_n_count, icount=0):
698                adv_n_decl = ""
699                #for i in range(adv_n_count): adv_n_decl += self.indent(icount+2) + "pending64[%s] = simd<1>::constant<0>();\n" % i     
700                return self.indent(icount) + "%s() { ""\n" % (type_name) + adv_n_decl + self.carry_init(carry_count) + " }" 
[880]701                       
702        def do_block(self, parameters, declarations, initializations, statements, icount=0):
[1195]703                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
[1196]704                if self.use_C_syntax:
705                        return "#define " + pfx + "do_block(" + parameters + ")\\\n do {" \
706                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
707                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
708                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
709                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
[1997]710                return self.indent(icount) + do_block_inline_decorator + "void " + pfx + "do_block(" + parameters + ") {" \
[1195]711                + "\n" + self.indent(icount) + declarations \
[880]712                + "\n" + self.indent(icount) + initializations \
713                + "\n" + self.indent(icount) + statements \
714                + "\n" + self.indent(icount + 2) + "}" 
715
[1196]716
717
718
[922]719        def do_final_block(self, parameters, declarations, initializations, statements, icount=0):
[1195]720                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
[1196]721                if self.use_C_syntax:
722                        return "#define " + pfx + "do_final_block(" + parameters + ")\\\n do {" \
723                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
724                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
725                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
726                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
[1997]727                return self.indent(icount) + do_final_block_inline_decorator + "void " + pfx + "do_final_block(" + parameters + ") {" \
[1195]728                + "\n" + self.indent(icount) + declarations \
[922]729                + "\n" + self.indent(icount) + initializations \
730                + "\n" + self.indent(icount) + statements \
731                + "\n" + self.indent(icount + 2) + "}" 
732
[906]733        def do_segment(self, parameters, do_block_call_args, icount=0):
[1195]734                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
[1196]735                if self.use_C_syntax:
736                        return "#define " + pfx + "do_segment(" + parameters + ")\\\n do {" \
737                        + "\\\n" + self.indent(icount) + "  int i;" \
[1882]738                        + "\\\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
[1196]739                        + "\\\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
740                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
[1195]741                return self.indent(icount) + "void " + pfx + "do_segment(" + parameters + ") {" \
[1196]742                + "\n" + self.indent(icount) + "  int i;" \
[1882]743                + "\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
[1195]744                + "\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
[906]745                + "\n" + self.indent(icount + 2) + "}" 
746
[1198]747        def declaration(self, type_name, instance_name, icount=0):
748                if self.use_C_syntax: return self.indent(icount) + "struct " + type_name + " " + instance_name + ";\n"
[880]749                return self.indent(icount) + type_name + " " + instance_name + ";\n"
750               
751        def carry_init(self, carry_count, icount=0):   
[1571]752                #CARRY SET
753                return ""
754               
[2206]755        def carry_declare(self, carry_variable, carry_count, adv_n_count=0, icount=0):
756                adv_n_decl = ""
757                #if adv_n_count > 0:
758                #       adv_n_decl = "\n" + self.indent(icount) + "BitBlock pending64[%s];" % adv_n_count
[1571]759                #CARRY SET
[2206]760                return self.indent(icount) + "CarryArray<%i, %i> %s;" % (carry_count, adv_n_count, carry_variable)
[880]761
762        def carry_test(self, carry_variable, carry_count, icount=0):
[1571]763                #CARRY SET
764                return self.indent(icount) + "carryQ.CarryTest(0, %i)" % (carry_count)         
[880]765               
766        def indent(self, icount):
767                s = ""
768                for i in range(0,icount): s += " "
769                return s       
770               
771        def do_block_parameters(self, parameters):
[1195]772                if self.use_C_syntax:
[1196]773                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters])
774                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
[1195]775                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters])
[880]776               
[922]777        def do_final_block_parameters(self, parameters):
[1195]778                if self.use_C_syntax:
[1196]779                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
780                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters]+ ["EOF_mask"])
[1195]781                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
[922]782               
[906]783        def do_segment_parameters(self, parameters):
[1195]784                if self.use_C_syntax:
[1196]785                        #return ", ".join([self.type_name + " * " + + self.instance_name] + [upper1(p) + " " + lower1(p) + "[]" for p in parameters])
[1882]786                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ["int segment_blocks"])
787                else: return ", ".join([upper1(p) + " " + lower1(p) + "[]" for p in parameters] + ["int segment_blocks"])
[906]788
789        def do_segment_args(self, parameters):
[1195]790                if self.use_C_syntax:
791                        return ", ".join([lower1(self.type_name)] + [lower1(p) + "[i]" for p in parameters])
792                else: return ", ".join([lower1(p) + "[i]" for p in parameters])
[908]793
[753]794def main(infilename, outfile = sys.stdout):
795  t = ast.parse(file(infilename).read())
[865]796  outfile.write(StreamStructGen(True).gen(t))
[753]797  outfile.write(FunctionXlator().xlat(t))
798
799#
800#  Routines for compatibility with the old compiler/template.
801#  Quick and dirty hacks for now - Dec. 2010.
802#
803
[777]804class MainLoopTransformer:
[1901]805  def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, add_assert_bitblock_align=False, main_node_id='Main'):
[889]806       
[777]807    self.main_module = main_module
[880]808    self.main_node_id = main_node_id
[1195]809    self.use_C_syntax = C_syntax
[889]810    self.add_dump_stmts = add_dump_stmts
[1901]811    self.add_assert_bitblock_align = add_assert_bitblock_align
[880]812   
813        # Gather and partition function definition nodes.
814    stream_function_visitor = StreamFunctionVisitor(self.main_module)
815    self.stream_function_node = stream_function_visitor.stream_function_node
[2215]816    for key, node in self.stream_function_node.iteritems():
817                AdvanceCombiner().xfrm(node)
[880]818    self.main_node = self.stream_function_node[main_node_id]
819    self.main_carry_count = CarryCounter().count(self.main_node)
[2206]820    self.main_adv_n_count = adv_nCounter().count(self.main_node)
821    assert self.main_adv_n_count == 0, "Advance32() in main not supported.\n"
[880]822    del self.stream_function_node[self.main_node_id]
823   
824    self.stream_functions = {}
825    for key, node in self.stream_function_node.iteritems():
826                stream_function = StreamFunction()
827                stream_function.carry_count = CarryCounter().count(node)
[2206]828                stream_function.adv_n_count = adv_nCounter().count(node)
[880]829                stream_function.type_name = node.name[0].upper() + node.name[1:]
830                stream_function.instance_name = node.name[0].lower() + node.name[1:]
831                stream_function.parameters = FunctionVars(node).params
832                stream_function.declarations = BitBlock_decls_of_fn(node)
833                stream_function.initializations = StreamInitializations().xfrm(node) 
834               
[2009]835                StringMatchCompiler().xfrm(node)
[2099]836                AssertCompiler().xfrm(node)
[880]837                AugAssignRemoval().xfrm(node)
838                Bitwise_to_SIMD().xfrm(node)
[922]839                final_block_node = copy.deepcopy(node)
[1195]840                if self.use_C_syntax:
[1196]841                        carryQname = stream_function.instance_name + ".carryQ"
[1195]842                else: carryQname = "carryQ"
843                CarryIntro(carryQname).xfrm_fndef(node)
844                CarryIntro(carryQname, "_ci", "").xfrm_fndef(final_block_node)
[889]845
846                if self.add_dump_stmts: 
847                        Add_SIMD_Register_Dump().xfrm(node)
[1849]848                        Add_SIMD_Register_Dump().xfrm(final_block_node)
[1901]849
850                if self.add_assert_bitblock_align:
851                        Add_Assert_BitBlock_Align().xfrm(node)
852                        Add_Assert_BitBlock_Align().xfrm(final_block_node)
853
[889]854                if stream_function.carry_count > 0:
[1571]855                        node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
[889]856               
[880]857                stream_function.statements = Cgen.py2C(4).gen(node.body)
[922]858                stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
[880]859                self.stream_functions[stream_function.type_name] = stream_function
860       
[1195]861    self.emitter = Emitter(self.use_C_syntax)
[880]862   
[777]863  def any_carry_expr(self):
[880]864       
865        carry_test = []
866       
867        if self.main_carry_count > 0:
[1571]868                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count)) 
[880]869                        carry_test.append(" || ")
870
871        for key in self.stream_functions.keys():               
872                if self.stream_functions[key].carry_count > 0:
[1571]873                        carry_test.append(self.stream_functions[key].instance_name + "." + self.emitter.carry_test('carryQ',self.stream_functions[key].carry_count))# TODO Update self.emitter.carry_test
[880]874                        carry_test.append(" || ")
875
876        if len(carry_test) > 0:
877                carry_test.pop()
[1195]878                return "".join(carry_test)
[880]879        return "1"
880
881  def gen_globals(self):
[865]882    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
[880]883    for key in self.stream_functions.keys():
[1195]884                self.Cglobals += Emitter(self.use_C_syntax).definition(self.stream_functions[key],2)       
[880]885                       
886  def gen_declarations(self): 
[865]887    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
[880]888    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
889    if self.main_carry_count > 0: 
[1571]890        self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
[880]891               
[777]892  def gen_initializations(self):
893    self.Cinits = ""
[880]894    if self.main_carry_count > 0: 
895        self.Cinits += self.emitter.carry_init(self.main_carry_count)
[777]896    self.Cinits += StreamInitializations().xfrm(self.main_module)
[1195]897    if self.use_C_syntax:
898                for key in self.stream_functions.keys():
[1196]899                        if self.stream_functions[key].carry_count == 0: continue
900                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
901                        self.Cinits += "CarryInit(" + self.stream_functions[key].instance_name + ".carryQ, %i);\n" % (self.stream_functions[key].carry_count)
902    else:
903                for key in self.stream_functions.keys():
904                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
[1195]905
[880]906                       
[889]907  def xfrm_block_stmts(self):
[2009]908    StringMatchCompiler().xfrm(self.main_node)
[880]909    AugAssignRemoval().xfrm(self.main_node)
910    Bitwise_to_SIMD().xfrm(self.main_node)
[2099]911    Bitwise_to_SIMD().xfrm(self.main_node)
912    AssertCompiler().xfrm(self.main_node)
[937]913    final_block_main = copy.deepcopy(self.main_node)
[880]914    CarryIntro().xfrm_fndef(self.main_node)
[937]915    CarryIntro('carryQ', '_ci', '').xfrm_fndef(final_block_main)
[889]916    if self.add_dump_stmts: 
[880]917        Add_SIMD_Register_Dump().xfrm(self.main_node)
[939]918        Add_SIMD_Register_Dump().xfrm(final_block_main)
[880]919               
[1901]920    if self.add_assert_bitblock_align:
[1917]921        print "add_assert_bitblock_align"
[1901]922        Add_Assert_BitBlock_Align().xfrm(self.main_node)
923        Add_Assert_BitBlock_Align().xfrm(final_block_main)
924
[1195]925    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
926    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
[889]927   
928    if self.main_carry_count > 0:
[1571]929                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
930                self.main_node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(self.main_carry_count)])]
[889]931   
932   
933       
[880]934    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
[937]935    self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
[880]936   
[753]937if __name__ == "__main__":
938                import doctest
[902]939                doctest.testmod()
Note: See TracBrowser for help on using the repository browser.