source: proto/Compiler/pablo.py @ 2606

Last change on this file since 2606 was 2606, checked in by cameron, 7 years ago

Make temp assignments for carry-generating built-ins

File size: 44.0 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, 2011, Robert D. Cameron, Kenneth S. Herdy
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11name_substitution_map = {}
12do_block_inline_decorator = 'IDISA_INLINE '
13do_final_block_inline_decorator = ''
14error_routine = 'raise_assert'
15
16
17def isCarryGenerating(builtin_fn):
18   return builtin_fn in ['Advance', 'ScanThru', 'ScanTo', 'AdvanceThenScanThru', 'AdvanceThenScanTo', 'SpanUpTo', 'InclusiveSpan', 'ExclusiveSpan', 'ScanToFirst']
19def usesCarryInit1(builtin_fn):
20   return builtin_fn in ['ScanToFirst']
21def usesCarryCountArgument(builtin_fn):
22   return builtin_fn in ['Advance']
23
24def CarryCountOfFn(fncall, builtin_fnmod_noprefix='pablo'):
25  if not isinstance(fncall, ast.Call): return 0
26  if isinstance(fncall.func, ast.Name): fn_name = fncall.func.id
27  elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
28    if fncall.func.value.id == builtin_fnmod_noprefix: fn_name = fncall.func.attr
29    else: return 0
30  else: return 0
31  if usesCarryCountArgument(fn_name):
32    if len(fncall.args) == 1: return 1
33    else: return 0 #  return fncall.args[1].n  # Possibly count Advance(m, n) as generating n carries.
34  elif isCarryGenerating(fn_name): return 1
35  else: return 0
36   
37def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='pablo'):
38        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
39        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
40                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
41        return iscall and len(fncall.args) == builtin_arg_cnt
42
43def dump_Call(fncall):
44        if isinstance(fncall.func, ast.Name): print "fn_name = %s\n" % fncall.func.id
45        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
46                print "fn_name = %s.%s\n" % (fncall.func.value.id, fncall.func.attr)
47        print "len(fncall.args) = %s\n" % len(fncall.args)
48
49def is_simd_not(e):
50  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
51
52def mkQname(obj, field):
53  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
54
55def mkCall(fn_name, args):
56  if isinstance(fn_name, str): 
57        if name_substitution_map.has_key(fn_name): fn_name = name_substitution_map[fn_name]
58        fn_name = ast.Name(fn_name, ast.Load())
59  return ast.Call(fn_name, args, [], None, None)
60
61def mkCallStmt(fn_name, args):
62  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
63  return ast.Expr(ast.Call(fn_name, args, [], None, None))
64
65#
66# Reducing AugAssign, e.g.  x |= y becomes x = x | y
67#
68class AugAssignRemoval(ast.NodeTransformer):
69  def xfrm(self, t):
70    return self.generic_visit(t)
71  def visit_AugAssign(self, e):
72    self.generic_visit(e)
73    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
74#
75# pablo.Advance(pablo.Advance(x, n)) => pablo.Advance(x, n+1)
76#
77class AdvanceCombiner(ast.NodeTransformer):
78  def xfrm(self, t):
79    return self.generic_visit(t)
80  def visit_if(self, ifNode):
81    return IfNode
82  def visit_While(self, whileNode):
83    return whileNode
84  def visit_Call(self, callnode):
85    self.generic_visit(callnode)
86    if len(callnode.args) == 0: return callnode
87    if not isinstance(callnode.args[0], ast.Call): return callnode
88    if is_BuiltIn_Call(callnode,'Advance', 1):
89        if is_BuiltIn_Call(callnode.args[0],'Advance', 1):
90          callnode.args = [callnode.args[0].args[0], ast.Num(2)]
91        elif is_BuiltIn_Call(callnode.args[0], 'Advance', 2):
92          if isinstance(callnode.args[0].args[1], ast.Num):
93            callnode.args = [callnode.args[0].args[0], ast.Num(callnode.args[0].args[1].n + 1)]
94          else:
95            callnode.args = [callnode.args[0].args[0], ast.BinOp(callnode.args[0].args[1], ast.Add(), ast.Num(1))]
96    return callnode
97
98
99#
100#  Translating pablo.match(marker, str)
101#  Incremental character match with lookahead
102#
103CharNameMap = {'[' : 'LBrak', ']' : 'RBrak', '{' : 'LBrace', '}' : 'LBrace', '(' : 'LParen', ')' : 'RParen', \
104              '!' : 'Exclam', '"' : 'DQuote', '#' : 'Hash', '$' : 'Dollar', '%' : 'PerCent', '&': 'RefStart', \
105              "'" : 'SQuote', '*': 'Star', '+' : 'Plus', ',' : 'Comma', '-' : 'Hyphen', '.' : 'Dot', '/' : 'Slash', \
106              ':' : 'Colon', ';' : 'Semicolon', '=' : 'Equals', '?' : 'QMark', '@' : 'AtSign', '\\' : 'BackSlash', \
107              '^' : 'Caret', '_' : 'Underscore', '|' : 'VBar', '~' : 'Tilde', ' ' : 'SP', '\t' : 'HT', '\m' : 'CR', '\n' : 'LF'}
108
109def GetCharName(char):
110        if char >= 'a' and char <= 'z' or char >= 'A' and char <= 'Z': return 'letter_' + char
111        elif char >= '0' and char <= '9': return 'digit_' + char
112        else: return CharNameMap[char]
113
114def MkCharStream(char):
115        return mkQname('lex', GetCharName(char))
116
117def MkLookAheadExpr(v, i):
118        return mkCall(mkQname('pablo', 'LookAhead'), [v, ast.Num(i)])
119
120def CompileMatch(match_var, string_to_match):
121        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
122        for i in range(1, len(string_to_match)):
123                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
124        return expr
125
126class StringMatchCompiler(ast.NodeTransformer):
127  def xfrm(self, t):
128    return self.generic_visit(t)
129  def visit_Call(self, callnode):
130    if is_BuiltIn_Call(callnode,'match', 2):
131        ast.dump(callnode)
132        assert isinstance(callnode.args[0], ast.Str)
133        string_to_match = callnode.args[0].s
134        match_var = callnode.args[1]
135        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
136        for i in range(1, len(string_to_match)):
137                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
138        return expr
139    else: return callnode
140
141
142
143#
144#  Converting expressions involving built-ins to compiled form. 
145#
146class TempifyBuiltins(ast.NodeTransformer):
147    def __init__(self, tempVarpfx = "tempvar"):
148      self.tempVarCount = 0
149      self.newTempList = []
150      self.tempVarPrefix = tempVarpfx
151    def genVar(self):
152      newTemp = self.tempVarPrefix + repr(self.tempVarCount)
153      self.newTempList.append(newTemp)
154      self.tempVarCount += 1
155      return newTemp
156    def tempVars(self):
157      return self.newTempList
158    def xfrm(self, t):
159      self.setUpStmts = []
160      self.assigNode = None
161      return self.generic_visit(t)
162    def is_Assign_value(self, node):
163      return self.assigNode != None and self.assigNode.value == node
164    def visit_If(self, ifNode):
165      outerSetUpStmts = self.setUpStmts
166      self.setUpStmts = []
167      self.generic_visit(ifNode)
168      ifSetUpStmts = self.setUpStmts
169      self.setUpStmts = outerSetUpStmts
170      if ifSetUpStmts == []: return ifNode
171      else: return ifSetUpStmts + [ifNode]
172    def visit_While(self, whileNode):
173      outerSetUpStmts = self.setUpStmts
174      self.setUpStmts = []
175      self.generic_visit(whileNode)
176      whileSetUpStmts = self.setUpStmts
177      self.setUpStmts = outerSetUpStmts
178      whileNode.body = whileNode.body + whileSetUpStmts
179      return whileSetUpStmts + [whileNode]
180    def visit_Assign(self, node):
181      self.assigNode = node
182      outerSetUpStmts = self.setUpStmts
183      self.setUpStmts = []
184      self.generic_visit(node)
185      innerSetUpStmts = self.setUpStmts
186      self.setUpStmts = outerSetUpStmts
187      return innerSetUpStmts + [node]
188    def visit_Call(self, callnode):     
189        self.generic_visit(callnode)
190        if CarryCountOfFn(callnode) > 0 and not self.is_Assign_value(callnode):
191        #if not self.is_Assign_value(callnode):
192            tempVar = ast.Name(self.genVar(), ast.Load())
193            self.setUpStmts.append(ast.Assign([tempVar], callnode))
194            return tempVar
195        else: return callnode
196
197
198
199
200
201#
202# Introducing BitBlock logical operations
203#
204class Bitwise_to_SIMD(ast.NodeTransformer):
205  """
206  Make the following substitutions:
207     x & y => simd_and(x, y)
208     x & ~y => simd_andc(x, y)
209     x | y => simd_or(x, y)
210     x ^ y => simd_xor(x, y)
211     ~x    => simd_not(x)
212     0     => simd_const_1(0)
213     -1    => simd_const_1(1)
214     if x: => if bitblock::any(x):
215  while x: => while bitblock::any(x):
216  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
217 
218  pfx = simd_and(bit0, bit1)
219  sfx = simd_and(bit0, simd_not(bit1))
220  >>>
221  """
222  def xfrm(self, t):
223    return self.generic_visit(t)
224  def visit_UnaryOp(self, t):
225    self.generic_visit(t)
226    if isinstance(t.op, ast.Invert):
227      return mkCall('simd_not', [t.operand])
228    else: return t
229  def visit_BinOp(self, t):
230    self.generic_visit(t)
231    if isinstance(t.op, ast.BitOr):
232      return mkCall('simd_or', [t.left, t.right])
233    elif isinstance(t.op, ast.BitAnd):
234      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
235      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
236      else: return mkCall('simd_and', [t.left, t.right])
237    elif isinstance(t.op, ast.BitXor):
238      return mkCall('simd_xor', [t.left, t.right])
239    else: return t
240  def visit_Num(self, numnode):
241    n = numnode.n
242    if n == 0: return mkCall('simd<1>::constant<0>', [])
243    elif n == -1: return mkCall('simd<1>::constant<1>', [])
244    else: return numnode
245  def visit_If(self, ifNode):
246    self.generic_visit(ifNode)
247    ifNode.test = mkCall('bitblock::any', [ifNode.test])
248    return ifNode
249  def visit_While(self, whileNode):
250    self.generic_visit(whileNode)
251    whileNode.test = mkCall('bitblock::any', [whileNode.test])
252    return whileNode
253  def visit_Subscript(self, numnode):
254    return numnode  # no recursive modifications of index expressions
255
256#
257#  Generating BitBlock declarations for Local Variables
258#
259class FunctionVars(ast.NodeVisitor):
260  def __init__(self,node):
261        self.params = []
262        self.stores = []
263        self.generic_visit(node)
264  def visit_Name(self, nm):
265    if isinstance(nm.ctx, ast.Param):
266      self.params.append(nm.id)
267    if isinstance(nm.ctx, ast.Store):
268      if nm.id not in self.stores: self.stores.append(nm.id)
269  def getLocals(self): 
270    return [v for v in self.stores if not v in self.params]
271
272MAX_LINE_LENGTH = 80
273
274def BitBlock_decls_from_vars(varlist):
275  global MAX_LINE_LENGTH
276  decls =  ""
277  if not len(varlist) == 0:
278          decls = "             BitBlock"
279          pending = ""
280          linelgth = 10
281          for v in varlist:
282            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
283              decls += pending + " " + v
284              linelgth += len(pending + v) + 1
285            else:
286              decls += ";\n             BitBlock " + v
287              linelgth = 11 + len(v)
288            pending = ","
289          decls += ";"
290  return decls
291
292def BitBlock_decls_of_fn(fndef):
293  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
294
295def BitBlock_header_of_fn(fndef):
296  Ccode = "static inline void " + fndef.name + "("
297  pending = ""
298  for arg in fndef.args.args:
299    if isinstance(arg, ast.Name):
300      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
301      pending = ", "
302  if CarryCounter().count(fndef) > 0:
303    Ccode += pending + " CarryQtype & carryQ"
304  Ccode += ")"
305  return Ccode
306
307
308
309#
310#  Stream Initialization Statement Extraction
311#
312#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
313class StreamInitializations(ast.NodeTransformer):
314  def xfrm(self, node):
315    self.stream_stmts = []
316    self.loop_post_inits = []
317    self.generic_visit(node)
318    return Cgen.py2C().gen(self.stream_stmts)
319  def visit_Assign(self, node):
320    if isinstance(node.value, ast.Num):
321      if node.value.n == 0: return node
322      elif node.value.n == -1: return node
323      else: 
324        stream_init = copy.deepcopy(node)
325        stream_init.value = mkCall('sisd_from_int', [node.value])
326        loop_init = copy.deepcopy(node)
327        loop_init.value.n = 0
328        self.stream_stmts.append(stream_init)
329        self.loop_post_inits.append(loop_init)
330        return None
331    else: return node
332  def visit_FunctionDef(self, node):
333    self.generic_visit(node)
334    node.body = node.body + self.loop_post_inits
335    return node
336
337
338
339#
340# Carry Introduction Transformation
341#
342class CarryCounter(ast.NodeVisitor):
343  def visit_Call(self, callnode):
344    self.generic_visit(callnode)
345    self.carry_count += CarryCountOfFn(callnode)
346  def visit_BinOp(self, exprnode):
347    self.generic_visit(exprnode)
348    if isinstance(exprnode.op, ast.Sub):
349      self.carry_count += 1
350    if isinstance(exprnode.op, ast.Add):
351      self.carry_count += 1
352  def count(self, nodeToVisit):
353    self.carry_count = 0
354    self.generic_visit(nodeToVisit)
355    return self.carry_count
356
357#
358# Carry Initialization:  Aug. 4, 2012
359# - Carry variables are initialized to 0 by default
360# - However, the scan_to_first routine should ideally use
361#   initialization with 1.
362
363#
364class CarryInitToOneList(ast.NodeVisitor):
365  def visit_Call(self, callnode):
366    self.generic_visit(callnode)
367    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2) or is_BuiltIn_Call(callnode,'ScanTo', 2) or is_BuiltIn_Call(callnode,'AdvanceThenScanThru', 2) or is_BuiltIn_Call(callnode,'AdvanceThenScanTo', 2) or is_BuiltIn_Call(callnode,'SpanUpTo', 2) or  is_BuiltIn_Call(callnode,'InclusiveSpan', 2) or is_BuiltIn_Call(callnode,'ExclusiveSpan', 2):       
368      self.carry_count += 1
369    elif is_BuiltIn_Call(callnode,'ScanToFirst', 1):
370      self.init_to_one_list.append(self.carry_count)
371      self.carry_count += 1
372  def visit_BinOp(self, exprnode):
373    self.generic_visit(exprnode)
374    if isinstance(exprnode.op, ast.Sub):
375      self.carry_count += 1
376    if isinstance(exprnode.op, ast.Add):
377      self.carry_count += 1
378  def count(self, nodeToVisit):
379    self.carry_count = 0
380    self.init_to_one_list = []
381    self.generic_visit(nodeToVisit)
382    return self.init_to_one_list
383
384class adv_nCounter(ast.NodeVisitor):
385  def visit_Call(self, callnode):
386    self.generic_visit(callnode)
387    if is_BuiltIn_Call(callnode,'Advance32', 1):       
388      self.adv_n_count += 1
389    if is_BuiltIn_Call(callnode,'Advance', 2):         
390      self.adv_n_count += 1
391  def count(self, nodeToVisit):
392    self.adv_n_count = 0
393    self.generic_visit(nodeToVisit)
394    return self.adv_n_count
395
396class CarryIntro(ast.NodeTransformer):
397  def __init__(self, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
398    self.carryvar = ast.Name(carryvar, ast.Load())
399    self.carryin = carryin
400    self.carryout = carryout
401  def xfrm_fndef(self, fndef):
402    self.current_carry = 0
403    self.current_adv_n = 0
404    carry_count = CarryCounter().count(fndef)
405    if carry_count == 0: return fndef
406    self.generic_visit(fndef)
407#   
408#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
409    return fndef
410  def generic_xfrm(self, node):
411    self.current_carry = 0
412    self.current_adv_n = 0
413    self.last_stmt = None
414    self.last_stmt_carries = 0
415    carry_count = CarryCounter().count(node)
416    adv_n_count = adv_nCounter().count(node)
417    if carry_count == 0 and adv_n_count == 0: return node
418    self.generic_visit(node)
419    return node
420  def visit_Call(self, callnode):
421    self.generic_visit(callnode)
422    #CARRYSET
423    #carry_args = [ast.Num(self.current_carry)]
424    #adv_n_args = [ast.Subscript(ast.Name(self.carryvar.id + '.pending64', ast.Load()), ast.Num(self.current_adv_n), ast.Load())]
425    #adv_n_pending = ast.Subscript(ast.Name(self.carryvar.id + '.pending64', ast.Load()), ast.Num(self.current_adv_n), ast.Load())
426    if self.carryin == "_ci":
427        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
428        adv_n_args = [mkCall(self.carryvar.id + "." + 'get_pending64', [ast.Num(self.current_adv_n)]), ast.Num(self.current_adv_n)]
429    else: 
430        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
431        adv_n_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_adv_n)]
432
433    if is_BuiltIn_Call(callnode, 'Advance', 2):         
434      #CARRYSET
435      rtn = self.carryvar.id + "." + "BitBlock_advance_n_<%i>" % callnode.args[1].n
436      c = mkCall(rtn, [callnode.args[0]] + adv_n_args)
437      self.current_adv_n += 1
438      return c
439    if is_BuiltIn_Call(callnode, 'Advance', 1):         
440      #CARRYSET
441      rtn = self.carryvar.id + "." + "BitBlock_advance_ci_co"
442      c = mkCall(rtn, callnode.args + carry_args)
443      self.current_carry += 1
444      return c
445    elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
446      #CARRYSET
447      rtn = self.carryvar.id + "." + "BitBlock_advance_n_<32>"
448      c = mkCall(rtn, callnode.args + adv_n_args)
449      self.current_adv_n += 1
450      return c
451    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
452      #CARRYSET
453      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co"
454      c = mkCall(rtn, callnode.args + carry_args)
455      self.current_carry += 1
456      return c
457    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanThru', 2):
458      #CARRYSET
459      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru"
460      c = mkCall(rtn, callnode.args + carry_args)
461      self.current_carry += 1
462      return c
463    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanTo', 2):
464      #CARRYSET
465      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru" 
466      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
467      else: scanclass = mkCall('simd_not', [callnode.args[1]])
468      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
469      self.current_carry += 1
470      return c
471    elif is_BuiltIn_Call(callnode, 'SpanUpTo', 2):
472      #CARRYSET
473      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
474      c = mkCall(rtn, callnode.args + carry_args)
475      self.current_carry += 1
476      return c
477    elif is_BuiltIn_Call(callnode, 'InclusiveSpan', 2):
478      #CARRYSET
479#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
480#      c = mkCall('simd_or', [mkCall(rtn, callnode.args + carry_args), callnode.args[1]])
481      rtn = self.carryvar.id + "." + "BitBlock_inclusive_span"
482      c = mkCall(rtn, callnode.args + carry_args)
483      self.current_carry += 1
484      return c
485    elif is_BuiltIn_Call(callnode, 'ExclusiveSpan', 2):
486      #CARRYSET
487#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
488#      c = mkCall('simd_andc', [mkCall(rtn, callnode.args + carry_args), callnode.args[0]])
489      rtn = self.carryvar.id + "." + "BitBlock_exclusive_span"
490      c = mkCall(rtn, callnode.args + carry_args)
491      self.current_carry += 1
492      return c
493    elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
494      # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
495      # in having a separate BitBlock_scanto routine.
496      #CARRYSET
497      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co" 
498      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
499      else: scanclass = mkCall('simd_not', [callnode.args[1]])
500      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
501      self.current_carry += 1
502      return c
503    elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
504      #CARRYSET
505      rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
506      #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
507      c = mkCall(rtn, callnode.args + carry_args)
508      self.current_carry += 1
509      return c
510    elif is_BuiltIn_Call(callnode, 'atEOF', 1):
511      if self.carryout != "": 
512        # Non final block: atEOF(x) = 0.
513        return mkCall('simd<1>::constant<0>', [])
514      else: return mkCall('simd_andc', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
515    elif is_BuiltIn_Call(callnode, 'inFile', 1):
516      if self.carryout != "": 
517        # Non final block: inFile(x) = x.
518        return callnode.args[0]
519      else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
520    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
521      rtn = "StreamScan"           
522      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
523                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
524                                           ast.Name(callnode.args[1].id, ast.Load())])
525      return c
526    else:
527      #dump_Call(callnode)
528      return callnode
529  def visit_BinOp(self, exprnode):
530    self.generic_visit(exprnode)
531    carry_args = [ast.Num(self.current_carry)]
532    if self.carryin == "_ci":
533        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
534    else: 
535        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
536    if isinstance(exprnode.op, ast.Sub):
537      #CARRYSET
538      rtn = self.carryvar.id + "." + "BitBlock_sub_ci_co"
539      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
540      self.current_carry += 1
541      return c
542    elif isinstance(exprnode.op, ast.Add):
543      #CARRYSET
544      rtn = self.carryvar.id + "." + "BitBlock_add_ci_co"
545      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
546      self.current_carry += 1
547      return c
548    else: return exprnode
549  def visit_Assign(self, assignode):
550      self.last_stmt_carries = CarryCounter().count(assignode)
551      self.generic_visit(assignode)
552      self.last_stmt = assignode
553      return assignode
554
555  def visit_If(self, ifNode):
556    carry_base = self.current_carry
557    carries = CarryCounter().count(ifNode)
558    assert adv_nCounter().count(ifNode) == 0, "Advance(x,n) within if: illegal\n"
559    self.generic_visit(ifNode)
560    if carries == 0 or self.carryin == "": 
561      self.last_stmt = ifNode
562      return ifNode
563    #CARRYSET
564    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
565    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
566    new_else_part = ifNode.orelse + [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]
567    newIf = ast.If(new_test, ifNode.body, new_else_part)
568    self.last_stmt = newIf
569    self.last_stmt_carries = carries
570    return newIf
571
572
573  def is_while_special_case(self, whileNode):
574    #
575    # Special case optimization for pattern:
576    #   m=pablo.scan...()
577    #   while m:
578    #      S
579    #      m=pablo.scan...()
580    #
581    # Determine the original test expression, now encloded in bitblock::any()
582    original_test_expr = whileNode.test.args[0]
583    if not isinstance(original_test_expr, ast.Name): return False
584    test_var = original_test_expr.id
585    if not isinstance(self.last_stmt, ast.Assign): return False
586    if not isinstance(whileNode.body[-1], ast.Assign): return False
587    if len(self.last_stmt.targets) != 1: return False
588    if len(whileNode.body[-1].targets) != 1: return False
589    if not isinstance(self.last_stmt.targets[0], ast.Name): return False
590    if not isinstance(whileNode.body[-1].targets[0], ast.Name): return False
591    if self.last_stmt.targets[0].id != test_var: return False
592    if whileNode.body[-1].targets[0].id != test_var: return False
593    if self.last_stmt_carries != 1: return False
594    if CarryCounter().count(whileNode.body[-1]) != 1: return False
595    return True
596
597  def visit_While(self, whileNode):
598    # Determine the original test expression, now encloded in bitblock::any()
599    original_test_expr = whileNode.test.args[0]
600    if self.carryout == '':
601      whileNode.test.args[0] = mkCall("simd_and", [original_test_expr, ast.Name('EOF_mask', ast.Load())])
602    carry_base = self.current_carry
603    assert adv_nCounter().count(whileNode) == 0, "Advance(x,n) within while: illegal\n"
604    carries = CarryCounter().count(whileNode)
605#   Special Case Recognition
606    is_special = self.is_while_special_case(whileNode)
607    #CARRYSET
608    if carries == 0: return whileNode
609    local_carryvar = 'sub' + self.carryvar.id
610    inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
611    self.generic_visit(whileNode)
612    local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
613    inner_while.body.insert(0, local_carry_decl)
614    final_combine = mkCallStmt(ast.Attribute(self.carryvar, 'CarryCombine', ast.Load()), [ast.Attribute(ast.Name(local_carryvar, ast.Load()), 'cq', ast.Load()),ast.Num(carry_base), ast.Num(carries)])
615    inner_while.body.append(final_combine)
616    #CARRYSET
617
618#   Special Case Optimization
619    if is_special:
620      # We combine the final carry into the one preceeding the loop.
621      combine1 = mkCallStmt(ast.Attribute(self.carryvar, 'CarryCombine1', ast.Load()), [ast.Num(carry_base-1), ast.Num(carry_base+carries-1)])
622      while_body_extend = [inner_while, combine1]
623      # The carry test can skip the final case.
624      carry_test_arglist = [ast.Num(carry_base), ast.Num(carries-1)]
625    else: 
626      carry_test_arglist = [ast.Num(carry_base), ast.Num(carries)]
627      while_body_extend = [inner_while]
628
629    if self.carryin == '': new_test = whileNode.test
630    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_test_arglist)])
631    else_part = [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), [ast.Num(carry_base), ast.Num(carries)])]   
632    newIf = ast.If(new_test, whileNode.body + while_body_extend, else_part)
633    self.last_stmt = newIf
634    self.last_stmt_carries = carries
635    return newIf
636
637class StreamStructGen(ast.NodeVisitor):
638  """
639  Given a BitStreamSet subclass, generate the equivalent C struct.
640  >>> obj = ast.parse(r'''
641  ... class S1(BitStreamSet):
642  ...   a1 = 0
643  ...   a2 = 0
644  ...   a3 = 0
645  ...
646  ... class S2(BitStreamSet):
647  ...   x1 = 0
648  ...   x2 = 0
649  ... ''')
650  >>> print StreamStructGen().gen(obj)
651  struct S1 {
652    BitBlock a1;
653    BitBlock a2;
654    BitBlock a3;
655  }    self.current_adv_n = 0
656
657 
658  struct S2 {
659    BitBlock x1;
660    BitBlock x2;
661  }
662  """
663  def __init__(self, asType=False):
664    self.asType = asType
665  def gen(self, tree):
666    self.Ccode=""
667    self.generic_visit(tree)
668    return self.Ccode
669  def gen_struct_types(self, tree):
670    self.asType = True
671    self.Ccode=""
672    self.generic_visit(tree)
673    return self.Ccode
674  def gen_struct_vars(self, tree):
675    self.asType = False
676    self.Ccode=""
677    self.generic_visit(tree)
678    return self.Ccode
679  def visit_ClassDef(self, node):
680    class_name = node.name[0].upper() + node.name[1:]
681    instance_name = node.name[0].lower() + node.name[1:]
682    self.Ccode += "  struct " + class_name
683    if self.asType:
684            self.Ccode += " {\n"
685            for stmt in node.body:
686              if isinstance(stmt, ast.Assign):
687                for v in stmt.targets:
688                  if isinstance(v, ast.Name):
689                    self.Ccode += "  BitBlock " + v.id + ";\n"
690            self.Ccode += "}" 
691    else: self.Ccode += " " + instance_name
692    self.Ccode += ";\n\n"
693 
694class StreamFunctionDecl(ast.NodeVisitor):
695  def __init__(self):
696    pass
697  def gen(self, tree):
698    self.Ccode=""
699    self.generic_visit(tree)
700    return self.Ccode
701  def visit_FunctionDef(self, node):
702    self.Ccode += "static inline void " + node.name + "("
703    pending = ""
704    for arg in node.args.args:
705      if isinstance(arg, ast.Name):
706        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
707        pending = ", "
708    self.Ccode += ");\n"
709
710class AssertCompiler(ast.NodeTransformer):
711  def __init__(self):
712    self.assert_routine = ast.parse(error_routine).body[0].value
713  def xfrm(self, t):
714    return self.generic_visit(t)
715  def visit_Expr(self, node):
716    if isinstance(node.value, ast.Call):
717        if is_BuiltIn_Call(node.value, "assert_0", 2):
718                err_stream = node.value.args[0]
719                err_code = node.value.args[1]
720                return ast.If(err_stream,
721                              [ast.Expr(mkCall(self.assert_routine, [err_code, err_stream]))],
722                              [])
723        else: return node
724    else: return node
725       
726#
727# Adding Debugging Statements
728#
729class Add_SIMD_Register_Dump(ast.NodeTransformer):
730  def xfrm(self, t):
731    return self.generic_visit(t)
732  def visit_Assign(self, t):
733    self.generic_visit(t)
734    v = t.targets[0]
735    dump_stmt = mkCallStmt(' print_register<BitBlock>', [ast.Str(Cgen.py2C().gen(v)), v])
736    return [t, dump_stmt]
737
738#
739# Adding ASSERT_BITBLOCK_ALIGN macros
740#
741class Add_Assert_BitBlock_Align(ast.NodeTransformer):
742    def xfrm(self, t):
743      return self.generic_visit(t)
744    def visit_Assign(self, t):
745      self.generic_visit(t)
746      v = t.targets[0]
747      dump_stmt = mkCallStmt(' ASSERT_BITBLOCK_ALIGN', [v])
748      return [t, dump_stmt]
749
750class StreamFunctionCarryCounter(ast.NodeVisitor):
751  def __init__(self):
752        self.carry_count = {}
753       
754  def count(self, node):
755        self.generic_visit(node)
756        return self.carry_count
757                                   
758  def visit_FunctionDef(self, node):   
759        type_name = node.name[0].upper() + node.name[1:]                       
760        self.carry_count[type_name] = CarryCounter().count(node)
761     
762class StreamFunctionCallXlator(ast.NodeTransformer):
763  def __init__(self, xlate_type="normal"):
764        self.stream_function_type_names = []
765        self.xlate_type = xlate_type
766
767  def xfrm(self, node, stream_function_type_names, C_syntax):
768        self.stream_function_type_names = stream_function_type_names
769        self.C_syntax = C_syntax
770        self.generic_visit(node)
771       
772  def visit_Call(self, node):   
773        self.generic_visit(node)
774
775        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):# and node.func.id in self.stream_function_type_names:
776             name = lower1(node.func.id)
777             node.func.id = name + ("_" if self.C_syntax else ".") + ("do_final_block" if self.xlate_type == "final" else "do_block")
778             if self.C_syntax:
779                     node.args = [ast.Name(lower1(name), ast.Load())] + node.args
780             if self.xlate_type == "final":
781                   node.args = node.args + [ast.Name("EOF_mask", ast.Load())]
782                     
783        return node     
784               
785class StreamFunctionVisitor(ast.NodeVisitor):
786        def __init__(self,node):
787                self.stream_function_node = {}
788                self.generic_visit(node)
789                                                             
790        def visit_FunctionDef(self, node):                     
791                key = node.name[0].upper() + node.name[1:]
792                self.stream_function_node[key] = node
793
794
795               
796class StreamFunction():
797        def __init__(self):
798                self.carry_count = 0 
799                self.init_to_one_list = [] 
800                self.adv_n_count = 0 
801                self.type_name = ""
802                self.instance_name = "" 
803                self.parameters = ""
804                self.declarations = "" 
805                self.initializations = "" 
806       
807        def dump(self):
808                print "%s" % (self.type_name)
809                print "%s=%s" % ("Carry Count", str(self.carry_count))
810                print "%s=[%s]" % ("Init to One List" , ','.join(map(str,self.init_to_one_list)))
811                print "%s=%s" % ("Adv n Count", str(self.adv_n_count)) 
812        #print "Instance Name:"     #+ self.instance_name = ""
813        #print "Parameters:"        #+ self.parameters = ""
814        #print "Declarations:"           #+ self.declarations = ""
815        #print "Initializations:"   # + self.initializations = ""
816   
817#
818# TODO Consolidate *all* C code generation into the Emitter class.   Medium priority.
819# TODO Implement 'pretty print' indentation.   Low priority.
820# TODO Migrate Emiter() class to Emitter module.  Medium priority.
821
822def lower1(name):
823    return name[0].lower() + name[1:]
824def upper1(name):
825    return name[0].upper() + name[1:]
826
827def escape_newlines(str):
828  return str.replace('\n', '\\\n')
829
830class Emitter():
831        def __init__(self, use_C_syntax):
832                self.use_C_syntax = use_C_syntax
833
834        def definition(self, stream_function, icount=0):
835               
836                constructor = ""
837                carry_declaration = ""
838                self.type_name = stream_function.type_name
839               
840                if stream_function.carry_count > 0 or stream_function.adv_n_count > 0:
841                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.init_to_one_list, stream_function.adv_n_count)
842                        carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv_n_count)
843
844                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), 
845                                                stream_function.declarations, 
846                                                stream_function.initializations, 
847                                                stream_function.statements)             
848
849                do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), 
850                                                stream_function.declarations, 
851                                                stream_function.initializations, 
852                                                stream_function.final_block_statements)                 
853
854                do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), 
855                                                self.do_segment_args(stream_function.parameters))       
856
857                if self.use_C_syntax:
858                        return self.indent(icount) + "struct " + stream_function.type_name + " {" \
859                               + "\n" + self.indent(icount) + carry_declaration \
860                               + "\n" + self.indent(icount) + "};\n" \
861                               + "\n" + self.indent(icount) + do_block_function \
862                               + "\n" + self.indent(icount) + do_final_block_function \
863                               + "\n" + self.indent(icount) + do_segment_function + "\n\n"
864                               
865                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
866                + "\n" + self.indent(icount) + constructor \
867                + "\n" + self.indent(icount) + do_block_function \
868                + "\n" + self.indent(icount) + do_final_block_function \
869                + "\n" + self.indent(icount) + do_segment_function \
870                + "\n" + self.indent(icount) + carry_declaration \
871                + "\n" + self.indent(icount) + "};\n\n"
872
873        def constructor(self, type_name, carry_count, init_to_one_list, adv_n_count, icount=0):
874                one_inits = ""
875                for v in init_to_one_list:
876                        one_inits += "  carryQ.cq[%s] = carryQ.carry_flip(carryQ.cq[%s]);\n" % (v, v)
877                adv_n_decl = ""
878                #for i in range(adv_n_count): adv_n_decl += self.indent(icount+2) + "pending64[%s] = simd<1>::constant<0>();\n" % i     
879                return self.indent(icount) + "%s() { ""\n" % (type_name) + adv_n_decl + self.carry_init(carry_count) + one_inits + " }" 
880                       
881        def do_block(self, parameters, declarations, initializations, statements, icount=0):
882                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
883                if self.use_C_syntax:
884                        return "#define " + pfx + "do_block(" + parameters + ")\\\n do {" \
885                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
886                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
887                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
888                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
889                return self.indent(icount) + do_block_inline_decorator + "void " + pfx + "do_block(" + parameters + ") {" \
890                + "\n" + self.indent(icount) + declarations \
891                + "\n" + self.indent(icount) + initializations \
892                + "\n" + self.indent(icount) + statements \
893                + "\n" + self.indent(icount + 2) + "}" 
894
895
896
897
898        def do_final_block(self, parameters, declarations, initializations, statements, icount=0):
899                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
900                if self.use_C_syntax:
901                        return "#define " + pfx + "do_final_block(" + parameters + ")\\\n do {" \
902                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
903                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
904                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
905                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
906                return self.indent(icount) + do_final_block_inline_decorator + "void " + pfx + "do_final_block(" + parameters + ") {" \
907                + "\n" + self.indent(icount) + declarations \
908                + "\n" + self.indent(icount) + initializations \
909                + "\n" + self.indent(icount) + statements \
910                + "\n" + self.indent(icount + 2) + "}" 
911
912        def do_segment(self, parameters, do_block_call_args, icount=0):
913                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
914                if self.use_C_syntax:
915                        return "#define " + pfx + "do_segment(" + parameters + ")\\\n do {" \
916                        + "\\\n" + self.indent(icount) + "  int i;" \
917                        + "\\\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
918                        + "\\\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
919                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
920                return self.indent(icount) + "void " + pfx + "do_segment(" + parameters + ") {" \
921                + "\n" + self.indent(icount) + "  int i;" \
922                + "\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
923                + "\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
924                + "\n" + self.indent(icount + 2) + "}" 
925
926        def declaration(self, type_name, instance_name, icount=0):
927                if self.use_C_syntax: return self.indent(icount) + "struct " + type_name + " " + instance_name + ";\n"
928                return self.indent(icount) + type_name + " " + instance_name + ";\n"
929               
930        def carry_init(self, carry_count, icount=0):   
931                #CARRY SET
932                return ""
933               
934        def carry_declare(self, carry_variable, carry_count, adv_n_count=0, icount=0):
935                adv_n_decl = ""
936                #if adv_n_count > 0:
937                #       adv_n_decl = "\n" + self.indent(icount) + "BitBlock pending64[%s];" % adv_n_count
938                #CARRY SET
939                return self.indent(icount) + "CarryArray<%i, %i> %s;" % (carry_count, adv_n_count, carry_variable)
940
941        def carry_test(self, carry_variable, carry_count, icount=0):
942                #CARRY SET
943                return self.indent(icount) + "carryQ.CarryTest(0, %i)" % (carry_count)         
944               
945        def indent(self, icount):
946                s = ""
947                for i in range(0,icount): s += " "
948                return s       
949               
950        def do_block_parameters(self, parameters):
951                if self.use_C_syntax:
952                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters])
953                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
954                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters])
955               
956        def do_final_block_parameters(self, parameters):
957                if self.use_C_syntax:
958                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
959                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters]+ ["EOF_mask"])
960                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
961               
962        def do_segment_parameters(self, parameters):
963                if self.use_C_syntax:
964                        #return ", ".join([self.type_name + " * " + + self.instance_name] + [upper1(p) + " " + lower1(p) + "[]" for p in parameters])
965                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ["segment_blocks"])
966                else: return ", ".join([upper1(p) + " " + lower1(p) + "[]" for p in parameters] + ["int segment_blocks"])
967
968        def do_segment_args(self, parameters):
969                if self.use_C_syntax:
970                        return ", ".join([lower1(self.type_name)] + [lower1(p) + "[i]" for p in parameters])
971                else: return ", ".join([lower1(p) + "[i]" for p in parameters])
972
973def main(infilename, outfile = sys.stdout):
974  t = ast.parse(file(infilename).read())
975  outfile.write(StreamStructGen(True).gen(t))
976  outfile.write(FunctionXlator().xlat(t))
977
978#
979#  Routines for compatibility with the old compiler/template.
980#  Quick and dirty hacks for now - Dec. 2010.
981#
982
983class MainLoopTransformer:
984  def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, add_assert_bitblock_align=False, dump_func_data=False, main_node_id='Main'):
985       
986    self.main_module = main_module
987    self.main_node_id = main_node_id
988    self.use_C_syntax = C_syntax
989    self.add_dump_stmts = add_dump_stmts
990    self.add_assert_bitblock_align = add_assert_bitblock_align
991    self.dump_func_data = dump_func_data
992   
993        # Gather and partition function definition nodes.
994    stream_function_visitor = StreamFunctionVisitor(self.main_module)
995    self.stream_function_node = stream_function_visitor.stream_function_node
996    for key, node in self.stream_function_node.iteritems():
997                AdvanceCombiner().xfrm(node)
998    self.main_node = self.stream_function_node[main_node_id]
999    self.main_carry_count = CarryCounter().count(self.main_node)
1000    self.main_adv_n_count = adv_nCounter().count(self.main_node)
1001    assert self.main_adv_n_count == 0, "Advance32() in main not supported.\n"
1002    del self.stream_function_node[self.main_node_id]
1003   
1004    self.stream_functions = {}
1005    for key, node in self.stream_function_node.iteritems():
1006                stream_function = StreamFunction()
1007                stream_function.carry_count = CarryCounter().count(node)
1008                stream_function.init_to_one_list = CarryInitToOneList().count(node)
1009                stream_function.adv_n_count = adv_nCounter().count(node)
1010                stream_function.type_name = node.name[0].upper() + node.name[1:]
1011                stream_function.instance_name = node.name[0].lower() + node.name[1:]
1012                stream_function.parameters = FunctionVars(node).params
1013                stream_function.declarations = BitBlock_decls_of_fn(node)
1014                stream_function.initializations = StreamInitializations().xfrm(node) 
1015               
1016                t = TempifyBuiltins()
1017                t.xfrm(node)
1018                stream_function.declarations += "\n" + BitBlock_decls_from_vars(t.tempVars())
1019               
1020                StringMatchCompiler().xfrm(node)
1021                AssertCompiler().xfrm(node)
1022                AugAssignRemoval().xfrm(node)
1023                Bitwise_to_SIMD().xfrm(node)
1024                final_block_node = copy.deepcopy(node)
1025                if self.use_C_syntax:
1026                        carryQname = stream_function.instance_name + ".carryQ"
1027                else: carryQname = "carryQ"
1028                CarryIntro(carryQname).xfrm_fndef(node)
1029                CarryIntro(carryQname, "_ci", "").xfrm_fndef(final_block_node)
1030
1031                if self.add_dump_stmts: 
1032                        Add_SIMD_Register_Dump().xfrm(node)
1033                        Add_SIMD_Register_Dump().xfrm(final_block_node)
1034
1035                if self.add_assert_bitblock_align:
1036                        Add_Assert_BitBlock_Align().xfrm(node)
1037                        Add_Assert_BitBlock_Align().xfrm(final_block_node)
1038
1039                if stream_function.carry_count > 0:
1040                        node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
1041               
1042                stream_function.statements = Cgen.py2C(4).gen(node.body)
1043                stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
1044                self.stream_functions[stream_function.type_name] = stream_function
1045               
1046    if self.dump_func_data:     
1047        for key, value in self.stream_functions.iteritems():
1048                        value.dump()
1049        sys.exit()
1050               
1051    self.emitter = Emitter(self.use_C_syntax)
1052   
1053  def any_carry_expr(self):
1054       
1055        carry_test = []
1056       
1057        if self.main_carry_count > 0:
1058                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count)) 
1059                        carry_test.append(" || ")
1060
1061        for key in self.stream_functions.keys():               
1062                if self.stream_functions[key].carry_count > 0:
1063                        carry_test.append(self.stream_functions[key].instance_name + "." + self.emitter.carry_test('carryQ',self.stream_functions[key].carry_count))# TODO Update self.emitter.carry_test
1064                        carry_test.append(" || ")
1065
1066        if len(carry_test) > 0:
1067                carry_test.pop()
1068                return "".join(carry_test)
1069        return "1"
1070
1071  def gen_globals(self):
1072    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
1073    for key in self.stream_functions.keys():
1074                self.Cglobals += Emitter(self.use_C_syntax).definition(self.stream_functions[key],2)       
1075                       
1076  def gen_declarations(self): 
1077    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
1078    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
1079    if self.main_carry_count > 0: 
1080        self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
1081               
1082  def gen_initializations(self):
1083    self.Cinits = ""
1084    if self.main_carry_count > 0: 
1085        self.Cinits += self.emitter.carry_init(self.main_carry_count)
1086    self.Cinits += StreamInitializations().xfrm(self.main_module)
1087    if self.use_C_syntax:
1088                for key in self.stream_functions.keys():
1089                        if self.stream_functions[key].carry_count == 0: continue
1090                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
1091                        self.Cinits += "CarryInit(" + self.stream_functions[key].instance_name + ".carryQ, %i);\n" % (self.stream_functions[key].carry_count)
1092    else:
1093                for key in self.stream_functions.keys():
1094                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
1095
1096                       
1097  def xfrm_block_stmts(self):
1098    StringMatchCompiler().xfrm(self.main_node)
1099    AugAssignRemoval().xfrm(self.main_node)
1100    Bitwise_to_SIMD().xfrm(self.main_node)
1101    Bitwise_to_SIMD().xfrm(self.main_node)
1102    AssertCompiler().xfrm(self.main_node)
1103    final_block_main = copy.deepcopy(self.main_node)
1104    CarryIntro().xfrm_fndef(self.main_node)
1105    CarryIntro('carryQ', '_ci', '').xfrm_fndef(final_block_main)
1106    if self.add_dump_stmts: 
1107        Add_SIMD_Register_Dump().xfrm(self.main_node)
1108        Add_SIMD_Register_Dump().xfrm(final_block_main)
1109               
1110    if self.add_assert_bitblock_align:
1111        print "add_assert_bitblock_align"
1112        Add_Assert_BitBlock_Align().xfrm(self.main_node)
1113        Add_Assert_BitBlock_Align().xfrm(final_block_main)
1114
1115    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
1116    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
1117   
1118    if self.main_carry_count > 0:
1119                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
1120                self.main_node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(self.main_carry_count)])]
1121   
1122   
1123       
1124    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
1125    self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
1126   
1127if __name__ == "__main__":
1128                import doctest
1129                doctest.testmod()
Note: See TracBrowser for help on using the repository browser.