source: proto/Compiler/pablo.py @ 2691

Last change on this file since 2691 was 2691, checked in by cameron, 7 years ago

Record whether blocks are if/while in carryInfoSet

File size: 56.2 KB
RevLine 
[753]1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
[903]5# Copyright 2010, 2011, Robert D. Cameron, Kenneth S. Herdy
[753]6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
[1271]11name_substitution_map = {}
[1997]12do_block_inline_decorator = 'IDISA_INLINE '
13do_final_block_inline_decorator = ''
[2099]14error_routine = 'raise_assert'
[2612]15experimentalMode=False
[1271]16
[2606]17def isCarryGenerating(builtin_fn):
[2636]18   return builtin_fn in ['ScanThru', 'ScanTo', 'AdvanceThenScanThru', 'AdvanceThenScanTo', 'SpanUpTo', 'InclusiveSpan', 'ExclusiveSpan', 'ScanToFirst']
[2606]19def usesCarryInit1(builtin_fn):
20   return builtin_fn in ['ScanToFirst']
[2636]21def isAdvance(builtin_fn):
[2606]22   return builtin_fn in ['Advance']
23
[2636]24
25def CheckForBuiltin(fncall, builtin_fnmod_noprefix='pablo'):
[2612]26  if not isinstance(fncall, ast.Call): return None
27  if isinstance(fncall.func, ast.Name): fn_name = fncall.func.id
28  elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
29    if fncall.func.value.id == builtin_fnmod_noprefix: fn_name = fncall.func.attr
30    else: return None
31  else: return None
[2636]32  if isCarryGenerating(fn_name) or isAdvance(fn_name): return fn_name
[2612]33  else: return None
34
[2606]35def CarryCountOfFn(fncall, builtin_fnmod_noprefix='pablo'):
36  if not isinstance(fncall, ast.Call): return 0
37  if isinstance(fncall.func, ast.Name): fn_name = fncall.func.id
38  elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
39    if fncall.func.value.id == builtin_fnmod_noprefix: fn_name = fncall.func.attr
40    else: return 0
41  else: return 0
[2636]42  if isAdvance(fn_name):
[2606]43    if len(fncall.args) == 1: return 1
44    else: return 0 #  return fncall.args[1].n  # Possibly count Advance(m, n) as generating n carries.
45  elif isCarryGenerating(fn_name): return 1
46  else: return 0
47   
[1440]48def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='pablo'):
[813]49        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
50        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
51                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
[1864]52        return iscall and len(fncall.args) == builtin_arg_cnt
[753]53
[1864]54def dump_Call(fncall):
55        if isinstance(fncall.func, ast.Name): print "fn_name = %s\n" % fncall.func.id
56        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
57                print "fn_name = %s.%s\n" % (fncall.func.value.id, fncall.func.attr)
58        print "len(fncall.args) = %s\n" % len(fncall.args)
59
[770]60def is_simd_not(e):
61  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
62
[753]63def mkQname(obj, field):
64  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
65
66def mkCall(fn_name, args):
[1271]67  if isinstance(fn_name, str): 
68        if name_substitution_map.has_key(fn_name): fn_name = name_substitution_map[fn_name]
69        fn_name = ast.Name(fn_name, ast.Load())
[753]70  return ast.Call(fn_name, args, [], None, None)
71
72def mkCallStmt(fn_name, args):
73  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
74  return ast.Expr(ast.Call(fn_name, args, [], None, None))
[2636]75 
76 
77#
78# Carry Info Set
79#
80class CarryInfoSetVisitor(ast.NodeVisitor):
[2690]81  def __init__(self, streamFunctionNode):
82    (self.operation_count, self.block_count) = (0, 0)
83    (self.block_first_op, self.block_op_count, self.advance_amount) = ({}, {}, {})
[2636]84    self.init_one_list = []
[2691]85    (self.parent_block, self.children, self.whileblock) = ({}, {}, {})
[2690]86    (self.carry_count, self.adv_n_count, self.adv_1_count) = (0, 0, 0)
87    (self.total_advance, self.max_advance) = (0, 0)
[2688]88
[2690]89    # Initialize for the main block
90    self.block_no = 0
91    self.block_count += 1
92    self.children[0] = []
93    self.block_first_op[0] = 0
94    # Recursively process all blocks
95    self.generic_visit(streamFunctionNode)
96    self.block_op_count[0] = self.operation_count
97
[2636]98  def visit_Call(self, callnode):
99    self.generic_visit(callnode)
100    builtin = CheckForBuiltin(callnode)
101    if builtin == None: return
102    if isCarryGenerating(builtin): 
[2690]103      if usesCarryInit1(builtin): self.init_one_list.append(self.operation_count)
104      self.operation_count += 1
[2688]105      self.carry_count += 1
[2636]106    elif isAdvance(builtin):
107      if len(callnode.args) > 1:
108        adv_amount = callnode.args[1].n
[2688]109        self.adv_n_count += 1
110      else: 
111        adv_amount = 1
112        self.adv_1_count += 1
[2690]113      self.advance_amount[self.operation_count] = adv_amount
[2688]114      self.total_advance += adv_amount
115      if adv_amount > self.max_advance: self.max_advance = adv_amount
[2690]116      self.operation_count += 1
[2636]117    else: return
[753]118
[2690]119  def block_visit(self, blkNode):
120    prnt = self.block_no
121    this_block_no = self.block_count
122    self.block_count += 1
123    self.parent_block[this_block_no] = prnt
124    self.children[prnt].append(this_block_no)
[2691]125    self.whileblock = isinstance(blkNode, ast.While)
[2636]126    self.block_no = this_block_no
[2690]127    self.block_first_op[this_block_no] = self.operation_count
128    self.children[this_block_no] = []
129    self.generic_visit(blkNode)
130    self.block_op_count[this_block_no] = self.operation_count - self.block_first_op[this_block_no]
[2688]131    # reset for processing remainder of parent
[2636]132    self.block_no = self.parent_block[this_block_no]
[2690]133
134
135  def visit_If(self, ifNode): 
136    self.block_visit(ifNode)
[2636]137 
[2690]138  def visit_While(self, whileNode):   
139    self.block_visit(whileNode)
[2636]140
141  def countBlockCarrysWithAdv1(self, blk):
142    op_count = self.block_op_count[blk]
143    if op_count == 0: return 0
144    carries = 0
145    for op in range(self.block_first_op[blk], self.block_first_op[blk] + op_count):
146      if op not in self.advance_amount.keys(): carries += 1
147      elif self.advance_amount[op] == 1: carries += 1
148    return carries
149
[753]150#
[863]151# Reducing AugAssign, e.g.  x |= y becomes x = x | y
[771]152#
153class AugAssignRemoval(ast.NodeTransformer):
154  def xfrm(self, t):
155    return self.generic_visit(t)
156  def visit_AugAssign(self, e):
157    self.generic_visit(e)
158    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
[2208]159#
160# pablo.Advance(pablo.Advance(x, n)) => pablo.Advance(x, n+1)
161#
162class AdvanceCombiner(ast.NodeTransformer):
163  def xfrm(self, t):
164    return self.generic_visit(t)
165  def visit_if(self, ifNode):
166    return IfNode
167  def visit_While(self, whileNode):
168    return whileNode
169  def visit_Call(self, callnode):
170    self.generic_visit(callnode)
[2262]171    if len(callnode.args) == 0: return callnode
[2208]172    if not isinstance(callnode.args[0], ast.Call): return callnode
173    if is_BuiltIn_Call(callnode,'Advance', 1):
174        if is_BuiltIn_Call(callnode.args[0],'Advance', 1):
175          callnode.args = [callnode.args[0].args[0], ast.Num(2)]
176        elif is_BuiltIn_Call(callnode.args[0], 'Advance', 2):
177          if isinstance(callnode.args[0].args[1], ast.Num):
178            callnode.args = [callnode.args[0].args[0], ast.Num(callnode.args[0].args[1].n + 1)]
179          else:
180            callnode.args = [callnode.args[0].args[0], ast.BinOp(callnode.args[0].args[1], ast.Add(), ast.Num(1))]
181    return callnode
[771]182
[2009]183
[771]184#
[2009]185#  Translating pablo.match(marker, str)
186#  Incremental character match with lookahead
187#
188CharNameMap = {'[' : 'LBrak', ']' : 'RBrak', '{' : 'LBrace', '}' : 'LBrace', '(' : 'LParen', ')' : 'RParen', \
189              '!' : 'Exclam', '"' : 'DQuote', '#' : 'Hash', '$' : 'Dollar', '%' : 'PerCent', '&': 'RefStart', \
190              "'" : 'SQuote', '*': 'Star', '+' : 'Plus', ',' : 'Comma', '-' : 'Hyphen', '.' : 'Dot', '/' : 'Slash', \
191              ':' : 'Colon', ';' : 'Semicolon', '=' : 'Equals', '?' : 'QMark', '@' : 'AtSign', '\\' : 'BackSlash', \
192              '^' : 'Caret', '_' : 'Underscore', '|' : 'VBar', '~' : 'Tilde', ' ' : 'SP', '\t' : 'HT', '\m' : 'CR', '\n' : 'LF'}
193
194def GetCharName(char):
195        if char >= 'a' and char <= 'z' or char >= 'A' and char <= 'Z': return 'letter_' + char
196        elif char >= '0' and char <= '9': return 'digit_' + char
197        else: return CharNameMap[char]
198
199def MkCharStream(char):
200        return mkQname('lex', GetCharName(char))
201
202def MkLookAheadExpr(v, i):
203        return mkCall(mkQname('pablo', 'LookAhead'), [v, ast.Num(i)])
204
205def CompileMatch(match_var, string_to_match):
206        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
207        for i in range(1, len(string_to_match)):
208                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
209        return expr
210
211class StringMatchCompiler(ast.NodeTransformer):
212  def xfrm(self, t):
213    return self.generic_visit(t)
214  def visit_Call(self, callnode):
215    if is_BuiltIn_Call(callnode,'match', 2):
216        ast.dump(callnode)
217        assert isinstance(callnode.args[0], ast.Str)
218        string_to_match = callnode.args[0].s
219        match_var = callnode.args[1]
220        expr = mkCall('simd_and', [match_var, MkCharStream(string_to_match[0])])
221        for i in range(1, len(string_to_match)):
222                expr = mkCall('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
223        return expr
224    else: return callnode
225
226
227
228#
[2606]229#  Converting expressions involving built-ins to compiled form. 
[2631]230#  Apply before carry variable insertion.
[2606]231#
232class TempifyBuiltins(ast.NodeTransformer):
233    def __init__(self, tempVarpfx = "tempvar"):
234      self.tempVarCount = 0
235      self.newTempList = []
236      self.tempVarPrefix = tempVarpfx
237    def genVar(self):
238      newTemp = self.tempVarPrefix + repr(self.tempVarCount)
239      self.newTempList.append(newTemp)
240      self.tempVarCount += 1
241      return newTemp
242    def tempVars(self):
243      return self.newTempList
244    def xfrm(self, t):
245      self.setUpStmts = []
246      self.assigNode = None
247      return self.generic_visit(t)
248    def is_Assign_value(self, node):
249      return self.assigNode != None and self.assigNode.value == node
250    def visit_If(self, ifNode):
251      self.setUpStmts = []
[2630]252      self.generic_visit(ifNode.test)
253      ifSetUpStmts = self.setUpStmts
[2606]254      self.generic_visit(ifNode)
255      if ifSetUpStmts == []: return ifNode
256      else: return ifSetUpStmts + [ifNode]
257    def visit_While(self, whileNode):
258      self.setUpStmts = []
[2630]259      self.generic_visit(whileNode.test)
260      whileSetUpStmts = self.setUpStmts
[2606]261      self.generic_visit(whileNode)
262      whileNode.body = whileNode.body + whileSetUpStmts
263      return whileSetUpStmts + [whileNode]
264    def visit_Assign(self, node):
265      self.assigNode = node
266      self.setUpStmts = []
267      self.generic_visit(node)
[2630]268      return self.setUpStmts + [node]
[2631]269    def visit_AugAssign(self, node):
270      self.setUpStmts = []
271      self.generic_visit(node)
272      return self.setUpStmts + [node]
[2606]273    def visit_Call(self, callnode):     
274        self.generic_visit(callnode)
275        if CarryCountOfFn(callnode) > 0 and not self.is_Assign_value(callnode):
276        #if not self.is_Assign_value(callnode):
277            tempVar = ast.Name(self.genVar(), ast.Load())
278            self.setUpStmts.append(ast.Assign([tempVar], callnode))
279            return tempVar
280        else: return callnode
281
282
283
284
285
286#
[753]287# Introducing BitBlock logical operations
288#
289class Bitwise_to_SIMD(ast.NodeTransformer):
290  """
291  Make the following substitutions:
292     x & y => simd_and(x, y)
[770]293     x & ~y => simd_andc(x, y)
[753]294     x | y => simd_or(x, y)
295     x ^ y => simd_xor(x, y)
296     ~x    => simd_not(x)
297     0     => simd_const_1(0)
298     -1    => simd_const_1(1)
[1916]299     if x: => if bitblock::any(x):
300  while x: => while bitblock::any(x):
[753]301  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
302 
303  pfx = simd_and(bit0, bit1)
304  sfx = simd_and(bit0, simd_not(bit1))
305  >>>
306  """
307  def xfrm(self, t):
308    return self.generic_visit(t)
309  def visit_UnaryOp(self, t):
310    self.generic_visit(t)
311    if isinstance(t.op, ast.Invert):
312      return mkCall('simd_not', [t.operand])
313    else: return t
314  def visit_BinOp(self, t):
315    self.generic_visit(t)
316    if isinstance(t.op, ast.BitOr):
317      return mkCall('simd_or', [t.left, t.right])
318    elif isinstance(t.op, ast.BitAnd):
[770]319      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
320      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
321      else: return mkCall('simd_and', [t.left, t.right])
[753]322    elif isinstance(t.op, ast.BitXor):
323      return mkCall('simd_xor', [t.left, t.right])
324    else: return t
325  def visit_Num(self, numnode):
326    n = numnode.n
[1916]327    if n == 0: return mkCall('simd<1>::constant<0>', [])
328    elif n == -1: return mkCall('simd<1>::constant<1>', [])
[753]329    else: return numnode
330  def visit_If(self, ifNode):
331    self.generic_visit(ifNode)
[1916]332    ifNode.test = mkCall('bitblock::any', [ifNode.test])
[753]333    return ifNode
334  def visit_While(self, whileNode):
335    self.generic_visit(whileNode)
[1916]336    whileNode.test = mkCall('bitblock::any', [whileNode.test])
[753]337    return whileNode
338  def visit_Subscript(self, numnode):
339    return numnode  # no recursive modifications of index expressions
340
341#
342#  Generating BitBlock declarations for Local Variables
343#
[880]344class FunctionVars(ast.NodeVisitor):
345  def __init__(self,node):
346        self.params = []
347        self.stores = []
348        self.generic_visit(node)
[753]349  def visit_Name(self, nm):
350    if isinstance(nm.ctx, ast.Param):
351      self.params.append(nm.id)
352    if isinstance(nm.ctx, ast.Store):
353      if nm.id not in self.stores: self.stores.append(nm.id)
[880]354  def getLocals(self): 
[753]355    return [v for v in self.stores if not v in self.params]
356
357MAX_LINE_LENGTH = 80
358
359def BitBlock_decls_from_vars(varlist):
360  global MAX_LINE_LENGTH
[880]361  decls =  ""
362  if not len(varlist) == 0:
363          decls = "             BitBlock"
364          pending = ""
365          linelgth = 10
366          for v in varlist:
367            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
368              decls += pending + " " + v
369              linelgth += len(pending + v) + 1
370            else:
371              decls += ";\n             BitBlock " + v
372              linelgth = 11 + len(v)
373            pending = ","
374          decls += ";"
[753]375  return decls
376
377def BitBlock_decls_of_fn(fndef):
[880]378  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
[753]379
380def BitBlock_header_of_fn(fndef):
381  Ccode = "static inline void " + fndef.name + "("
382  pending = ""
383  for arg in fndef.args.args:
384    if isinstance(arg, ast.Name):
385      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
386      pending = ", "
[863]387  if CarryCounter().count(fndef) > 0:
388    Ccode += pending + " CarryQtype & carryQ"
[753]389  Ccode += ")"
390  return Ccode
391
392
[760]393
394#
395#  Stream Initialization Statement Extraction
396#
397#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
398class StreamInitializations(ast.NodeTransformer):
399  def xfrm(self, node):
400    self.stream_stmts = []
401    self.loop_post_inits = []
402    self.generic_visit(node)
403    return Cgen.py2C().gen(self.stream_stmts)
404  def visit_Assign(self, node):
405    if isinstance(node.value, ast.Num):
406      if node.value.n == 0: return node
407      elif node.value.n == -1: return node
408      else: 
409        stream_init = copy.deepcopy(node)
410        stream_init.value = mkCall('sisd_from_int', [node.value])
411        loop_init = copy.deepcopy(node)
412        loop_init.value.n = 0
413        self.stream_stmts.append(stream_init)
414        self.loop_post_inits.append(loop_init)
415        return None
416    else: return node
417  def visit_FunctionDef(self, node):
418    self.generic_visit(node)
419    node.body = node.body + self.loop_post_inits
420    return node
421
[2041]422
423
[753]424#
425# Carry Introduction Transformation
426#
427class CarryCounter(ast.NodeVisitor):
428  def visit_Call(self, callnode):
429    self.generic_visit(callnode)
[2606]430    self.carry_count += CarryCountOfFn(callnode)
[769]431  def visit_BinOp(self, exprnode):
[753]432    self.generic_visit(exprnode)
433    if isinstance(exprnode.op, ast.Sub):
434      self.carry_count += 1
[769]435    if isinstance(exprnode.op, ast.Add):
436      self.carry_count += 1
[753]437  def count(self, nodeToVisit):
438    self.carry_count = 0
439    self.generic_visit(nodeToVisit)
440    return self.carry_count
441
[2260]442#
443# Carry Initialization:  Aug. 4, 2012
444# - Carry variables are initialized to 0 by default
445# - However, the scan_to_first routine should ideally use
446#   initialization with 1.
447
448#
449class CarryInitToOneList(ast.NodeVisitor):
450  def visit_Call(self, callnode):
451    self.generic_visit(callnode)
452    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2) or is_BuiltIn_Call(callnode,'ScanTo', 2) or is_BuiltIn_Call(callnode,'AdvanceThenScanThru', 2) or is_BuiltIn_Call(callnode,'AdvanceThenScanTo', 2) or is_BuiltIn_Call(callnode,'SpanUpTo', 2) or  is_BuiltIn_Call(callnode,'InclusiveSpan', 2) or is_BuiltIn_Call(callnode,'ExclusiveSpan', 2):       
453      self.carry_count += 1
454    elif is_BuiltIn_Call(callnode,'ScanToFirst', 1):
455      self.init_to_one_list.append(self.carry_count)
456      self.carry_count += 1
457  def visit_BinOp(self, exprnode):
458    self.generic_visit(exprnode)
459    if isinstance(exprnode.op, ast.Sub):
460      self.carry_count += 1
461    if isinstance(exprnode.op, ast.Add):
462      self.carry_count += 1
463  def count(self, nodeToVisit):
464    self.carry_count = 0
465    self.init_to_one_list = []
466    self.generic_visit(nodeToVisit)
467    return self.init_to_one_list
468
[2206]469class adv_nCounter(ast.NodeVisitor):
[1211]470  def visit_Call(self, callnode):
471    self.generic_visit(callnode)
472    if is_BuiltIn_Call(callnode,'Advance32', 1):       
[2206]473      self.adv_n_count += 1
474    if is_BuiltIn_Call(callnode,'Advance', 2):         
475      self.adv_n_count += 1
[1211]476  def count(self, nodeToVisit):
[2206]477    self.adv_n_count = 0
[1211]478    self.generic_visit(nodeToVisit)
[2206]479    return self.adv_n_count
[1211]480
[2689]481#
482# Base CCGO Class is also the null class, suitable for stream
483# functions that have no carry-generating operations.
484#
485class CCGO:
486    def __init__(self): pass
487    def GenerateCarryDecls(self): return ""
488    def GenerateInitializations(self): return ""
489    def GenerateStreamFunctionDecls(self):  return ""
490    def GenerateCarryInAccess(self, operation_no): return None
491    def GenerateCarryOutStore(self, operation_no, carry_out_expr): return []
492    def GenerateCarryIfTest(self, block_no, ifTest): return ifTest
493    def GenerateCarryElseFinalization(self, block_no): return []
494    def GenerateLocalDeclare(self, block_no): return []
495    def GenerateCarryWhileTest(self, block_no, testExpr): return testExpr
496    def EnterLocalWhileBlock(self, operation_offset):  pass
497    def ExitLocalWhileBlock(self):  pass
498    def GenerateCarryWhileFinalization(self, block_no): return []
499    def GenerateStreamFunctionFinalization(self): return []
[2612]500
[2689]501class testCCGO(CCGO):
[2638]502    def __init__(self, carryInfoSet, carryGroupVarName='carryQ'):
[2636]503        self.carryInfoSet = carryInfoSet
[2638]504        self.carryGroupVar = carryGroupVarName
[2636]505        self.carryIndex = {}
506        self.operation_offset = 0
507        carry_counter = 0
[2690]508        for op_no in range(carryInfoSet.operation_count):
[2636]509          self.carryIndex[op_no] = carry_counter
510          if not op_no in carryInfoSet.advance_amount.keys(): carry_counter += 1
511          elif carryInfoSet.advance_amount[op_no] == 1: carry_counter += 1
512        # Add a dummy entry for any possible final block that is empty.
[2690]513        self.carryIndex[carryInfoSet.operation_count] = carry_counter
[2640]514    # Helper
[2638]515    def CarryGroupAtt(self, attname, CarryGroupVarName=""):
516        if CarryGroupVarName == '': CarryGroupVarName = self.carryGroupVar
517        return ast.Attribute(ast.Name(CarryGroupVarName, ast.Load()), attname, ast.Load())
[2640]518    def GenerateCarryDecls(self):
519        carry_counter = 0
520        adv_n_counter = 0
521        for op_no in range(self.carryInfoSet.block_op_count[0]):
522          if not op_no in self.carryInfoSet.advance_amount.keys(): carry_counter += 1
523          elif self.carryInfoSet.advance_amount[op_no] == 1: carry_counter += 1
524          else: adv_n_counter += 1
525        return "CarryArray<%i, %i> %s;" % (carry_counter, adv_n_counter, self.carryGroupVar)
[2688]526    def GenerateInitializations(self):
527        carry_counter = 0
528        adv_n_counter = 0
529        inits = ""
530        for op_no in range(self.carryInfoSet.block_op_count[0]):
531          if op_no in self.carryInfoSet.init_one_list: inits += "carryQ.cq[%s] = carryQ.carry_flip(carryQ.cq[%s]);\n" % (carry_counter, carry_counter)
532          if not op_no in self.carryInfoSet.advance_amount.keys(): carry_counter += 1
533          elif self.carryInfoSet.advance_amount[op_no] == 1: carry_counter += 1
534          else: adv_n_counter += 1
535        return inits
536    def GenerateStreamFunctionDecls(self):  return ""
[2636]537    def GenerateCarryInAccess(self, operation_no):
538        carry_index = self.carryIndex[operation_no - self.operation_offset]
539        return mkCall(self.carryGroupVar + "." + 'get_carry_in', [ast.Num(carry_index)])
540    def GenerateCarryOutStore(self, operation_no, carry_out_expr):
541        carry_index = self.carryIndex[operation_no - self.operation_offset]
[2689]542        return [ast.Assign([ast.Subscript(self.CarryGroupAtt('cq'), ast.Index(ast.Num(carry_index)), ast.Store())], 
543                          mkCall("bitblock::srli<127>", [carry_out_expr]))]
[2636]544    def GenerateCarryIfTest(self, block_no, ifTest):
545        carry_count = self.carryInfoSet.block_op_count[block_no]
546        if carry_count == 0: return ifTest
547        ifIndex = self.carryIndex[self.carryInfoSet.block_first_op[block_no]]       
[2638]548        return ast.BoolOp(ast.Or(), [ifTest, mkCall(self.CarryGroupAtt('CarryTest'), [ast.Num(ifIndex), ast.Num(carry_count)])])
[2637]549    def GenerateCarryElseFinalization(self, block_no): 
550        carry_count = self.carryInfoSet.block_op_count[block_no]
551        if carry_count == 0: return []
552        ifIndex = self.carryIndex[self.carryInfoSet.block_first_op[block_no]]       
[2638]553        return [mkCallStmt(self.CarryGroupAtt('CarryDequeueEnqueue'), [ast.Num(ifIndex), ast.Num(carry_count)])]
[2640]554    def GenerateLocalDeclare(self, block_no):
555        local_carryvar = ast.Name("sub" + self.carryGroupVar, ast.Load())
[2689]556        return [mkCallStmt('LocalCarryDeclare', [local_carryvar, ast.Num(self.carryInfoSet.block_op_count[block_no])])]
[2640]557    def GenerateCarryWhileTest(self, block_no, testExpr):
558        carry_count = self.carryInfoSet.block_op_count[block_no]
559        if carry_count == 0: return testExpr
560        carry0 = self.carryIndex[self.carryInfoSet.block_first_op[block_no]]       
561        return ast.BoolOp(ast.Or(), [testExpr, mkCall(self.CarryGroupAtt('CarryTest'), [ast.Num(carry0), ast.Num(carry_count)])])
[2636]562    def EnterLocalWhileBlock(self, operation_offset): 
563        self.carryGroupVar = "sub" + self.carryGroupVar
564        self.operation_offset = operation_offset
565    def ExitLocalWhileBlock(self): 
566        self.operation_offset = 0
567        self.carryGroupVar = self.carryGroupVar[3:]
[2638]568    def GenerateCarryWhileFinalization(self, block_no): 
569        carry_count = self.carryInfoSet.block_op_count[block_no]
570        if carry_count == 0: return []
571        loopIndex = self.carryIndex[self.carryInfoSet.block_first_op[block_no]]       
[2689]572        return [mkCallStmt(self.CarryGroupAtt('CarryCombine'), [self.CarryGroupAtt('cq', "sub" + self.carryGroupVar), ast.Num(loopIndex), ast.Num(carry_count)])]
[2636]573    def GenerateStreamFunctionFinalization(self):
574        carry_count = self.carryInfoSet.countBlockCarrysWithAdv1(0)
575        if carry_count == 0: return []
[2638]576        else: return [mkCallStmt(self.CarryGroupAtt('CarryQ_Adjust'), [ast.Num(carry_count)])]
[2612]577
[2689]578
[2636]579def Strategic_CCGO_Factory(carryInfoSet):
580    ccgo = testCCGO(carryInfoSet, 'carryQ')
581    return ccgo
[2612]582
[2636]583
[753]584class CarryIntro(ast.NodeTransformer):
[2636]585  def __init__(self, ccgo, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
[753]586    self.carryvar = ast.Name(carryvar, ast.Load())
[921]587    self.carryin = carryin
588    self.carryout = carryout
[2636]589    self.ccgo = ccgo
[753]590  def xfrm_fndef(self, fndef):
[2636]591    self.block_no = 0
[2690]592    self.operation_count = 0
[753]593    self.current_carry = 0
[2206]594    self.current_adv_n = 0
[753]595    carry_count = CarryCounter().count(fndef)
596    if carry_count == 0: return fndef
597    self.generic_visit(fndef)
[2628]598  def xfrm_fndef_final(self, fndef):
[2636]599    self.block_no = 0
[2690]600    self.operation_count = 0
[2628]601    self.carryout = ""
602    self.current_carry = 0
603    self.current_adv_n = 0
604    carry_count = CarryCounter().count(fndef)
605    if carry_count == 0: return fndef
606    self.generic_visit(fndef)
[762]607#   
608#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
[753]609    return fndef
610  def generic_xfrm(self, node):
[2636]611    self.block_no = 0
[2690]612    self.operation_count = 0
[753]613    self.current_carry = 0
[2206]614    self.current_adv_n = 0
[2264]615    self.last_stmt = None
[2265]616    self.last_stmt_carries = 0
[753]617    carry_count = CarryCounter().count(node)
[2206]618    adv_n_count = adv_nCounter().count(node)
619    if carry_count == 0 and adv_n_count == 0: return node
[753]620    self.generic_visit(node)
621    return node
[2631]622   
623  def local_while_xfrm(self, local_carryvar, whileNode):
[2690]624    saved_state = (self.block_no, self.operation_count, self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n)
[2636]625    (self.carryvar, self.carryin, self.current_carry, self.current_adv_n) = (local_carryvar, '', 0, 0)
[2690]626    self.ccgo.EnterLocalWhileBlock(self.operation_count);
[2631]627    inner_while = self.generic_visit(whileNode)
[2636]628    self.ccgo.ExitLocalWhileBlock();
[2690]629    (self.block_no, self.operation_count, self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n) = saved_state
[2631]630    return inner_while
631   
[753]632  def visit_Call(self, callnode):
633    self.generic_visit(callnode)
[1571]634    #CARRYSET
[1997]635    #carry_args = [ast.Num(self.current_carry)]
[2206]636    #adv_n_args = [ast.Subscript(ast.Name(self.carryvar.id + '.pending64', ast.Load()), ast.Num(self.current_adv_n), ast.Load())]
[2221]637    #adv_n_pending = ast.Subscript(ast.Name(self.carryvar.id + '.pending64', ast.Load()), ast.Num(self.current_adv_n), ast.Load())
[1997]638    if self.carryin == "_ci":
639        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
[2221]640        adv_n_args = [mkCall(self.carryvar.id + "." + 'get_pending64', [ast.Num(self.current_adv_n)]), ast.Num(self.current_adv_n)]
[1997]641    else: 
642        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
[2206]643        adv_n_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_adv_n)]
[1997]644
[2206]645    if is_BuiltIn_Call(callnode, 'Advance', 2):         
646      #CARRYSET
647      rtn = self.carryvar.id + "." + "BitBlock_advance_n_<%i>" % callnode.args[1].n
648      c = mkCall(rtn, [callnode.args[0]] + adv_n_args)
649      self.current_adv_n += 1
650      return c
[813]651    if is_BuiltIn_Call(callnode, 'Advance', 1):         
[1571]652      #CARRYSET
[1997]653      rtn = self.carryvar.id + "." + "BitBlock_advance_ci_co"
[924]654      c = mkCall(rtn, callnode.args + carry_args)
[753]655      self.current_carry += 1
656      return c
[1211]657    elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
[1571]658      #CARRYSET
[2206]659      rtn = self.carryvar.id + "." + "BitBlock_advance_n_<32>"
660      c = mkCall(rtn, callnode.args + adv_n_args)
661      self.current_adv_n += 1
[1211]662      return c
[813]663    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
[1571]664      #CARRYSET
[1997]665      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co"
[924]666      c = mkCall(rtn, callnode.args + carry_args)
[753]667      self.current_carry += 1
668      return c
[2041]669    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanThru', 2):
670      #CARRYSET
671      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru"
672      c = mkCall(rtn, callnode.args + carry_args)
673      self.current_carry += 1
674      return c
675    elif is_BuiltIn_Call(callnode, 'AdvanceThenScanTo', 2):
676      #CARRYSET
677      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru" 
678      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
679      else: scanclass = mkCall('simd_not', [callnode.args[1]])
[2049]680      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
[2041]681      self.current_carry += 1
682      return c
683    elif is_BuiltIn_Call(callnode, 'SpanUpTo', 2):
684      #CARRYSET
685      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
686      c = mkCall(rtn, callnode.args + carry_args)
687      self.current_carry += 1
688      return c
689    elif is_BuiltIn_Call(callnode, 'InclusiveSpan', 2):
690      #CARRYSET
[2049]691#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
692#      c = mkCall('simd_or', [mkCall(rtn, callnode.args + carry_args), callnode.args[1]])
693      rtn = self.carryvar.id + "." + "BitBlock_inclusive_span"
694      c = mkCall(rtn, callnode.args + carry_args)
[2041]695      self.current_carry += 1
696      return c
697    elif is_BuiltIn_Call(callnode, 'ExclusiveSpan', 2):
698      #CARRYSET
[2049]699#      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
700#      c = mkCall('simd_andc', [mkCall(rtn, callnode.args + carry_args), callnode.args[0]])
701      rtn = self.carryvar.id + "." + "BitBlock_exclusive_span"
702      c = mkCall(rtn, callnode.args + carry_args)
[2041]703      self.current_carry += 1
704      return c
[902]705    elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
[1520]706      # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
707      # in having a separate BitBlock_scanto routine.
[1571]708      #CARRYSET
[1997]709      rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co" 
[1520]710      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
711      else: scanclass = mkCall('simd_not', [callnode.args[1]])
712      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
[902]713      self.current_carry += 1
714      return c
[1074]715    elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
[1571]716      #CARRYSET
717      rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
[1074]718      #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
719      c = mkCall(rtn, callnode.args + carry_args)
720      self.current_carry += 1
721      return c
[1439]722    elif is_BuiltIn_Call(callnode, 'atEOF', 1):
723      if self.carryout != "": 
724        # Non final block: atEOF(x) = 0.
[1916]725        return mkCall('simd<1>::constant<0>', [])
[1439]726      else: return mkCall('simd_andc', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
727    elif is_BuiltIn_Call(callnode, 'inFile', 1):
728      if self.carryout != "": 
729        # Non final block: inFile(x) = x.
730        return callnode.args[0]
731      else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
[822]732    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
733      rtn = "StreamScan"           
734      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
735                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
736                                           ast.Name(callnode.args[1].id, ast.Load())])
737      return c
[1864]738    else:
[1865]739      #dump_Call(callnode)
[1864]740      return callnode
[765]741  def visit_BinOp(self, exprnode):
[753]742    self.generic_visit(exprnode)
[1571]743    carry_args = [ast.Num(self.current_carry)]
[2003]744    if self.carryin == "_ci":
745        carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
746    else: 
747        carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
[753]748    if isinstance(exprnode.op, ast.Sub):
[1571]749      #CARRYSET
[2003]750      rtn = self.carryvar.id + "." + "BitBlock_sub_ci_co"
[924]751      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
[753]752      self.current_carry += 1
753      return c
754    elif isinstance(exprnode.op, ast.Add):
[1571]755      #CARRYSET
[2004]756      rtn = self.carryvar.id + "." + "BitBlock_add_ci_co"
[924]757      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
[753]758      self.current_carry += 1
759      return c
760    else: return exprnode
[2612]761  def visit_Assign(self, assigNode):
762    self.last_stmt_carries = CarryCounter().count(assigNode)
[2636]763    f = CheckForBuiltin(assigNode.value)
764    if f == None or not experimentalMode: 
[2612]765            self.generic_visit(assigNode)
766            self.last_stmt = assigNode
767            return assigNode
[2636]768    elif isCarryGenerating(f) or (isAdvance(f) and ((len(assigNode.value.args) == 1) or (assigNode.value.args[1].n==1))):
[2612]769    # We have an assignment v = pablo.SomeCarryGeneratingFunction()
770    #elif f == 'ScanThru':
771            if self.carryin == "_ci":
[2690]772                carry_in_expr = self.ccgo.GenerateCarryInAccess(self.operation_count)
[2612]773            else: 
774                carry_in_expr = mkCall('simd<1>::constant<0>', [])
775            callnode = assigNode.value
[2636]776            if isAdvance(f):
777               pablo_routine_call = mkCall('pablo_blk_' + f, [assigNode.value.args[0], carry_in_expr, assigNode.targets[0]])
778            elif f in ['ScanTo', 'AdvanceThenScanTo']:
[2612]779               if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
780               else: scanclass = mkCall('simd_not', [callnode.args[1]])
781               pablo_routine_call = mkCall('pablo_blk_' +f[:-2] + 'Thru', [callnode.args[0], scanclass, carry_in_expr, assigNode.targets[0]])
782            else:
783               pablo_routine_call = mkCall('pablo_blk_' + f, assigNode.value.args + [carry_in_expr, assigNode.targets[0]])
784            self.last_stmt = pablo_routine_call
[2690]785            compiled = self.ccgo.GenerateCarryOutStore(self.operation_count, pablo_routine_call)
786            self.operation_count += 1
[2612]787            self.current_carry += 1
788            return compiled
789    else:
790            self.generic_visit(assigNode)
791            self.last_stmt = assigNode
[2690]792            self.operation_count += 1
[2612]793            return assigNode
[2636]794           
795           
[753]796  def visit_If(self, ifNode):
[2636]797    self.block_no += 1
[2637]798    this_block = self.block_no
[753]799    carry_base = self.current_carry
800    carries = CarryCounter().count(ifNode)
[2206]801    assert adv_nCounter().count(ifNode) == 0, "Advance(x,n) within if: illegal\n"
[753]802    self.generic_visit(ifNode)
[2264]803    if carries == 0 or self.carryin == "": 
804      self.last_stmt = ifNode
805      return ifNode
[1571]806    #CARRYSET
807    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
[2636]808    #new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
[2637]809    #new_else_part = ifNode.orelse + [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]
810    new_test = self.ccgo.GenerateCarryIfTest(this_block, ifNode.test)
811    new_else_part = ifNode.orelse + self.ccgo.GenerateCarryElseFinalization(this_block)
[2264]812    newIf = ast.If(new_test, ifNode.body, new_else_part)
813    self.last_stmt = newIf
[2265]814    self.last_stmt_carries = carries
[2264]815    return newIf
816
[2265]817
818  def is_while_special_case(self, whileNode):
819    #
820    # Special case optimization for pattern:
821    #   m=pablo.scan...()
822    #   while m:
823    #      S
824    #      m=pablo.scan...()
825    #
826    # Determine the original test expression, now encloded in bitblock::any()
827    original_test_expr = whileNode.test.args[0]
828    if not isinstance(original_test_expr, ast.Name): return False
829    test_var = original_test_expr.id
830    if not isinstance(self.last_stmt, ast.Assign): return False
831    if not isinstance(whileNode.body[-1], ast.Assign): return False
832    if len(self.last_stmt.targets) != 1: return False
833    if len(whileNode.body[-1].targets) != 1: return False
834    if not isinstance(self.last_stmt.targets[0], ast.Name): return False
835    if not isinstance(whileNode.body[-1].targets[0], ast.Name): return False
836    if self.last_stmt.targets[0].id != test_var: return False
837    if whileNode.body[-1].targets[0].id != test_var: return False
[2269]838    if self.last_stmt_carries != 1: return False
[2265]839    if CarryCounter().count(whileNode.body[-1]) != 1: return False
840    return True
841
[753]842  def visit_While(self, whileNode):
[2265]843    # Determine the original test expression, now encloded in bitblock::any()
[2636]844    self.block_no += 1
[2638]845    this_block = self.block_no
[2265]846    original_test_expr = whileNode.test.args[0]
[941]847    if self.carryout == '':
[2265]848      whileNode.test.args[0] = mkCall("simd_and", [original_test_expr, ast.Name('EOF_mask', ast.Load())])
[753]849    carry_base = self.current_carry
[2206]850    assert adv_nCounter().count(whileNode) == 0, "Advance(x,n) within while: illegal\n"
[753]851    carries = CarryCounter().count(whileNode)
[2269]852#   Special Case Recognition
853    is_special = self.is_while_special_case(whileNode)
[1571]854    #CARRYSET
[753]855    if carries == 0: return whileNode
[2628]856    local_carryvar = ast.Name('sub' + self.carryvar.id, ast.Load())
[2631]857    inner_while = self.local_while_xfrm(local_carryvar, copy.deepcopy(whileNode))
[753]858    self.generic_visit(whileNode)
[2640]859    local_carry_decl = self.ccgo.GenerateLocalDeclare(this_block)
860    #local_carry_decl = mkCallStmt('LocalCarryDeclare', [local_carryvar, ast.Num(carries)])
[2689]861    inner_while.body = local_carry_decl + inner_while.body
[2638]862    #final_combine = mkCallStmt(ast.Attribute(self.carryvar, 'CarryCombine', ast.Load()), [ast.Attribute(local_carryvar, 'cq', ast.Load()),ast.Num(carry_base), ast.Num(carries)])
863    final_combine = self.ccgo.GenerateCarryWhileFinalization(this_block)
[2689]864    inner_while.body += final_combine
[1571]865    #CARRYSET
[2269]866
867#   Special Case Optimization
868    if is_special:
869      # We combine the final carry into the one preceeding the loop.
870      combine1 = mkCallStmt(ast.Attribute(self.carryvar, 'CarryCombine1', ast.Load()), [ast.Num(carry_base-1), ast.Num(carry_base+carries-1)])
871      while_body_extend = [inner_while, combine1]
872      # The carry test can skip the final case.
873      carry_test_arglist = [ast.Num(carry_base), ast.Num(carries-1)]
874    else: 
875      carry_test_arglist = [ast.Num(carry_base), ast.Num(carries)]
876      while_body_extend = [inner_while]
877
[921]878    if self.carryin == '': new_test = whileNode.test
[2640]879    else: new_test = self.ccgo.GenerateCarryWhileTest(this_block, whileNode.test)
880    else_part = [self.ccgo.GenerateCarryElseFinalization(this_block)]   
[2269]881    newIf = ast.If(new_test, whileNode.body + while_body_extend, else_part)
[2264]882    self.last_stmt = newIf
[2265]883    self.last_stmt_carries = carries
[2264]884    return newIf
[753]885
886class StreamStructGen(ast.NodeVisitor):
887  """
888  Given a BitStreamSet subclass, generate the equivalent C struct.
889  >>> obj = ast.parse(r'''
890  ... class S1(BitStreamSet):
891  ...   a1 = 0
892  ...   a2 = 0
893  ...   a3 = 0
894  ...
895  ... class S2(BitStreamSet):
896  ...   x1 = 0
897  ...   x2 = 0
898  ... ''')
899  >>> print StreamStructGen().gen(obj)
900  struct S1 {
901    BitBlock a1;
902    BitBlock a2;
903    BitBlock a3;
[2206]904  }    self.current_adv_n = 0
[1211]905
[753]906 
907  struct S2 {
908    BitBlock x1;
909    BitBlock x2;
910  }
911  """
[857]912  def __init__(self, asType=False):
913    self.asType = asType
[753]914  def gen(self, tree):
915    self.Ccode=""
916    self.generic_visit(tree)
917    return self.Ccode
[865]918  def gen_struct_types(self, tree):
919    self.asType = True
920    self.Ccode=""
921    self.generic_visit(tree)
922    return self.Ccode
923  def gen_struct_vars(self, tree):
924    self.asType = False
925    self.Ccode=""
926    self.generic_visit(tree)
927    return self.Ccode
[753]928  def visit_ClassDef(self, node):
[857]929    class_name = node.name[0].upper() + node.name[1:]
930    instance_name = node.name[0].lower() + node.name[1:]
[880]931    self.Ccode += "  struct " + class_name
[865]932    if self.asType:
933            self.Ccode += " {\n"
934            for stmt in node.body:
935              if isinstance(stmt, ast.Assign):
936                for v in stmt.targets:
937                  if isinstance(v, ast.Name):
938                    self.Ccode += "  BitBlock " + v.id + ";\n"
939            self.Ccode += "}" 
940    else: self.Ccode += " " + instance_name
[857]941    self.Ccode += ";\n\n"
[753]942 
943class StreamFunctionDecl(ast.NodeVisitor):
944  def __init__(self):
945    pass
946  def gen(self, tree):
947    self.Ccode=""
948    self.generic_visit(tree)
949    return self.Ccode
950  def visit_FunctionDef(self, node):
951    self.Ccode += "static inline void " + node.name + "("
952    pending = ""
953    for arg in node.args.args:
954      if isinstance(arg, ast.Name):
955        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
956        pending = ", "
957    self.Ccode += ");\n"
958
[2099]959class AssertCompiler(ast.NodeTransformer):
960  def __init__(self):
961    self.assert_routine = ast.parse(error_routine).body[0].value
962  def xfrm(self, t):
963    return self.generic_visit(t)
964  def visit_Expr(self, node):
965    if isinstance(node.value, ast.Call):
966        if is_BuiltIn_Call(node.value, "assert_0", 2):
967                err_stream = node.value.args[0]
968                err_code = node.value.args[1]
[2640]969                return ast.If(mkCall('bitblock::any', [err_stream]),
[2100]970                              [ast.Expr(mkCall(self.assert_routine, [err_code, err_stream]))],
[2099]971                              [])
972        else: return node
973    else: return node
974       
[810]975#
976# Adding Debugging Statements
977#
978class Add_SIMD_Register_Dump(ast.NodeTransformer):
979  def xfrm(self, t):
980    return self.generic_visit(t)
981  def visit_Assign(self, t):
982    self.generic_visit(t)
983    v = t.targets[0]
[1843]984    dump_stmt = mkCallStmt(' print_register<BitBlock>', [ast.Str(Cgen.py2C().gen(v)), v])
[810]985    return [t, dump_stmt]
[880]986
[1901]987#
988# Adding ASSERT_BITBLOCK_ALIGN macros
989#
990class Add_Assert_BitBlock_Align(ast.NodeTransformer):
991    def xfrm(self, t):
992      return self.generic_visit(t)
993    def visit_Assign(self, t):
994      self.generic_visit(t)
995      v = t.targets[0]
996      dump_stmt = mkCallStmt(' ASSERT_BITBLOCK_ALIGN', [v])
997      return [t, dump_stmt]
998
[880]999class StreamFunctionCarryCounter(ast.NodeVisitor):
1000  def __init__(self):
1001        self.carry_count = {}
1002       
1003  def count(self, node):
1004        self.generic_visit(node)
1005        return self.carry_count
1006                                   
1007  def visit_FunctionDef(self, node):   
1008        type_name = node.name[0].upper() + node.name[1:]                       
1009        self.carry_count[type_name] = CarryCounter().count(node)
1010     
1011class StreamFunctionCallXlator(ast.NodeTransformer):
[937]1012  def __init__(self, xlate_type="normal"):
[880]1013        self.stream_function_type_names = []
[937]1014        self.xlate_type = xlate_type
1015
[1195]1016  def xfrm(self, node, stream_function_type_names, C_syntax):
[880]1017        self.stream_function_type_names = stream_function_type_names
[1195]1018        self.C_syntax = C_syntax
[880]1019        self.generic_visit(node)
1020       
1021  def visit_Call(self, node):   
1022        self.generic_visit(node)
1023
[1957]1024        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):# and node.func.id in self.stream_function_type_names:
[1195]1025             name = lower1(node.func.id)
1026             node.func.id = name + ("_" if self.C_syntax else ".") + ("do_final_block" if self.xlate_type == "final" else "do_block")
1027             if self.C_syntax:
[1196]1028                     node.args = [ast.Name(lower1(name), ast.Load())] + node.args
[937]1029             if self.xlate_type == "final":
1030                   node.args = node.args + [ast.Name("EOF_mask", ast.Load())]
[1195]1031                     
[880]1032        return node     
1033               
1034class StreamFunctionVisitor(ast.NodeVisitor):
1035        def __init__(self,node):
1036                self.stream_function_node = {}
1037                self.generic_visit(node)
1038                                                             
1039        def visit_FunctionDef(self, node):                     
1040                key = node.name[0].upper() + node.name[1:]
1041                self.stream_function_node[key] = node
[2260]1042
1043
[880]1044               
1045class StreamFunction():
1046        def __init__(self):
1047                self.carry_count = 0 
[2260]1048                self.init_to_one_list = [] 
[2206]1049                self.adv_n_count = 0 
[880]1050                self.type_name = ""
1051                self.instance_name = "" 
1052                self.parameters = ""
1053                self.declarations = "" 
1054                self.initializations = "" 
[2393]1055       
1056        def dump(self):
[2408]1057                print "%s" % (self.type_name)
1058                print "%s=%s" % ("Carry Count", str(self.carry_count))
1059                print "%s=[%s]" % ("Init to One List" , ','.join(map(str,self.init_to_one_list)))
1060                print "%s=%s" % ("Adv n Count", str(self.adv_n_count)) 
1061        #print "Instance Name:"     #+ self.instance_name = ""
1062        #print "Parameters:"        #+ self.parameters = ""
1063        #print "Declarations:"           #+ self.declarations = ""
1064        #print "Initializations:"   # + self.initializations = ""
1065   
[813]1066#
[903]1067# TODO Consolidate *all* C code generation into the Emitter class.   Medium priority.
1068# TODO Implement 'pretty print' indentation.   Low priority.
1069# TODO Migrate Emiter() class to Emitter module.  Medium priority.
[908]1070
1071def lower1(name):
1072    return name[0].lower() + name[1:]
1073def upper1(name):
1074    return name[0].upper() + name[1:]
1075
[1196]1076def escape_newlines(str):
1077  return str.replace('\n', '\\\n')
1078
[880]1079class Emitter():
[2640]1080        def __init__(self, use_C_syntax, ccgo):
[1195]1081                self.use_C_syntax = use_C_syntax
[2640]1082                self.ccgo = ccgo
[753]1083
[880]1084        def definition(self, stream_function, icount=0):
1085               
1086                constructor = ""
1087                carry_declaration = ""
[1195]1088                self.type_name = stream_function.type_name
[880]1089               
[2206]1090                if stream_function.carry_count > 0 or stream_function.adv_n_count > 0:
[2260]1091                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.init_to_one_list, stream_function.adv_n_count)
[2206]1092                        carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv_n_count)
[753]1093
[880]1094                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), 
1095                                                stream_function.declarations, 
1096                                                stream_function.initializations, 
1097                                                stream_function.statements)             
[906]1098
[922]1099                do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), 
1100                                                stream_function.declarations, 
1101                                                stream_function.initializations, 
1102                                                stream_function.final_block_statements)                 
1103
[906]1104                do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), 
1105                                                self.do_segment_args(stream_function.parameters))       
1106
[1195]1107                if self.use_C_syntax:
[1196]1108                        return self.indent(icount) + "struct " + stream_function.type_name + " {" \
[1195]1109                               + "\n" + self.indent(icount) + carry_declaration \
1110                               + "\n" + self.indent(icount) + "};\n" \
1111                               + "\n" + self.indent(icount) + do_block_function \
1112                               + "\n" + self.indent(icount) + do_final_block_function \
1113                               + "\n" + self.indent(icount) + do_segment_function + "\n\n"
1114                               
[880]1115                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
1116                + "\n" + self.indent(icount) + constructor \
1117                + "\n" + self.indent(icount) + do_block_function \
[922]1118                + "\n" + self.indent(icount) + do_final_block_function \
[906]1119                + "\n" + self.indent(icount) + do_segment_function \
[880]1120                + "\n" + self.indent(icount) + carry_declaration \
1121                + "\n" + self.indent(icount) + "};\n\n"
1122
[2260]1123        def constructor(self, type_name, carry_count, init_to_one_list, adv_n_count, icount=0):
[2688]1124                one_inits = self.ccgo.GenerateInitializations()
1125                #one_inits = ""
1126                #for v in init_to_one_list:
1127                #       one_inits += "  carryQ.cq[%s] = carryQ.carry_flip(carryQ.cq[%s]);\n" % (v, v)
[2206]1128                adv_n_decl = ""
1129                #for i in range(adv_n_count): adv_n_decl += self.indent(icount+2) + "pending64[%s] = simd<1>::constant<0>();\n" % i     
[2260]1130                return self.indent(icount) + "%s() { ""\n" % (type_name) + adv_n_decl + self.carry_init(carry_count) + one_inits + " }" 
[880]1131                       
1132        def do_block(self, parameters, declarations, initializations, statements, icount=0):
[1195]1133                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
[1196]1134                if self.use_C_syntax:
1135                        return "#define " + pfx + "do_block(" + parameters + ")\\\n do {" \
1136                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
1137                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
1138                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
1139                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
[1997]1140                return self.indent(icount) + do_block_inline_decorator + "void " + pfx + "do_block(" + parameters + ") {" \
[1195]1141                + "\n" + self.indent(icount) + declarations \
[880]1142                + "\n" + self.indent(icount) + initializations \
1143                + "\n" + self.indent(icount) + statements \
1144                + "\n" + self.indent(icount + 2) + "}" 
1145
[1196]1146
1147
1148
[922]1149        def do_final_block(self, parameters, declarations, initializations, statements, icount=0):
[1195]1150                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
[1196]1151                if self.use_C_syntax:
1152                        return "#define " + pfx + "do_final_block(" + parameters + ")\\\n do {" \
1153                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
1154                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
1155                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
1156                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
[1997]1157                return self.indent(icount) + do_final_block_inline_decorator + "void " + pfx + "do_final_block(" + parameters + ") {" \
[1195]1158                + "\n" + self.indent(icount) + declarations \
[922]1159                + "\n" + self.indent(icount) + initializations \
1160                + "\n" + self.indent(icount) + statements \
1161                + "\n" + self.indent(icount + 2) + "}" 
1162
[906]1163        def do_segment(self, parameters, do_block_call_args, icount=0):
[1195]1164                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
[1196]1165                if self.use_C_syntax:
1166                        return "#define " + pfx + "do_segment(" + parameters + ")\\\n do {" \
1167                        + "\\\n" + self.indent(icount) + "  int i;" \
[1882]1168                        + "\\\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
[1196]1169                        + "\\\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
1170                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
[1195]1171                return self.indent(icount) + "void " + pfx + "do_segment(" + parameters + ") {" \
[1196]1172                + "\n" + self.indent(icount) + "  int i;" \
[1882]1173                + "\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
[1195]1174                + "\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
[906]1175                + "\n" + self.indent(icount + 2) + "}" 
1176
[1198]1177        def declaration(self, type_name, instance_name, icount=0):
1178                if self.use_C_syntax: return self.indent(icount) + "struct " + type_name + " " + instance_name + ";\n"
[880]1179                return self.indent(icount) + type_name + " " + instance_name + ";\n"
1180               
1181        def carry_init(self, carry_count, icount=0):   
[1571]1182                #CARRY SET
1183                return ""
1184               
[2206]1185        def carry_declare(self, carry_variable, carry_count, adv_n_count=0, icount=0):
1186                adv_n_decl = ""
1187                #if adv_n_count > 0:
1188                #       adv_n_decl = "\n" + self.indent(icount) + "BitBlock pending64[%s];" % adv_n_count
[1571]1189                #CARRY SET
[2640]1190                return self.indent(icount) + self.ccgo.GenerateCarryDecls()
[880]1191
1192        def carry_test(self, carry_variable, carry_count, icount=0):
[1571]1193                #CARRY SET
1194                return self.indent(icount) + "carryQ.CarryTest(0, %i)" % (carry_count)         
[880]1195               
1196        def indent(self, icount):
1197                s = ""
1198                for i in range(0,icount): s += " "
1199                return s       
1200               
1201        def do_block_parameters(self, parameters):
[1195]1202                if self.use_C_syntax:
[1196]1203                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters])
1204                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
[1195]1205                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters])
[880]1206               
[922]1207        def do_final_block_parameters(self, parameters):
[1195]1208                if self.use_C_syntax:
[1196]1209                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
1210                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters]+ ["EOF_mask"])
[1195]1211                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
[922]1212               
[906]1213        def do_segment_parameters(self, parameters):
[1195]1214                if self.use_C_syntax:
[1196]1215                        #return ", ".join([self.type_name + " * " + + self.instance_name] + [upper1(p) + " " + lower1(p) + "[]" for p in parameters])
[2218]1216                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ["segment_blocks"])
[1882]1217                else: return ", ".join([upper1(p) + " " + lower1(p) + "[]" for p in parameters] + ["int segment_blocks"])
[906]1218
1219        def do_segment_args(self, parameters):
[1195]1220                if self.use_C_syntax:
1221                        return ", ".join([lower1(self.type_name)] + [lower1(p) + "[i]" for p in parameters])
1222                else: return ", ".join([lower1(p) + "[i]" for p in parameters])
[908]1223
[753]1224def main(infilename, outfile = sys.stdout):
1225  t = ast.parse(file(infilename).read())
[865]1226  outfile.write(StreamStructGen(True).gen(t))
[753]1227  outfile.write(FunctionXlator().xlat(t))
1228
1229#
1230#  Routines for compatibility with the old compiler/template.
1231#  Quick and dirty hacks for now - Dec. 2010.
1232#
1233
[777]1234class MainLoopTransformer:
[2627]1235  def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, add_assert_bitblock_align=False, dump_func_data=False, main_node_id='Main'):
[889]1236       
[777]1237    self.main_module = main_module
[880]1238    self.main_node_id = main_node_id
[1195]1239    self.use_C_syntax = C_syntax
[889]1240    self.add_dump_stmts = add_dump_stmts
[1901]1241    self.add_assert_bitblock_align = add_assert_bitblock_align
[2393]1242    self.dump_func_data = dump_func_data
[880]1243   
1244        # Gather and partition function definition nodes.
1245    stream_function_visitor = StreamFunctionVisitor(self.main_module)
1246    self.stream_function_node = stream_function_visitor.stream_function_node
[2215]1247    for key, node in self.stream_function_node.iteritems():
1248                AdvanceCombiner().xfrm(node)
[880]1249    self.main_node = self.stream_function_node[main_node_id]
1250    self.main_carry_count = CarryCounter().count(self.main_node)
[2206]1251    self.main_adv_n_count = adv_nCounter().count(self.main_node)
[2690]1252    self.main_carry_info_set = CarryInfoSetVisitor(self.main_node)
[2640]1253    self.main_ccgo = Strategic_CCGO_Factory(self.main_carry_info_set)
[2206]1254    assert self.main_adv_n_count == 0, "Advance32() in main not supported.\n"
[880]1255    del self.stream_function_node[self.main_node_id]
1256   
1257    self.stream_functions = {}
1258    for key, node in self.stream_function_node.iteritems():
1259                stream_function = StreamFunction()
1260                stream_function.carry_count = CarryCounter().count(node)
[2260]1261                stream_function.init_to_one_list = CarryInitToOneList().count(node)
[2206]1262                stream_function.adv_n_count = adv_nCounter().count(node)
[2690]1263                carry_info_set = CarryInfoSetVisitor(node)
[2640]1264                stream_function.ccgo = Strategic_CCGO_Factory(carry_info_set)
[880]1265                stream_function.type_name = node.name[0].upper() + node.name[1:]
1266                stream_function.instance_name = node.name[0].lower() + node.name[1:]
1267                stream_function.parameters = FunctionVars(node).params
1268                stream_function.declarations = BitBlock_decls_of_fn(node)
[2688]1269                stream_function.declarations += stream_function.ccgo.GenerateStreamFunctionDecls()
[880]1270                stream_function.initializations = StreamInitializations().xfrm(node) 
1271               
[2631]1272                t = TempifyBuiltins()
1273                t.xfrm(node)
1274                stream_function.declarations += "\n" + BitBlock_decls_from_vars(t.tempVars())
1275
[2630]1276                StringMatchCompiler().xfrm(node)
1277                AugAssignRemoval().xfrm(node)
1278
[2606]1279               
[880]1280                Bitwise_to_SIMD().xfrm(node)
[922]1281                final_block_node = copy.deepcopy(node)
[1195]1282                if self.use_C_syntax:
[1196]1283                        carryQname = stream_function.instance_name + ".carryQ"
[1195]1284                else: carryQname = "carryQ"
[2640]1285                CarryIntroVisitor = CarryIntro(stream_function.ccgo, carryQname)
[2628]1286                CarryIntroVisitor.xfrm_fndef(node)
1287                CarryIntroVisitor.xfrm_fndef_final(final_block_node)
[2640]1288                #
1289                # Compile asserts after carry intro so that generated if-statements
1290                # are ignored.
1291                AssertCompiler().xfrm(node)
1292                AssertCompiler().xfrm(final_block_node)
[889]1293                if self.add_dump_stmts: 
1294                        Add_SIMD_Register_Dump().xfrm(node)
[1849]1295                        Add_SIMD_Register_Dump().xfrm(final_block_node)
[1901]1296
1297                if self.add_assert_bitblock_align:
1298                        Add_Assert_BitBlock_Align().xfrm(node)
1299                        Add_Assert_BitBlock_Align().xfrm(final_block_node)
1300
[2636]1301                #if stream_function.carry_count > 0:
1302                #       node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
[2640]1303                node.body += stream_function.ccgo.GenerateStreamFunctionFinalization()
[2636]1304
[889]1305               
[880]1306                stream_function.statements = Cgen.py2C(4).gen(node.body)
[922]1307                stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
[880]1308                self.stream_functions[stream_function.type_name] = stream_function
[2393]1309               
1310    if self.dump_func_data:     
1311        for key, value in self.stream_functions.iteritems():
1312                        value.dump()
[2409]1313        sys.exit()
[2393]1314               
[2640]1315    self.emitter = Emitter(self.use_C_syntax, stream_function.ccgo)
[880]1316   
[777]1317  def any_carry_expr(self):
[880]1318       
1319        carry_test = []
1320       
1321        if self.main_carry_count > 0:
[1571]1322                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count)) 
[880]1323                        carry_test.append(" || ")
1324
1325        for key in self.stream_functions.keys():               
1326                if self.stream_functions[key].carry_count > 0:
[1571]1327                        carry_test.append(self.stream_functions[key].instance_name + "." + self.emitter.carry_test('carryQ',self.stream_functions[key].carry_count))# TODO Update self.emitter.carry_test
[880]1328                        carry_test.append(" || ")
1329
1330        if len(carry_test) > 0:
1331                carry_test.pop()
[1195]1332                return "".join(carry_test)
[880]1333        return "1"
1334
1335  def gen_globals(self):
[865]1336    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
[880]1337    for key in self.stream_functions.keys():
[2640]1338                sf = self.stream_functions[key]
1339                self.Cglobals += Emitter(self.use_C_syntax, sf.ccgo).definition(sf, 2)     
[880]1340                       
1341  def gen_declarations(self): 
[865]1342    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
[880]1343    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
1344    if self.main_carry_count > 0: 
[1571]1345        self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
[880]1346               
[777]1347  def gen_initializations(self):
1348    self.Cinits = ""
[880]1349    if self.main_carry_count > 0: 
1350        self.Cinits += self.emitter.carry_init(self.main_carry_count)
[777]1351    self.Cinits += StreamInitializations().xfrm(self.main_module)
[1195]1352    if self.use_C_syntax:
1353                for key in self.stream_functions.keys():
[1196]1354                        if self.stream_functions[key].carry_count == 0: continue
1355                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
1356                        self.Cinits += "CarryInit(" + self.stream_functions[key].instance_name + ".carryQ, %i);\n" % (self.stream_functions[key].carry_count)
1357    else:
1358                for key in self.stream_functions.keys():
1359                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
[1195]1360
[880]1361                       
[889]1362  def xfrm_block_stmts(self):
[2009]1363    StringMatchCompiler().xfrm(self.main_node)
[880]1364    AugAssignRemoval().xfrm(self.main_node)
1365    Bitwise_to_SIMD().xfrm(self.main_node)
[2099]1366    Bitwise_to_SIMD().xfrm(self.main_node)
[937]1367    final_block_main = copy.deepcopy(self.main_node)
[2690]1368    carry_info_set = CarryInfoSetVisitor(self.main_node)
[2636]1369    ccgo = Strategic_CCGO_Factory(carry_info_set)
1370    CarryIntroVisitor = CarryIntro(ccgo)
[2628]1371    CarryIntroVisitor.xfrm_fndef(self.main_node)
1372    CarryIntroVisitor.xfrm_fndef_final(final_block_main)
[2636]1373    AssertCompiler().xfrm(self.main_node)
[889]1374    if self.add_dump_stmts: 
[880]1375        Add_SIMD_Register_Dump().xfrm(self.main_node)
[939]1376        Add_SIMD_Register_Dump().xfrm(final_block_main)
[880]1377               
[1901]1378    if self.add_assert_bitblock_align:
[1917]1379        print "add_assert_bitblock_align"
[1901]1380        Add_Assert_BitBlock_Align().xfrm(self.main_node)
1381        Add_Assert_BitBlock_Align().xfrm(final_block_main)
1382
[1195]1383    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
1384    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
[889]1385   
[2636]1386    #if self.main_carry_count > 0:
[1571]1387                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
[2636]1388    self.main_node.body += ccgo.GenerateStreamFunctionFinalization()
1389   
[889]1390   
1391       
[880]1392    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
[937]1393    self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
[880]1394   
[753]1395if __name__ == "__main__":
1396                import doctest
[902]1397                doctest.testmod()
Note: See TracBrowser for help on using the repository browser.