source: proto/Compiler/pablo.py @ 1917

Last change on this file since 1917 was 1917, checked in by cameron, 7 years ago

Add IDISA_INLINE to do_block

File size: 30.1 KB
Line 
1#
2# Pablo.py - parallel bitstream to bitblock
3#  2nd generation compiler
4#
5# Copyright 2010, 2011, Robert D. Cameron, Kenneth S. Herdy
6# All rights reserved.
7#
8import ast, copy, sys
9import Cgen
10
11name_substitution_map = {}
12
13def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='pablo'):
14        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == builtin_fnname
15        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
16                 iscall = fncall.func.value.id == builtin_fnmod_noprefix and fncall.func.attr == builtin_fnname
17        return iscall and len(fncall.args) == builtin_arg_cnt
18
19def dump_Call(fncall):
20        if isinstance(fncall.func, ast.Name): print "fn_name = %s\n" % fncall.func.id
21        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
22                print "fn_name = %s.%s\n" % (fncall.func.value.id, fncall.func.attr)
23        print "len(fncall.args) = %s\n" % len(fncall.args)
24
25def is_simd_not(e):
26  return isinstance(e, ast.Call) and isinstance(e.func, ast.Name) and e.func.id == 'simd_not'
27
28def mkQname(obj, field):
29  return ast.Attribute(ast.Name(obj, ast.Load()), field, ast.Load())
30
31def mkCall(fn_name, args):
32  if isinstance(fn_name, str): 
33        if name_substitution_map.has_key(fn_name): fn_name = name_substitution_map[fn_name]
34        fn_name = ast.Name(fn_name, ast.Load())
35  return ast.Call(fn_name, args, [], None, None)
36
37def mkCallStmt(fn_name, args):
38  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
39  return ast.Expr(ast.Call(fn_name, args, [], None, None))
40
41#
42# Reducing AugAssign, e.g.  x |= y becomes x = x | y
43#
44class AugAssignRemoval(ast.NodeTransformer):
45  def xfrm(self, t):
46    return self.generic_visit(t)
47  def visit_AugAssign(self, e):
48    self.generic_visit(e)
49    return ast.Assign([e.target], ast.BinOp(e.target, e.op, e.value))
50
51#
52# Introducing BitBlock logical operations
53#
54class Bitwise_to_SIMD(ast.NodeTransformer):
55  """
56  Make the following substitutions:
57     x & y => simd_and(x, y)
58     x & ~y => simd_andc(x, y)
59     x | y => simd_or(x, y)
60     x ^ y => simd_xor(x, y)
61     ~x    => simd_not(x)
62     0     => simd_const_1(0)
63     -1    => simd_const_1(1)
64     if x: => if bitblock::any(x):
65  while x: => while bitblock::any(x):
66  >>> ast_show(Bitwise_to_SIMD().xfrm(ast.parse("pfx = bit0 & bit1; sfx = bit0 &~ bit1")))
67 
68  pfx = simd_and(bit0, bit1)
69  sfx = simd_and(bit0, simd_not(bit1))
70  >>>
71  """
72  def xfrm(self, t):
73    return self.generic_visit(t)
74  def visit_UnaryOp(self, t):
75    self.generic_visit(t)
76    if isinstance(t.op, ast.Invert):
77      return mkCall('simd_not', [t.operand])
78    else: return t
79  def visit_BinOp(self, t):
80    self.generic_visit(t)
81    if isinstance(t.op, ast.BitOr):
82      return mkCall('simd_or', [t.left, t.right])
83    elif isinstance(t.op, ast.BitAnd):
84      if is_simd_not(t.right): return mkCall('simd_andc', [t.left, t.right.args[0]])
85      elif is_simd_not(t.left): return mkCall('simd_andc', [t.right, t.left.args[0]])
86      else: return mkCall('simd_and', [t.left, t.right])
87    elif isinstance(t.op, ast.BitXor):
88      return mkCall('simd_xor', [t.left, t.right])
89    else: return t
90  def visit_Num(self, numnode):
91    n = numnode.n
92    if n == 0: return mkCall('simd<1>::constant<0>', [])
93    elif n == -1: return mkCall('simd<1>::constant<1>', [])
94    else: return numnode
95  def visit_If(self, ifNode):
96    self.generic_visit(ifNode)
97    ifNode.test = mkCall('bitblock::any', [ifNode.test])
98    return ifNode
99  def visit_While(self, whileNode):
100    self.generic_visit(whileNode)
101    whileNode.test = mkCall('bitblock::any', [whileNode.test])
102    return whileNode
103  def visit_Subscript(self, numnode):
104    return numnode  # no recursive modifications of index expressions
105
106#
107#  Generating BitBlock declarations for Local Variables
108#
109class FunctionVars(ast.NodeVisitor):
110  def __init__(self,node):
111        self.params = []
112        self.stores = []
113        self.generic_visit(node)
114  def visit_Name(self, nm):
115    if isinstance(nm.ctx, ast.Param):
116      self.params.append(nm.id)
117    if isinstance(nm.ctx, ast.Store):
118      if nm.id not in self.stores: self.stores.append(nm.id)
119  def getLocals(self): 
120    return [v for v in self.stores if not v in self.params]
121
122MAX_LINE_LENGTH = 80
123
124def BitBlock_decls_from_vars(varlist):
125  global MAX_LINE_LENGTH
126  decls =  ""
127  if not len(varlist) == 0:
128          decls = "             BitBlock"
129          pending = ""
130          linelgth = 10
131          for v in varlist:
132            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
133              decls += pending + " " + v
134              linelgth += len(pending + v) + 1
135            else:
136              decls += ";\n             BitBlock " + v
137              linelgth = 11 + len(v)
138            pending = ","
139          decls += ";"
140  return decls
141
142def BitBlock_decls_of_fn(fndef):
143  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
144
145def BitBlock_header_of_fn(fndef):
146  Ccode = "static inline void " + fndef.name + "("
147  pending = ""
148  for arg in fndef.args.args:
149    if isinstance(arg, ast.Name):
150      Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
151      pending = ", "
152  if CarryCounter().count(fndef) > 0:
153    Ccode += pending + " CarryQtype & carryQ"
154  Ccode += ")"
155  return Ccode
156
157
158
159#
160#  Stream Initialization Statement Extraction
161#
162#  streamvar = 1 ==> streamvar = sisd_from_int(1) initially.
163class StreamInitializations(ast.NodeTransformer):
164  def xfrm(self, node):
165    self.stream_stmts = []
166    self.loop_post_inits = []
167    self.generic_visit(node)
168    return Cgen.py2C().gen(self.stream_stmts)
169  def visit_Assign(self, node):
170    if isinstance(node.value, ast.Num):
171      if node.value.n == 0: return node
172      elif node.value.n == -1: return node
173      else: 
174        stream_init = copy.deepcopy(node)
175        stream_init.value = mkCall('sisd_from_int', [node.value])
176        loop_init = copy.deepcopy(node)
177        loop_init.value.n = 0
178        self.stream_stmts.append(stream_init)
179        self.loop_post_inits.append(loop_init)
180        return None
181    else: return node
182  def visit_FunctionDef(self, node):
183    self.generic_visit(node)
184    node.body = node.body + self.loop_post_inits
185    return node
186
187#
188# Carry Introduction Transformation
189#
190class CarryCounter(ast.NodeVisitor):
191  def visit_Call(self, callnode):
192    self.generic_visit(callnode)
193    if is_BuiltIn_Call(callnode,'Advance', 1) or is_BuiltIn_Call(callnode,'ScanThru', 2) or is_BuiltIn_Call(callnode,'ScanTo', 2) or is_BuiltIn_Call(callnode,'ScanToFirst', 1):       
194      self.carry_count += 1
195  def visit_BinOp(self, exprnode):
196    self.generic_visit(exprnode)
197    if isinstance(exprnode.op, ast.Sub):
198      self.carry_count += 1
199    if isinstance(exprnode.op, ast.Add):
200      self.carry_count += 1
201  def count(self, nodeToVisit):
202    self.carry_count = 0
203    self.generic_visit(nodeToVisit)
204    return self.carry_count
205
206class Adv32Counter(ast.NodeVisitor):
207  def visit_Call(self, callnode):
208    self.generic_visit(callnode)
209    if is_BuiltIn_Call(callnode,'Advance32', 1):       
210      self.adv32_count += 1
211  def count(self, nodeToVisit):
212    self.adv32_count = 0
213    self.generic_visit(nodeToVisit)
214    return self.adv32_count
215
216class CarryIntro(ast.NodeTransformer):
217  def __init__(self, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
218    self.carryvar = ast.Name(carryvar, ast.Load())
219    self.carryin = carryin
220    self.carryout = carryout
221  def xfrm_fndef(self, fndef):
222    self.current_carry = 0
223    self.current_adv32 = 0
224    carry_count = CarryCounter().count(fndef)
225    if carry_count == 0: return fndef
226    self.generic_visit(fndef)
227#   
228#    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
229    return fndef
230  def generic_xfrm(self, node):
231    self.current_carry = 0
232    self.current_adv32 = 0
233    carry_count = CarryCounter().count(node)
234    adv32_count = Adv32Counter().count(node)
235    if carry_count == 0 and adv32_count == 0: return node
236    self.generic_visit(node)
237    return node
238  def visit_Call(self, callnode):
239    self.generic_visit(callnode)
240    #CARRYSET
241    carry_args = [ast.Num(self.current_carry)]
242    adv32_args = [ast.Subscript(ast.Name('pending32', ast.Load()), ast.Num(self.current_adv32), ast.Load())]
243    if is_BuiltIn_Call(callnode, 'Advance', 1):         
244      #CARRYSET
245      rtn = self.carryvar.id + "." + "BitBlock_advance%s_co" % (self.carryin)
246      c = mkCall(rtn, callnode.args + carry_args)
247      self.current_carry += 1
248      return c
249    elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
250      #CARRYSET
251      rtn = self.carryvar.id + "." + "BitBlock_advance32%s_co" % (self.carryin)
252      c = mkCall(rtn, callnode.args + adv32_args)
253      self.current_adv32 += 1
254      return c
255    elif is_BuiltIn_Call(callnode, 'ScanThru', 2):
256      #CARRYSET
257      rtn = self.carryvar.id + "." + "BitBlock_scanthru%s_co" % (self.carryin)
258      c = mkCall(rtn, callnode.args + carry_args)
259      self.current_carry += 1
260      return c
261    elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
262      # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
263      # in having a separate BitBlock_scanto routine.
264      #CARRYSET
265      rtn = self.carryvar.id + "." + "BitBlock_scanthru%s_co" % (self.carryin)
266      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
267      else: scanclass = mkCall('simd_not', [callnode.args[1]])
268      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
269      self.current_carry += 1
270      return c
271    elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
272      #CARRYSET
273      rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
274      #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
275      c = mkCall(rtn, callnode.args + carry_args)
276      self.current_carry += 1
277      return c
278    elif is_BuiltIn_Call(callnode, 'atEOF', 1):
279      if self.carryout != "": 
280        # Non final block: atEOF(x) = 0.
281        return mkCall('simd<1>::constant<0>', [])
282      else: return mkCall('simd_andc', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
283    elif is_BuiltIn_Call(callnode, 'inFile', 1):
284      if self.carryout != "": 
285        # Non final block: inFile(x) = x.
286        return callnode.args[0]
287      else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
288    elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
289      rtn = "StreamScan"           
290      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), 
291                                           ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
292                                           ast.Name(callnode.args[1].id, ast.Load())])
293      return c
294    else:
295      #dump_Call(callnode)
296      return callnode
297  def visit_BinOp(self, exprnode):
298    self.generic_visit(exprnode)
299    carry_args = [ast.Num(self.current_carry)]
300    if isinstance(exprnode.op, ast.Sub):
301      #CARRYSET
302      rtn = self.carryvar.id + "." + "BitBlock_sub%s_co" % (self.carryin)
303      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
304      self.current_carry += 1
305      return c
306    elif isinstance(exprnode.op, ast.Add):
307      #CARRYSET
308      rtn = self.carryvar.id + "." + "BitBlock_add%s_co" % (self.carryin)
309      c = mkCall(rtn, [exprnode.left, exprnode.right] + carry_args)
310      self.current_carry += 1
311      return c
312    else: return exprnode
313  def visit_If(self, ifNode):
314    carry_base = self.current_carry
315    carries = CarryCounter().count(ifNode)
316    assert Adv32Counter().count(ifNode) == 0, "Advance32() within if: illegal\n"
317    self.generic_visit(ifNode)
318    if carries == 0 or self.carryin == "": return ifNode
319    #CARRYSET
320    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
321    new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
322    new_else_part = ifNode.orelse + [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]
323    return ast.If(new_test, ifNode.body, new_else_part)
324  def visit_While(self, whileNode):
325    if self.carryout == '':
326      whileNode.test.args[0] = mkCall("simd_and", [whileNode.test.args[0], ast.Name('EOF_mask', ast.Load())])
327    carry_base = self.current_carry
328    assert Adv32Counter().count(whileNode) == 0, "Advance32() within while: illegal\n"
329    carries = CarryCounter().count(whileNode)
330    #CARRYSET
331    if carries == 0: return whileNode
332    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
333    local_carryvar = 'subcarryQ'
334    inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
335    self.generic_visit(whileNode)
336    local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
337    inner_while.body.insert(0, local_carry_decl)
338    final_combine = mkCallStmt('carryQ.CarryCombine', [ast.Name( '(ICarryQueue *) &' + local_carryvar, ast.Load()), ast.Num(carry_base), ast.Num(carries)])
339    inner_while.body.append(final_combine)
340    #CARRYSET
341    if self.carryin == '': new_test = whileNode.test
342    else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall('carryQ.CarryTest', carry_arglist)])
343    else_part = [mkCallStmt('carryQ.CarryDequeueEnqueue', carry_arglist)]   
344    return ast.If(new_test, whileNode.body + [inner_while], else_part)
345
346class StreamStructGen(ast.NodeVisitor):
347  """
348  Given a BitStreamSet subclass, generate the equivalent C struct.
349  >>> obj = ast.parse(r'''
350  ... class S1(BitStreamSet):
351  ...   a1 = 0
352  ...   a2 = 0
353  ...   a3 = 0
354  ...
355  ... class S2(BitStreamSet):
356  ...   x1 = 0
357  ...   x2 = 0
358  ... ''')
359  >>> print StreamStructGen().gen(obj)
360  struct S1 {
361    BitBlock a1;
362    BitBlock a2;
363    BitBlock a3;
364  }    self.current_adv32 = 0
365
366 
367  struct S2 {
368    BitBlock x1;
369    BitBlock x2;
370  }
371  """
372  def __init__(self, asType=False):
373    self.asType = asType
374  def gen(self, tree):
375    self.Ccode=""
376    self.generic_visit(tree)
377    return self.Ccode
378  def gen_struct_types(self, tree):
379    self.asType = True
380    self.Ccode=""
381    self.generic_visit(tree)
382    return self.Ccode
383  def gen_struct_vars(self, tree):
384    self.asType = False
385    self.Ccode=""
386    self.generic_visit(tree)
387    return self.Ccode
388  def visit_ClassDef(self, node):
389    class_name = node.name[0].upper() + node.name[1:]
390    instance_name = node.name[0].lower() + node.name[1:]
391    self.Ccode += "  struct " + class_name
392    if self.asType:
393            self.Ccode += " {\n"
394            for stmt in node.body:
395              if isinstance(stmt, ast.Assign):
396                for v in stmt.targets:
397                  if isinstance(v, ast.Name):
398                    self.Ccode += "  BitBlock " + v.id + ";\n"
399            self.Ccode += "}" 
400    else: self.Ccode += " " + instance_name
401    self.Ccode += ";\n\n"
402 
403class StreamFunctionDecl(ast.NodeVisitor):
404  def __init__(self):
405    pass
406  def gen(self, tree):
407    self.Ccode=""
408    self.generic_visit(tree)
409    return self.Ccode
410  def visit_FunctionDef(self, node):
411    self.Ccode += "static inline void " + node.name + "("
412    pending = ""
413    for arg in node.args.args:
414      if isinstance(arg, ast.Name):
415        self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + " & " + arg.id
416        pending = ", "
417    self.Ccode += ");\n"
418
419#
420# Adding Debugging Statements
421#
422class Add_SIMD_Register_Dump(ast.NodeTransformer):
423  def xfrm(self, t):
424    return self.generic_visit(t)
425  def visit_Assign(self, t):
426    self.generic_visit(t)
427    v = t.targets[0]
428    dump_stmt = mkCallStmt(' print_register<BitBlock>', [ast.Str(Cgen.py2C().gen(v)), v])
429    return [t, dump_stmt]
430
431#
432# Adding ASSERT_BITBLOCK_ALIGN macros
433#
434class Add_Assert_BitBlock_Align(ast.NodeTransformer):
435    def xfrm(self, t):
436      return self.generic_visit(t)
437    def visit_Assign(self, t):
438      self.generic_visit(t)
439      v = t.targets[0]
440      dump_stmt = mkCallStmt(' ASSERT_BITBLOCK_ALIGN', [v])
441      return [t, dump_stmt]
442
443class StreamFunctionCarryCounter(ast.NodeVisitor):
444  def __init__(self):
445        self.carry_count = {}
446       
447  def count(self, node):
448        self.generic_visit(node)
449        return self.carry_count
450                                   
451  def visit_FunctionDef(self, node):   
452        type_name = node.name[0].upper() + node.name[1:]                       
453        self.carry_count[type_name] = CarryCounter().count(node)
454     
455class StreamFunctionCallXlator(ast.NodeTransformer):
456  def __init__(self, xlate_type="normal"):
457        self.stream_function_type_names = []
458        self.xlate_type = xlate_type
459
460  def xfrm(self, node, stream_function_type_names, C_syntax):
461        self.stream_function_type_names = stream_function_type_names
462        self.C_syntax = C_syntax
463        self.generic_visit(node)
464       
465  def visit_Call(self, node):   
466        self.generic_visit(node)
467
468        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in self.stream_function_type_names:
469             name = lower1(node.func.id)
470             node.func.id = name + ("_" if self.C_syntax else ".") + ("do_final_block" if self.xlate_type == "final" else "do_block")
471             if self.C_syntax:
472                     node.args = [ast.Name(lower1(name), ast.Load())] + node.args
473             if self.xlate_type == "final":
474                   node.args = node.args + [ast.Name("EOF_mask", ast.Load())]
475                     
476        return node     
477               
478class StreamFunctionVisitor(ast.NodeVisitor):
479        def __init__(self,node):
480                self.stream_function_node = {}
481                self.generic_visit(node)
482                                                             
483        def visit_FunctionDef(self, node):                     
484                key = node.name[0].upper() + node.name[1:]
485                self.stream_function_node[key] = node
486               
487class StreamFunction():
488        def __init__(self):
489                self.carry_count = 0 
490                self.adv32_count = 0 
491                self.type_name = ""
492                self.instance_name = "" 
493                self.parameters = ""
494                self.declarations = "" 
495                self.initializations = "" 
496#
497# TODO Consolidate *all* C code generation into the Emitter class.   Medium priority.
498# TODO Implement 'pretty print' indentation.   Low priority.
499# TODO Migrate Emiter() class to Emitter module.  Medium priority.
500
501def lower1(name):
502    return name[0].lower() + name[1:]
503def upper1(name):
504    return name[0].upper() + name[1:]
505
506def escape_newlines(str):
507  return str.replace('\n', '\\\n')
508
509class Emitter():
510        def __init__(self, use_C_syntax):
511                self.use_C_syntax = use_C_syntax
512
513        def definition(self, stream_function, icount=0):
514               
515                constructor = ""
516                carry_declaration = ""
517                self.type_name = stream_function.type_name
518               
519                if stream_function.carry_count > 0 or stream_function.adv32_count > 0:
520                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.adv32_count)
521                        carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv32_count)
522
523                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), 
524                                                stream_function.declarations, 
525                                                stream_function.initializations, 
526                                                stream_function.statements)             
527
528                do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), 
529                                                stream_function.declarations, 
530                                                stream_function.initializations, 
531                                                stream_function.final_block_statements)                 
532
533                do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), 
534                                                self.do_segment_args(stream_function.parameters))       
535
536                if self.use_C_syntax:
537                        return self.indent(icount) + "struct " + stream_function.type_name + " {" \
538                               + "\n" + self.indent(icount) + carry_declaration \
539                               + "\n" + self.indent(icount) + "};\n" \
540                               + "\n" + self.indent(icount) + do_block_function \
541                               + "\n" + self.indent(icount) + do_final_block_function \
542                               + "\n" + self.indent(icount) + do_segment_function + "\n\n"
543                               
544                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
545                + "\n" + self.indent(icount) + constructor \
546                + "\n" + self.indent(icount) + do_block_function \
547                + "\n" + self.indent(icount) + do_final_block_function \
548                + "\n" + self.indent(icount) + do_segment_function \
549                + "\n" + self.indent(icount) + carry_declaration \
550                + "\n" + self.indent(icount) + "};\n\n"
551
552        def constructor(self, type_name, carry_count, adv32_count, icount=0):
553                adv32_decl = ""
554                for i in range(adv32_count): adv32_decl += self.indent(icount+2) + "pending32[%s] = 0;\n" % i   
555                return self.indent(icount) + "%s() { ""\n" % (type_name) + adv32_decl + self.carry_init(carry_count) + " }" 
556                       
557        def do_block(self, parameters, declarations, initializations, statements, icount=0):
558                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
559                if self.use_C_syntax:
560                        return "#define " + pfx + "do_block(" + parameters + ")\\\n do {" \
561                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
562                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
563                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
564                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
565                return self.indent(icount) + "IDISA_INLINE void " + pfx + "do_block(" + parameters + ") {" \
566                + "\n" + self.indent(icount) + declarations \
567                + "\n" + self.indent(icount) + initializations \
568                + "\n" + self.indent(icount) + statements \
569                + "\n" + self.indent(icount + 2) + "}" 
570
571
572
573
574        def do_final_block(self, parameters, declarations, initializations, statements, icount=0):
575                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
576                if self.use_C_syntax:
577                        return "#define " + pfx + "do_final_block(" + parameters + ")\\\n do {" \
578                        + "\\\n" + self.indent(icount) + escape_newlines(declarations) \
579                        + "\\\n" + self.indent(icount) + escape_newlines(initializations) \
580                        + "\\\n" + self.indent(icount) + escape_newlines(statements) \
581                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
582                return self.indent(icount) + "void " + pfx + "do_final_block(" + parameters + ") {" \
583                + "\n" + self.indent(icount) + declarations \
584                + "\n" + self.indent(icount) + initializations \
585                + "\n" + self.indent(icount) + statements \
586                + "\n" + self.indent(icount + 2) + "}" 
587
588        def do_segment(self, parameters, do_block_call_args, icount=0):
589                pfx = (lower1(self.type_name) + "_" if self.use_C_syntax else "")
590                if self.use_C_syntax:
591                        return "#define " + pfx + "do_segment(" + parameters + ")\\\n do {" \
592                        + "\\\n" + self.indent(icount) + "  int i;" \
593                        + "\\\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
594                        + "\\\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
595                        + "\\\n" + self.indent(icount + 2) + "} while (0)" 
596                return self.indent(icount) + "void " + pfx + "do_segment(" + parameters + ") {" \
597                + "\n" + self.indent(icount) + "  int i;" \
598                + "\n" + self.indent(icount) + "  for (i = 0; i < segment_blocks; i++)" \
599                + "\n" + self.indent(icount) + "    " + pfx + "do_block(" + do_block_call_args + ");" \
600                + "\n" + self.indent(icount + 2) + "}" 
601
602        def declaration(self, type_name, instance_name, icount=0):
603                if self.use_C_syntax: return self.indent(icount) + "struct " + type_name + " " + instance_name + ";\n"
604                return self.indent(icount) + type_name + " " + instance_name + ";\n"
605               
606        def carry_init(self, carry_count, icount=0):   
607                #CARRY SET
608                return ""
609               
610        def carry_declare(self, carry_variable, carry_count, adv32_count=0, icount=0):
611                adv32_decl = ""
612                if adv32_count > 0:
613                        adv32_decl = "\n" + self.indent(icount) + "uint32_t pending32[%s];" % adv32_count
614                #CARRY SET
615                return self.indent(icount) + "CarryArray<%i> %s;" % (carry_count, carry_variable) + adv32_decl
616
617        def carry_test(self, carry_variable, carry_count, icount=0):
618                #CARRY SET
619                return self.indent(icount) + "carryQ.CarryTest(0, %i)" % (carry_count)         
620               
621        def indent(self, icount):
622                s = ""
623                for i in range(0,icount): s += " "
624                return s       
625               
626        def do_block_parameters(self, parameters):
627                if self.use_C_syntax:
628                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters])
629                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters])
630                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters])
631               
632        def do_final_block_parameters(self, parameters):
633                if self.use_C_syntax:
634                        #return ", ".join([self.type_name + " * " + self.instance_name] + [upper1(p) + " * " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
635                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters]+ ["EOF_mask"])
636                else: return ", ".join([upper1(p) + " & " + lower1(p) for p in parameters]+ ["BitBlock EOF_mask"])
637               
638        def do_segment_parameters(self, parameters):
639                if self.use_C_syntax:
640                        #return ", ".join([self.type_name + " * " + + self.instance_name] + [upper1(p) + " " + lower1(p) + "[]" for p in parameters])
641                        return ", ".join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ["int segment_blocks"])
642                else: return ", ".join([upper1(p) + " " + lower1(p) + "[]" for p in parameters] + ["int segment_blocks"])
643
644        def do_segment_args(self, parameters):
645                if self.use_C_syntax:
646                        return ", ".join([lower1(self.type_name)] + [lower1(p) + "[i]" for p in parameters])
647                else: return ", ".join([lower1(p) + "[i]" for p in parameters])
648
649def main(infilename, outfile = sys.stdout):
650  t = ast.parse(file(infilename).read())
651  outfile.write(StreamStructGen(True).gen(t))
652  outfile.write(FunctionXlator().xlat(t))
653
654#
655#  Routines for compatibility with the old compiler/template.
656#  Quick and dirty hacks for now - Dec. 2010.
657#
658
659class MainLoopTransformer:
660  def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, add_assert_bitblock_align=False, main_node_id='Main'):
661       
662    self.main_module = main_module
663    self.main_node_id = main_node_id
664    self.use_C_syntax = C_syntax
665    self.add_dump_stmts = add_dump_stmts
666    self.add_assert_bitblock_align = add_assert_bitblock_align
667   
668        # Gather and partition function definition nodes.
669    stream_function_visitor = StreamFunctionVisitor(self.main_module)
670    self.stream_function_node = stream_function_visitor.stream_function_node
671    self.main_node = self.stream_function_node[main_node_id]
672    self.main_carry_count = CarryCounter().count(self.main_node)
673    self.main_adv32_count = Adv32Counter().count(self.main_node)
674    assert self.main_adv32_count == 0, "Advance32() in main not supported.\n"
675    del self.stream_function_node[self.main_node_id]
676   
677    self.stream_functions = {}
678    for key, node in self.stream_function_node.iteritems():
679                stream_function = StreamFunction()
680                stream_function.carry_count = CarryCounter().count(node)
681                stream_function.adv32_count = Adv32Counter().count(node)
682                stream_function.type_name = node.name[0].upper() + node.name[1:]
683                stream_function.instance_name = node.name[0].lower() + node.name[1:]
684                stream_function.parameters = FunctionVars(node).params
685                stream_function.declarations = BitBlock_decls_of_fn(node)
686                stream_function.initializations = StreamInitializations().xfrm(node) 
687               
688                AugAssignRemoval().xfrm(node)
689                Bitwise_to_SIMD().xfrm(node)
690                final_block_node = copy.deepcopy(node)
691                if self.use_C_syntax:
692                        carryQname = stream_function.instance_name + ".carryQ"
693                else: carryQname = "carryQ"
694                CarryIntro(carryQname).xfrm_fndef(node)
695                CarryIntro(carryQname, "_ci", "").xfrm_fndef(final_block_node)
696
697                if self.add_dump_stmts: 
698                        Add_SIMD_Register_Dump().xfrm(node)
699                        Add_SIMD_Register_Dump().xfrm(final_block_node)
700
701                if self.add_assert_bitblock_align:
702                        Add_Assert_BitBlock_Align().xfrm(node)
703                        Add_Assert_BitBlock_Align().xfrm(final_block_node)
704
705                if stream_function.carry_count > 0:
706                        node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
707               
708                stream_function.statements = Cgen.py2C(4).gen(node.body)
709                stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
710                self.stream_functions[stream_function.type_name] = stream_function
711       
712    self.emitter = Emitter(self.use_C_syntax)
713   
714  def any_carry_expr(self):
715       
716        carry_test = []
717       
718        if self.main_carry_count > 0:
719                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count)) 
720                        carry_test.append(" || ")
721
722        for key in self.stream_functions.keys():               
723                if self.stream_functions[key].carry_count > 0:
724                        carry_test.append(self.stream_functions[key].instance_name + "." + self.emitter.carry_test('carryQ',self.stream_functions[key].carry_count))# TODO Update self.emitter.carry_test
725                        carry_test.append(" || ")
726
727        if len(carry_test) > 0:
728                carry_test.pop()
729                return "".join(carry_test)
730        return "1"
731
732  def gen_globals(self):
733    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
734    for key in self.stream_functions.keys():
735                self.Cglobals += Emitter(self.use_C_syntax).definition(self.stream_functions[key],2)       
736                       
737  def gen_declarations(self): 
738    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
739    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
740    if self.main_carry_count > 0: 
741        self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
742               
743  def gen_initializations(self):
744    self.Cinits = ""
745    if self.main_carry_count > 0: 
746        self.Cinits += self.emitter.carry_init(self.main_carry_count)
747    self.Cinits += StreamInitializations().xfrm(self.main_module)
748    if self.use_C_syntax:
749                for key in self.stream_functions.keys():
750                        if self.stream_functions[key].carry_count == 0: continue
751                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
752                        self.Cinits += "CarryInit(" + self.stream_functions[key].instance_name + ".carryQ, %i);\n" % (self.stream_functions[key].carry_count)
753    else:
754                for key in self.stream_functions.keys():
755                        self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
756
757                       
758  def xfrm_block_stmts(self):
759    AugAssignRemoval().xfrm(self.main_node)
760    Bitwise_to_SIMD().xfrm(self.main_node)
761    final_block_main = copy.deepcopy(self.main_node)
762    CarryIntro().xfrm_fndef(self.main_node)
763    CarryIntro('carryQ', '_ci', '').xfrm_fndef(final_block_main)
764    if self.add_dump_stmts: 
765        Add_SIMD_Register_Dump().xfrm(self.main_node)
766        Add_SIMD_Register_Dump().xfrm(final_block_main)
767               
768    if self.add_assert_bitblock_align:
769        print "add_assert_bitblock_align"
770        Add_Assert_BitBlock_Align().xfrm(self.main_node)
771        Add_Assert_BitBlock_Align().xfrm(final_block_main)
772
773    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
774    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
775   
776    if self.main_carry_count > 0:
777                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
778                self.main_node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(self.main_carry_count)])]
779   
780   
781       
782    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
783    self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
784   
785if __name__ == "__main__":
786                import doctest
787                doctest.testmod()
Note: See TracBrowser for help on using the repository browser.