source: proto/Compiler/pablo.py @ 1571

Last change on this file since 1571 was 1571, checked in by ksherdy, 8 years ago

Updated Pablo for use with the new carryQ implementation.

File size: 29.0 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, 2011, Robert D. Cameron, Kenneth S. Herdy
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11name_substitution_map = {}
12
13def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='pablo'):
14        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
15        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
16                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
17        return iscall and len(fncall.args) == builtin_arg_cnt and fncall.kwargs == None and fncall.starargs == None
18
19def is_simd_not(e):
20  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
21
22def mkQname(obj, field):
23  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
24
25def mkCall(fn_name, args):
26  if isinstance(fn_name, str): 
27        if name_substitution_map.has_key(fn_name): fn_name = name_substitution_map[fn_name]
28        fn_name = ast.Name(fn_name, ast.Load())
29  return ast.Call(fn_name, args, [], None, None)
30
31def mkCallStmt(fn_name, args):
32  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
33  return ast.Expr(ast.Call(fn_name, args, [], None, None))
34
35#
36# Reducing AugAssign, e.g.  x |= y becomes x = x | y
37#
38class AugAssignRemoval(ast.NodeTransformer):
39  def xfrm(self, t):
40    return self.generic_visit(t)
41  def visit_AugAssign(self, e):
42    self.generic_visit(e)
43    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
44
45#
46# Introducing BitBlock logical operations
47#
48class Bitwise_to_SIMD(ast.NodeTransformer):
49  """
50  Make the following substitutions:
51     x & y => simd_and(x, y)
52     x & ~y => simd_andc(x, y)
53     x | y => simd_or(x, y)
54     x ^ y => simd_xor(x, y)
55     ~x    => simd_not(x)
56     0     => simd_const_1(0)
57     -1    => simd_const_1(1)
58     if x: => if bitblock_has_bit(x):
59  while x: => while bitblock_has_bit(x):
60  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
61 
62  pfx = simd_and(bit0, bit1)
63  sfx = simd_and(bit0, simd_not(bit1))
64  >>>
65  """
66  def xfrm(self, t):
67    return self.generic_visit(t)
68  def visit_UnaryOp(self, t):
69    self.generic_visit(t)
70    if isinstance(t.op, ast.Invert):
71      return mkCall('simd_not', [t.operand])
72    else: return t
73  def visit_BinOp(self, t):
74    self.generic_visit(t)
75    if isinstance(t.op, ast.BitOr):
76      return mkCall('simd_or', [t.left, t.right])
77    elif isinstance(t.op, ast.BitAnd):
78      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
79      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
80      else: return mkCall('simd_and', [t.left, t.right])
81    elif isinstance(t.op, ast.BitXor):
82      return mkCall('simd_xor', [t.left, t.right])
83    else: return t
84  def visit_Num(self, numnode):
85    n = numnode.n
86    if n == 0: return mkCall('simd_const_1', [numnode])
87    elif n == -1: return mkCall('simd_const_1', [ast.Num(1)])
88    else: return numnode
89  def visit_If(self, ifNode):
90    self.generic_visit(ifNode)
91    ifNode.test = mkCall('bitblock_has_bit', [ifNode.test])
92    return ifNode
93  def visit_While(self, whileNode):
94    self.generic_visit(whileNode)
95    whileNode.test = mkCall('bitblock_has_bit', [whileNode.test])
96    return whileNode
97  def visit_Subscript(self, numnode):
98    return numnode  # no recursive modifications of index expressions
99
100#
101#  Generating BitBlock declarations for Local Variables
102#
103class FunctionVars(ast.NodeVisitor):
104  def __init__(self,node):
105        self.params = []
106        self.stores = []
107        self.generic_visit(node)
108  def visit_Name(self, nm):
109    if isinstance(nm.ctx, ast.Param):
110      self.params.append(nm.id)
111    if isinstance(nm.ctx, ast.Store):
112      if nm.id not in self.stores: self.stores.append(nm.id)
113  def getLocals(self): 
114    return [v for v in self.stores if not v in self.params]
115
116MAX_LINE_LENGTH = 80
117
118def BitBlock_decls_from_vars(varlist):
119  global MAX_LINE_LENGTH
120  decls =  ""
121  if not len(varlist) == 0:
122          decls = "             BitBlock"
123          pending = ""
124          linelgth = 10
125          for v in varlist:
126            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
127              decls += pending + " " + v
128              linelgth += len(pending + v) + 1
129            else:
130              decls += ";\n             BitBlock " + v
131              linelgth = 11 + len(v)
132            pending = ","
133          decls += ";"
134  return decls
135
136def BitBlock_decls_of_fn(fndef):
137  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
138
139def BitBlock_header_of_fn(fndef):
140  Ccode = "static inline void " + fndef.name + "("
141  pending = ""
142  for arg in fndef.args.args:
143    if isinstance(arg, ast.Name):
144      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
145      pending = ", "
146  if CarryCounter().count(fndef) > 0:
147    Ccode += pending + " CarryQtype & carryQ"
148  Ccode += ")"
149  return Ccode
150
151
152
153#
154#  Stream Initialization Statement Extraction
155#
156#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
157class StreamInitializations(ast.NodeTransformer):
158  def xfrm(self, node):
159    self.stream_stmts = []
160    self.loop_post_inits = []
161    self.generic_visit(node)
162    return Cgen.py2C().gen(self.stream_stmts)
163  def visit_Assign(self, node):
164    if isinstance(node.value, ast.Num):
165      if node.value.n == 0: return node
166      elif node.value.n == -1: return node
167      else: 
168        stream_init = copy.deepcopy(node)
169        stream_init.value = mkCall('sisd_from_int', [node.value])
170        loop_init = copy.deepcopy(node)
171        loop_init.value.n = 0
172        self.stream_stmts.append(stream_init)
173        self.loop_post_inits.append(loop_init)
174        return None
175    else: return node
176  def visit_FunctionDef(self, node):
177    self.generic_visit(node)
178    node.body = node.body + self.loop_post_inits
179    return node
180
181#
182# Carry Introduction Transformation
183#
184class CarryCounter(ast.NodeVisitor):
185  def visit_Call(self, callnode):
186    self.generic_visit(callnode)
187    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2) or is_BuiltIn_Call(callnode,'ScanTo', 2) or is_BuiltIn_Call(callnode,'ScanToFirst', 1):       
188      self.carry_count += 1
189  def visit_BinOp(self, exprnode):
190    self.generic_visit(exprnode)
191    if isinstance(exprnode.op, ast.Sub):
192      self.carry_count += 1
193    if isinstance(exprnode.op, ast.Add):
194      self.carry_count += 1
195  def count(self, nodeToVisit):
196    self.carry_count = 0
197    self.generic_visit(nodeToVisit)
198    return self.carry_count
199
200class Adv32Counter(ast.NodeVisitor):
201  def visit_Call(self, callnode):
202    self.generic_visit(callnode)
203    if is_BuiltIn_Call(callnode,'Advance32', 1):       
204      self.adv32_count += 1
205  def count(self, nodeToVisit):
206    self.adv32_count = 0
207    self.generic_visit(nodeToVisit)
208    return self.adv32_count
209
210class CarryIntro(ast.NodeTransformer):
211  def __init__(self, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
212    self.carryvar = ast.Name(carryvar, ast.Load())
213    self.carryin = carryin
214    self.carryout = carryout
215  def xfrm_fndef(self, fndef):
216    self.current_carry = 0
217    self.current_adv32 = 0
218    carry_count = CarryCounter().count(fndef)
219    if carry_count == 0: return fndef
220    self.generic_visit(fndef)
221#   
222#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
223    return fndef
224  def generic_xfrm(self, node):
225    self.current_carry = 0
226    self.current_adv32 = 0
227    carry_count = CarryCounter().count(node)
228    adv32_count = Adv32Counter().count(node)
229    if carry_count == 0 and adv32_count == 0: return node
230    self.generic_visit(node)
231    return node
232  def visit_Call(self, callnode):
233    self.generic_visit(callnode)
234    #CARRYSET
235    carry_args = [ast.Num(self.current_carry)]
236    adv32_args = [ast.Subscript(ast.Name('pending32', ast.Load()), ast.Num(self.current_adv32), ast.Load())]
237    if is_BuiltIn_Call(callnode, 'Advance', 1):         
238      #CARRYSET
239      rtn = self.carryvar.id + "." + "BitBlock_advance%s_co" % (self.carryin)
240      c = mkCall(rtn, callnode.args + carry_args)
241      self.current_carry += 1
242      return c
243    elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
244      #CARRYSET
245      rtn = self.carryvar.id + "." + "BitBlock_advance32%s_co" % (self.carryin)
246      c = mkCall(rtn, callnode.args + adv32_args)
247      self.current_adv32 += 1
248      return c
249    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
250      #CARRYSET
251      rtn = self.carryvar.id + "." + "BitBlock_scanthru%s_co" % (self.carryin)
252      c = mkCall(rtn, callnode.args + carry_args)
253      self.current_carry += 1
254      return c
255    elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
256      # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
257      # in having a separate BitBlock_scanto routine.
258      #CARRYSET
259      rtn = self.carryvar.id + "." + "BitBlock_scanthru%s_co" % (self.carryin)
260      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
261      else: scanclass = mkCall('simd_not', [callnode.args[1]])
262      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
263      self.current_carry += 1
264      return c
265    elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
266      #CARRYSET
267      rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
268      #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
269      c = mkCall(rtn, callnode.args + carry_args)
270      self.current_carry += 1
271      return c
272    elif is_BuiltIn_Call(callnode, 'atEOF', 1):
273      if self.carryout != "": 
274        # Non final block: atEOF(x) = 0.
275        return mkCall('simd_const_1', [ast.Num(0)])
276      else: return mkCall('simd_andc', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
277    elif is_BuiltIn_Call(callnode, 'inFile', 1):
278      if self.carryout != "": 
279        # Non final block: inFile(x) = x.
280        return callnode.args[0]
281      else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
282    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
283      rtn = "StreamScan"           
284      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
285                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
286                                           ast.Name(callnode.args[1].id, ast.Load())])
287      return c
288    else: return callnode
289  def visit_BinOp(self, exprnode):
290    self.generic_visit(exprnode)
291    carry_args = [ast.Num(self.current_carry)]
292    if isinstance(exprnode.op, ast.Sub):
293      #CARRYSET
294      rtn = self.carryvar.id + "." + "BitBlock_sub%s_co" % (self.carryin)
295      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
296      self.current_carry += 1
297      return c
298    elif isinstance(exprnode.op, ast.Add):
299      #CARRYSET
300      rtn = self.carryvar.id + "." + "BitBlock_add%s_co" % (self.carryin)
301      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
302      self.current_carry += 1
303      return c
304    else: return exprnode
305  def visit_If(self, ifNode):
306    carry_base = self.current_carry
307    carries = CarryCounter().count(ifNode)
308    assert Adv32Counter().count(ifNode) == 0, "Advance32() within if: illegal\n"
309    self.generic_visit(ifNode)
310    if carries == 0 or self.carryin == "": return ifNode
311    #CARRYSET
312    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
313    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
314    new_else_part = ifNode.orelse + [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]
315    return ast.If(new_test, ifNode.body, new_else_part)
316  def visit_While(self, whileNode):
317    if self.carryout == '':
318      whileNode.test.args[0] = mkCall("simd_and", [whileNode.test.args[0], ast.Name('EOF_mask', ast.Load())])
319    carry_base = self.current_carry
320    assert Adv32Counter().count(whileNode) == 0, "Advance32() within while: illegal\n"
321    carries = CarryCounter().count(whileNode)
322    #CARRYSET
323    if carries == 0: return whileNode
324    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
325    local_carryvar = 'subcarryQ'
326    inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
327    self.generic_visit(whileNode)
328    local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
329    inner_while.body.insert(0, local_carry_decl)
330    final_combine = mkCallStmt('carryQ.CarryCombine', [ast.Name(local_carryvar + '.array()', ast.Load()), ast.Num(carry_base), ast.Num(carries)])
331    inner_while.body.append(final_combine)
332    #CARRYSET
333    if self.carryin == '': new_test = whileNode.test
334    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
335    else_part = [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]   
336    return ast.If(new_test, whileNode.body + [inner_while], else_part)
337
338class StreamStructGen(ast.NodeVisitor):
339  """
340  Given a BitStreamSet subclass, generate the equivalent C struct.
341  >>> obj = ast.parse(r'''
342  ... class S1(BitStreamSet):
343  ...   a1 = 0
344  ...   a2 = 0
345  ...   a3 = 0
346  ...
347  ... class S2(BitStreamSet):
348  ...   x1 = 0
349  ...   x2 = 0
350  ... ''')
351  >>> print StreamStructGen().gen(obj)
352  struct S1 {
353    BitBlock a1;
354    BitBlock a2;
355    BitBlock a3;
356  }    self.current_adv32 = 0
357
358 
359  struct S2 {
360    BitBlock x1;
361    BitBlock x2;
362  }
363  """
364  def __init__(self, asType=False):
365    self.asType = asType
366  def gen(self, tree):
367    self.Ccode=""
368    self.generic_visit(tree)
369    return self.Ccode
370  def gen_struct_types(self, tree):
371    self.asType = True
372    self.Ccode=""
373    self.generic_visit(tree)
374    return self.Ccode
375  def gen_struct_vars(self, tree):
376    self.asType = False
377    self.Ccode=""
378    self.generic_visit(tree)
379    return self.Ccode
380  def visit_ClassDef(self, node):
381    class_name = node.name[0].upper() + node.name[1:]
382    instance_name = node.name[0].lower() + node.name[1:]
383    self.Ccode += "  struct " + class_name
384    if self.asType:
385            self.Ccode += " {\n"
386            for stmt in node.body:
387              if isinstance(stmt, ast.Assign):
388                for v in stmt.targets:
389                  if isinstance(v, ast.Name):
390                    self.Ccode += "  BitBlock " + v.id + ";\n"
391            self.Ccode += "}" 
392    else: self.Ccode += " " + instance_name
393    self.Ccode += ";\n\n"
394 
395class StreamFunctionDecl(ast.NodeVisitor):
396  def __init__(self):
397    pass
398  def gen(self, tree):
399    self.Ccode=""
400    self.generic_visit(tree)
401    return self.Ccode
402  def visit_FunctionDef(self, node):
403    self.Ccode += "static inline void " + node.name + "("
404    pending = ""
405    for arg in node.args.args:
406      if isinstance(arg, ast.Name):
407        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
408        pending = ", "
409    self.Ccode += ");\n"
410
411#
412# Adding Debugging Statements
413#
414class Add_SIMD_Register_Dump(ast.NodeTransformer):
415  def xfrm(self, t):
416    return self.generic_visit(t)
417  def visit_Assign(self, t):
418    self.generic_visit(t)
419    v = t.targets[0]
420    dump_stmt = mkCallStmt('print_simd_register', [ast.Str(Cgen.py2C().gen(v)), v])
421    return [t, dump_stmt]
422
423class StreamFunctionCarryCounter(ast.NodeVisitor):
424  def __init__(self):
425        self.carry_count = {}
426       
427  def count(self, node):
428        self.generic_visit(node)
429        return self.carry_count
430                                   
431  def visit_FunctionDef(self, node):   
432        type_name = node.name[0].upper() + node.name[1:]                       
433        self.carry_count[type_name] = CarryCounter().count(node)
434     
435class StreamFunctionCallXlator(ast.NodeTransformer):
436  def __init__(self, xlate_type="normal"):
437        self.stream_function_type_names = []
438        self.xlate_type = xlate_type
439
440  def xfrm(self, node, stream_function_type_names, C_syntax):
441        self.stream_function_type_names = stream_function_type_names
442        self.C_syntax = C_syntax
443        self.generic_visit(node)
444       
445  def visit_Call(self, node):   
446        self.generic_visit(node)
447
448        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in self.stream_function_type_names:
449             name = lower1(node.func.id)
450             node.func.id = name + ("_" if self.C_syntax else ".") + ("do_final_block" if self.xlate_type == "final" else "do_block")
451             if self.C_syntax:
452                     node.args = [ast.Name(lower1(name), ast.Load())] + node.args
453             if self.xlate_type == "final":
454                   node.args = node.args + [ast.Name("EOF_mask", ast.Load())]
455                     
456        return node     
457               
458class StreamFunctionVisitor(ast.NodeVisitor):
459        def __init__(self,node):
460                self.stream_function_node = {}
461                self.generic_visit(node)
462                                                             
463        def visit_FunctionDef(self, node):                     
464                key = node.name[0].upper() + node.name[1:]
465                self.stream_function_node[key] = node
466               
467class StreamFunction():
468        def __init__(self):
469                self.carry_count = 0 
470                self.adv32_count = 0 
471                self.type_name = ""
472                self.instance_name = "" 
473                self.parameters = ""
474                self.declarations = "" 
475                self.initializations = "" 
476#
477# TODO Consolidate *all* C code generation into the Emitter class.   Medium priority.
478# TODO Implement 'pretty print' indentation.   Low priority.
479# TODO Migrate Emiter() class to Emitter module.  Medium priority.
480
481def lower1(name):
482    return name[0].lower() + name[1:]
483def upper1(name):
484    return name[0].upper() + name[1:]
485
486def escape_newlines(str):
487  return str.replace('\n', '\\\n')
488
489class Emitter():
490        def __init__(self, use_C_syntax):
491                self.use_C_syntax = use_C_syntax
492
493        def definition(self, stream_function, icount=0):
494               
495                constructor = ""
496                carry_declaration = ""
497                self.type_name = stream_function.type_name
498               
499                if stream_function.carry_count > 0 or stream_function.adv32_count > 0:
500                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.adv32_count)
501                        carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv32_count)
502
503                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), 
504                                                stream_function.declarations, 
505                                                stream_function.initializations, 
506                                                stream_function.statements)             
507
508                do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), 
509                                                stream_function.declarations, 
510                                                stream_function.initializations, 
511                                                stream_function.final_block_statements)                 
512
513                do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), 
514                                                self.do_segment_args(stream_function.parameters))       
515
516                if self.use_C_syntax:
517                        return self.indent(icount) + "struct " + stream_function.type_name + " {" \
518                               + "\n" + self.indent(icount) + carry_declaration \
519                               + "\n" + self.indent(icount) + "};\n" \
520                               + "\n" + self.indent(icount) + do_block_function \
521                               + "\n" + self.indent(icount) + do_final_block_function \
522                               + "\n" + self.indent(icount) + do_segment_function + "\n\n"
523                               
524                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
525                + "\n" + self.indent(icount) + constructor \
526                + "\n" + self.indent(icount) + do_block_function \
527                + "\n" + self.indent(icount) + do_final_block_function \
528                + "\n" + self.indent(icount) + do_segment_function \
529                + "\n" + self.indent(icount) + carry_declaration \
530                + "\n" + self.indent(icount) + "};\n\n"
531
532        def constructor(self, type_name, carry_count, adv32_count, icount=0):
533                adv32_decl = ""
534                for i in range(adv32_count): adv32_decl += self.indent(icount+2) + "pending32[%s] = 0;\n" % i   
535                return self.indent(icount) + "%s() { ""\n" % (type_name) + adv32_decl + self.carry_init(carry_count) + " }" 
536                       
537        def do_block(self, parameters, declarations, initializations, statements, icount=0):
538                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
539                if self.use_C_syntax:
540                        return "#define " + pfx + "do_block(" + parameters + ")\\\n do {" \
541                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
542                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
543                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
544                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
545                return self.indent(icount) + "void " + pfx + "do_block(" + parameters + ") {" \
546                + "\n" + self.indent(icount) + declarations \
547                + "\n" + self.indent(icount) + initializations \
548                + "\n" + self.indent(icount) + statements \
549                + "\n" + self.indent(icount + 2) + "}" 
550
551
552
553
554        def do_final_block(self, parameters, declarations, initializations, statements, icount=0):
555                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
556                if self.use_C_syntax:
557                        return "#define " + pfx + "do_final_block(" + parameters + ")\\\n do {" \
558                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
559                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
560                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
561                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
562                return self.indent(icount) + "void " + pfx + "do_final_block(" + parameters + ") {" \
563                + "\n" + self.indent(icount) + declarations \
564                + "\n" + self.indent(icount) + initializations \
565                + "\n" + self.indent(icount) + statements \
566                + "\n" + self.indent(icount + 2) + "}" 
567
568        def do_segment(self, parameters, do_block_call_args, icount=0):
569                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
570                if self.use_C_syntax:
571                        return "#define " + pfx + "do_segment(" + parameters + ")\\\n do {" \
572                        + "\\\n" + self.indent(icount) + "  int i;" \
573                        + "\\\n" + self.indent(icount) + "  for (i = 0; i < SEGMENT_BLOCKS; i++)" \
574                        + "\\\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
575                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
576                return self.indent(icount) + "void " + pfx + "do_segment(" + parameters + ") {" \
577                + "\n" + self.indent(icount) + "  int i;" \
578                + "\n" + self.indent(icount) + "  for (i = 0; i < SEGMENT_BLOCKS; i++)" \
579                + "\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
580                + "\n" + self.indent(icount + 2) + "}" 
581
582        def declaration(self, type_name, instance_name, icount=0):
583                if self.use_C_syntax: return self.indent(icount) + "struct " + type_name + " " + instance_name + ";\n"
584                return self.indent(icount) + type_name + " " + instance_name + ";\n"
585               
586        def carry_init(self, carry_count, icount=0):   
587                #CARRY SET
588                return ""
589               
590        def carry_declare(self, carry_variable, carry_count, adv32_count=0, icount=0):
591                adv32_decl = ""
592                if adv32_count > 0:
593                        adv32_decl = "\n" + self.indent(icount) + "uint32_t pending32[%s];" % adv32_count
594                #CARRY SET
595                return self.indent(icount) + "CarryArray<%i> %s;" % (carry_count, carry_variable) + adv32_decl
596
597        def carry_test(self, carry_variable, carry_count, icount=0):
598                #CARRY SET
599                return self.indent(icount) + "carryQ.CarryTest(0, %i)" % (carry_count)         
600               
601        def indent(self, icount):
602                s = ""
603                for i in range(0,icount): s += " "
604                return s       
605               
606        def do_block_parameters(self, parameters):
607                if self.use_C_syntax:
608                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters])
609                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
610                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters])
611               
612        def do_final_block_parameters(self, parameters):
613                if self.use_C_syntax:
614                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
615                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters]+ ["EOF_mask"])
616                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
617               
618        def do_segment_parameters(self, parameters):
619                if self.use_C_syntax:
620                        #return ", ".join([self.type_name + " * " + + self.instance_name] + [upper1(p) + " " + lower1(p) + "[]" for p in parameters])
621                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
622                else: return ", ".join([upper1(p) + " " + lower1(p) + "[]" for p in parameters])
623
624        def do_segment_args(self, parameters):
625                if self.use_C_syntax:
626                        return ", ".join([lower1(self.type_name)] + [lower1(p) + "[i]" for p in parameters])
627                else: return ", ".join([lower1(p) + "[i]" for p in parameters])
628
629def main(infilename, outfile = sys.stdout):
630  t = ast.parse(file(infilename).read())
631  outfile.write(StreamStructGen(True).gen(t))
632  outfile.write(FunctionXlator().xlat(t))
633
634#
635#  Routines for compatibility with the old compiler/template.
636#  Quick and dirty hacks for now - Dec. 2010.
637#
638
639class MainLoopTransformer:
640  def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, main_node_id='Main'):
641       
642    self.main_module = main_module
643    self.main_node_id = main_node_id
644    self.use_C_syntax = C_syntax
645    self.add_dump_stmts = add_dump_stmts
646   
647        # Gather and partition function definition nodes.
648    stream_function_visitor = StreamFunctionVisitor(self.main_module)
649    self.stream_function_node = stream_function_visitor.stream_function_node
650    self.main_node = self.stream_function_node[main_node_id]
651    self.main_carry_count = CarryCounter().count(self.main_node)
652    self.main_adv32_count = Adv32Counter().count(self.main_node)
653    assert self.main_adv32_count == 0, "Advance32() in main not supported.\n"
654    del self.stream_function_node[self.main_node_id]
655   
656    self.stream_functions = {}
657    for key, node in self.stream_function_node.iteritems():
658                stream_function = StreamFunction()
659                stream_function.carry_count = CarryCounter().count(node)
660                stream_function.adv32_count = Adv32Counter().count(node)
661                stream_function.type_name = node.name[0].upper() + node.name[1:]
662                stream_function.instance_name = node.name[0].lower() + node.name[1:]
663                stream_function.parameters = FunctionVars(node).params
664                stream_function.declarations = BitBlock_decls_of_fn(node)
665                stream_function.initializations = StreamInitializations().xfrm(node) 
666               
667                AugAssignRemoval().xfrm(node)
668                Bitwise_to_SIMD().xfrm(node)
669                final_block_node = copy.deepcopy(node)
670                if self.use_C_syntax:
671                        carryQname = stream_function.instance_name + ".carryQ"
672                else: carryQname = "carryQ"
673                CarryIntro(carryQname).xfrm_fndef(node)
674                CarryIntro(carryQname, "_ci", "").xfrm_fndef(final_block_node)
675
676                if self.add_dump_stmts: 
677                        Add_SIMD_Register_Dump().xfrm(node)
678               
679                if stream_function.carry_count > 0:
680                        node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
681               
682                stream_function.statements = Cgen.py2C(4).gen(node.body)
683                stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
684                self.stream_functions[stream_function.type_name] = stream_function
685       
686    self.emitter = Emitter(self.use_C_syntax)
687   
688  def any_carry_expr(self):
689       
690        carry_test = []
691       
692        if self.main_carry_count > 0:
693                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count)) 
694                        carry_test.append(" || ")
695
696        for key in self.stream_functions.keys():               
697                if self.stream_functions[key].carry_count > 0:
698                        carry_test.append(self.stream_functions[key].instance_name + "." + self.emitter.carry_test('carryQ',self.stream_functions[key].carry_count))# TODO Update self.emitter.carry_test
699                        carry_test.append(" || ")
700
701        if len(carry_test) > 0:
702                carry_test.pop()
703                return "".join(carry_test)
704        return "1"
705
706  def gen_globals(self):
707    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
708    for key in self.stream_functions.keys():
709                self.Cglobals += Emitter(self.use_C_syntax).definition(self.stream_functions[key],2)       
710                       
711  def gen_declarations(self): 
712    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
713    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
714    if self.main_carry_count > 0: 
715        self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
716               
717  def gen_initializations(self):
718    self.Cinits = ""
719    if self.main_carry_count > 0: 
720        self.Cinits += self.emitter.carry_init(self.main_carry_count)
721    self.Cinits += StreamInitializations().xfrm(self.main_module)
722    if self.use_C_syntax:
723                for key in self.stream_functions.keys():
724                        if self.stream_functions[key].carry_count == 0: continue
725                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
726                        self.Cinits += "CarryInit(" + self.stream_functions[key].instance_name + ".carryQ, %i);\n" % (self.stream_functions[key].carry_count)
727    else:
728                for key in self.stream_functions.keys():
729                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
730
731                       
732  def xfrm_block_stmts(self):
733    AugAssignRemoval().xfrm(self.main_node)
734    Bitwise_to_SIMD().xfrm(self.main_node)
735    final_block_main = copy.deepcopy(self.main_node)
736    CarryIntro().xfrm_fndef(self.main_node)
737    CarryIntro('carryQ', '_ci', '').xfrm_fndef(final_block_main)
738    if self.add_dump_stmts: 
739        Add_SIMD_Register_Dump().xfrm(self.main_node)
740        Add_SIMD_Register_Dump().xfrm(final_block_main)
741               
742    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
743    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
744   
745    if self.main_carry_count > 0:
746                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
747                self.main_node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(self.main_carry_count)])]
748   
749   
750       
751    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
752    self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
753   
754if __name__ == "__main__":
755                import doctest
756                doctest.testmod()
Note: See TracBrowser for help on using the repository browser.