source: proto/Compiler/pablo.py @ 3589

Last change on this file since 3589 was 3571, checked in by nmedfort, 6 years ago

start of error rewriting work. some clean up done to pablo.py; a few classes in it were moved to pablo_util.py.

File size: 37.6 KB
Line 
1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3
4import ast
5import copy
6import sys
7import mkast
8import Cgen
9from carryInfo import *
10import CCGO
11import CCGO_HMCPS
12import lookAhead
13from pablo_util import *
14from pablo_error_handling import RewriteErrorStatements
15from pablo_optimizer import Optimize
16
17do_block_inline_decorator = 'IDISA_INLINE '
18do_final_block_inline_decorator = ''
19error_routine = 'raise_assert'
20experimentalMode = False
21pablo_char_type = 'char'
22rewrite_errors = False
23
24def dump_Call(fncall):
25    if isinstance(fncall.func, ast.Name):
26        print 'fn_name = %s\n' % fncall.func.id
27    elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
28        print 'fn_name = %s.%s\n' % (fncall.func.value.id, fncall.func.attr)
29    print 'len(fncall.args) = %s\n' % len(fncall.args)
30
31
32class AdvanceCombiner(ast.NodeTransformer):
33
34    def xfrm(self, t):
35        return self.generic_visit(t)
36
37    def visit_if(self, ifNode):
38        return IfNode
39
40    def visit_While(self, whileNode):
41        return whileNode
42
43    def visit_Call(self, callnode):
44        self.generic_visit(callnode)
45        if len(callnode.args) == 0:
46            return callnode
47        if not isinstance(callnode.args[0], ast.Call):
48            return callnode
49        if is_BuiltIn_Call(callnode, 'Advance', 1):
50            if is_BuiltIn_Call(callnode.args[0], 'Advance', 1):
51                callnode.args = [callnode.args[0].args[0], ast.Num(2)]
52            elif is_BuiltIn_Call(callnode.args[0], 'Advance', 2):
53                if isinstance(callnode.args[0].args[1], ast.Num):
54                    callnode.args = [callnode.args[0].args[0], ast.Num(callnode.args[0].args[1].n + 1)]
55                else:
56                    callnode.args = [callnode.args[0].args[0], ast.BinOp(callnode.args[0].args[1], ast.Add(), ast.Num(1))]
57        return callnode
58
59
60CharNameMap = {
61    '[': 'LBrak',
62    ']': 'RBrak',
63    '{': 'LBrace',
64    '}': 'LBrace',
65    '(': 'LParen',
66    ')': 'RParen',
67    '!': 'Exclam',
68    '"': 'DQuote',
69    '#': 'Hash',
70    '$': 'Dollar',
71    '%': 'PerCent',
72    '&': 'RefStart',
73    "'": 'SQuote',
74    '*': 'Star',
75    '+': 'Plus',
76    ',': 'Comma',
77    '-': 'Hyphen',
78    '.': 'Dot',
79    '/': 'Slash',
80    ':': 'Colon',
81    ';': 'Semicolon',
82    '=': 'Equals',
83    '?': 'QMark',
84    '@': 'AtSign',
85    '\\': 'BackSlash',
86    '^': 'Caret',
87    '_': 'Underscore',
88    '|': 'VBar',
89    '~': 'Tilde',
90    ' ': 'SP',
91    '\t': 'HT',
92    '\m': 'CR',
93    '\n': 'LF',
94    }
95
96
97def GetCharName(char):
98    if char >= 'a' and char <= 'z' or char >= 'A' and char <= 'Z':
99        return 'letter_' + char
100    elif char >= '0' and char <= '9':
101        return 'digit_' + char
102    else:
103        return CharNameMap[char]
104
105
106def MkCharStream(char):
107    return mkast.Qname('lex', GetCharName(char))
108
109
110def MkLookAheadExpr(v, i):
111    return mkast.call(mkast.Qname('pablo', 'LookAhead'), [v, ast.Num(i)])
112
113
114def CompileMatch(match_var, string_to_match):
115    expr = mkast.call('simd_and', [match_var, MkCharStream(string_to_match[0])])
116    for i in range(1, len(string_to_match)):
117        expr = mkast.call('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
118    return expr
119
120
121class StringMatchCompiler(ast.NodeTransformer):
122
123    def xfrm(self, t):
124        return self.generic_visit(t)
125
126    def visit_Call(self, callnode):
127        if is_BuiltIn_Call(callnode, 'match', 2):
128            ast.dump(callnode)
129            assert isinstance(callnode.args[0], ast.Str)
130            string_to_match = callnode.args[0].s
131            match_var = callnode.args[1]
132            expr = mkast.call('simd_and', [match_var, MkCharStream(string_to_match[0])])
133            for i in range(1, len(string_to_match)):
134                expr = mkast.call('simd_and', [expr, MkLookAheadExpr(MkCharStream(string_to_match[i]), i)])
135            return expr
136        else:
137            return callnode
138
139class FunctionVars(ast.NodeVisitor):
140
141    def __init__(self, node):
142        self.params = []
143        self.stores = []
144        self.generic_visit(node)
145
146    def visit_Name(self, nm):
147        if isinstance(nm.ctx, ast.Param):
148            self.params.append(nm.id)
149        if isinstance(nm.ctx, ast.Store):
150            if nm.id not in self.stores:
151                self.stores.append(nm.id)
152
153    def getLocals(self):
154        return [v for v in self.stores if not v in self.params]
155
156
157MAX_LINE_LENGTH = 80
158
159
160def BitBlock_decls_from_vars(varlist):
161    global MAX_LINE_LENGTH
162    decls = ''
163    if not len(varlist) == 0:
164        decls = "               BitBlock"
165        pending = ''
166        linelgth = 10
167        for v in varlist:
168            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
169                decls += pending + ' ' + v
170                linelgth += len(pending + v) + 1
171            else:
172                decls += ";\n           BitBlock " + v
173                linelgth = 11 + len(v)
174            pending = ','
175        decls += ';'
176    return decls
177
178
179def BitBlock_decls_of_fn(fndef):
180    return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
181
182
183def BitBlock_header_of_fn(fndef):
184    Ccode = 'static inline void ' + fndef.name + '('
185    pending = ''
186    for arg in fndef.args.args:
187        if isinstance(arg, ast.Name):
188            Ccode += pending + arg.id.upper()[0] + arg.id[1:] + ' & ' + arg.id
189            pending = ', '
190    if CarryCounter().count(fndef) > 0:
191        Ccode += pending + ' CarryQtype & carryQ'
192    Ccode += ')'
193    return Ccode
194
195
196class StreamInitializations(ast.NodeTransformer):
197
198    def xfrm(self, node):
199        self.stream_stmts = []
200        self.loop_post_inits = []
201        self.generic_visit(node)
202        return Cgen.py2C().gen(self.stream_stmts)
203
204    def visit_Assign(self, node):
205        if isinstance(node.value, ast.Num):
206            if node.value.n == 0:
207                return node
208            elif node.value.n == -1:
209                return node
210            else:
211                stream_init = copy.deepcopy(node)
212                stream_init.value = mkast.call('sisd_from_int', [node.value])
213                loop_init = copy.deepcopy(node)
214                loop_init.value.n = 0
215                self.stream_stmts.append(stream_init)
216                self.loop_post_inits.append(loop_init)
217                return None
218        else:
219            return node
220
221    def visit_FunctionDef(self, node):
222        self.generic_visit(node)
223        node.body = node.body + self.loop_post_inits
224        return node
225
226
227import CCGO_While
228
229
230def Strategic_CCGO_Factory(carryInfoSet):
231    BLOCK_SIZE = 128
232    if multicarryWhileMode:
233        ccgo = CCGO_While.CCGO_While1(BLOCK_SIZE, carryInfoSet)
234    elif experimentalMode:
235        ops = carryInfoSet.operation_count
236        if ops == 0:
237            ccgo = CCGO.CCGO()
238        elif ops <= 2:
239            ccgo = CCGO_HMCPS.HMCPS_CCGO2(BLOCK_SIZE, 64, carryInfoSet, 'carryG', '__c')
240        elif ops <= 4:
241            ccgo = CCGO_HMCPS.HMCPS_CCGO2(BLOCK_SIZE, 32, carryInfoSet, 'carryG', '__c')
242        else:
243
244            ccgo = CCGO_HMCPS.HMCPS_CCGO_BitPack2(BLOCK_SIZE, 8, carryInfoSet, 'carryG', '__c')
245    else:
246        ccgo = CCGO.testCCGO(BLOCK_SIZE, carryInfoSet, 'carryQ')
247    ccgo.allocate_all()
248    return ccgo
249
250
251class CarryIntro(ast.NodeTransformer):
252
253    def __init__(
254        self,
255        ccgo,
256        carryvar='carryQ',
257        carryin='_ci',
258        carryout='_co',
259        ):
260        self.carryvar = ast.Name(carryvar, ast.Load())
261        self.carryin = carryin
262        self.carryout = carryout
263        self.ccgo = ccgo
264
265    def xfrm_fndef(self, fndef):
266        self.block_no = 0
267        self.operation_count = 0
268        self.current_carry = 0
269        self.current_adv_n = 0
270        self.generic_visit(fndef)
271
272    def xfrm_fndef_final(self, fndef):
273        self.block_no = 0
274        self.operation_count = 0
275        self.carryout = ''
276        self.current_carry = 0
277        self.current_adv_n = 0
278        self.generic_visit(fndef)
279        return fndef
280
281    def generic_xfrm(self, node):
282        self.block_no = 0
283        self.operation_count = 0
284        self.current_carry = 0
285        self.current_adv_n = 0
286        self.last_stmt = None
287        self.last_stmt_carries = 0
288        self.generic_visit(node)
289        return node
290
291    def local_while_xfrm(self, local_carryvar, whileNode):
292        saved_state = (
293            self.block_no,
294            self.operation_count,
295            self.carryvar,
296            self.carryin,
297            self.carryout,
298            self.current_carry,
299            self.current_adv_n,
300            )
301        (self.carryvar, self.carryin, self.current_carry, self.current_adv_n) = (local_carryvar, '', 0, 0)
302
303        self.ccgo.EnterLocalWhileBlock(self.operation_count)
304        inner_while = self.generic_visit(whileNode)
305        self.ccgo.ExitLocalWhileBlock()
306        (
307            self.block_no,
308            self.operation_count,
309            self.carryvar,
310            self.carryin,
311            self.carryout,
312            self.current_carry,
313            self.current_adv_n,
314            ) = saved_state
315        return inner_while
316
317    def visit_Call(self, callnode):
318        self.generic_visit(callnode)
319
320        if is_BuiltIn_Call(callnode, 'StreamScan', 2):
321            rtn = 'StreamScan'
322            c = mkast.call(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()), ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()), ast.Name(callnode.args[1].id, ast.Load())])
323            return c
324        elif is_BuiltIn_Call(callnode, 'match', 3):
325            assert isinstance(callnode.args[1], ast.Str)
326            string_to_match = callnode.args[1].s
327            match_len = len(string_to_match)
328            match_var = callnode.args[2]
329            expr = mkast.call('pablo_blk_match<%s>' % pablo_char_type, [callnode.args[0], callnode.args[1], match_var, ast.Num(match_len)])
330            return expr
331        else:
332
333            return callnode
334
335    def visit_BinOp(self, exprnode):
336        self.generic_visit(exprnode)
337        carry_args = [ast.Num(self.current_carry)]
338        if self.carryin == '_ci':
339            carry_args = [mkast.call(self.carryvar.id + '.' + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
340        else:
341            carry_args = [mkast.call('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
342        if isinstance(exprnode.op, ast.Sub):
343            assert False, 'Subtraction no longer supported - use pablo.SpanUpTo ...'
344        elif isinstance(exprnode.op, ast.Add):
345            assert False, 'Addition no longer supported - use pablo.Scan ...'
346        else:
347            return exprnode
348
349    def visit_Assign(self, assigNode):
350        self.last_stmt_carries = CarryCounter().count(assigNode)
351        f = CheckForBuiltin(assigNode.value)
352        if f == None:
353            self.generic_visit(assigNode)
354            self.last_stmt = assigNode
355            return assigNode
356        elif isCarryGenerating(f) or isAdvance(f) and (len(assigNode.value.args) == 1 or assigNode.value.args[1].n == 1):
357
358            if self.carryin == '_ci':
359                carry_in_expr = self.ccgo.GenerateCarryInAccess(self.operation_count)
360            else:
361                carry_in_expr = mkast.var('carry_value_0')
362            callnode = assigNode.value
363            if isAdvance(f):
364                pablo_routine_call = mkast.call('pablo_blk_' + f, [assigNode.value.args[0], carry_in_expr, assigNode.targets[0]])
365            elif f in ['ScanTo', 'AdvanceThenScanTo']:
366                if self.carryout == '':
367                    scanclass = mkast.call('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
368                else:
369                    scanclass = mkast.call('simd_not', [callnode.args[1]])
370                pablo_routine_call = mkast.call('pablo_blk_' + f[:-2] + 'Thru', [callnode.args[0], scanclass, carry_in_expr, assigNode.targets[0]])
371            else:
372                pablo_routine_call = mkast.call('pablo_blk_' + f, assigNode.value.args + [carry_in_expr, assigNode.targets[0]])
373            self.last_stmt = pablo_routine_call
374            compiled = self.ccgo.GenerateCarryOutStore(self.operation_count, pablo_routine_call)
375            self.operation_count += 1
376            self.current_carry += 1
377            return compiled
378        elif isAdvance(f):
379            if self.carryin == '_ci':
380                carry_in_expr = self.ccgo.GenerateAdvanceInAccess(self.operation_count)
381            else:
382                carry_in_expr = mkast.var('carry_value_0')
383            callnode = assigNode.value
384            pablo_routine_call = mkast.call('pablo_blk_Advance_n_<%i>' % assigNode.value.args[1].n, [assigNode.value.args[0], carry_in_expr, assigNode.targets[0]])
385            self.last_stmt = pablo_routine_call
386            compiled = self.ccgo.GenerateAdvanceOutStore(self.operation_count, pablo_routine_call)
387            self.operation_count += 1
388            self.current_adv_n += 1
389            return compiled
390        else:
391
392            self.generic_visit(assigNode)
393            self.last_stmt = assigNode
394            self.operation_count += 1
395            return assigNode
396
397    def visit_If(self, ifNode):
398        self.block_no += 1
399        this_block = self.block_no
400        carry_base = self.current_carry
401        carries = CarryCounter().count(ifNode)
402
403        self.generic_visit(ifNode)
404        if carries == 0:  # or self.carryin == "":
405            self.last_stmt = ifNode
406            return ifNode
407
408        carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
409
410        if self.carryin == '':
411            new_test = ifNode.test
412        else:
413            new_test = self.ccgo.GenerateCarryIfTest(this_block, ifNode.test)
414        new_then_part = ifNode.body + self.ccgo.GenerateCarryThenFinalization(this_block)
415        new_else_part = ifNode.orelse + self.ccgo.GenerateCarryElseFinalization(this_block)
416        newIf = ast.If(new_test, new_then_part, new_else_part)
417        self.last_stmt = newIf
418        self.last_stmt_carries = carries
419        return newIf
420
421    def is_while_special_case(self, whileNode):
422
423        original_test_expr = whileNode.test.args[0]
424        if not isinstance(original_test_expr, ast.Name):
425            return False
426        test_var = original_test_expr.id
427        if not isinstance(self.last_stmt, ast.Assign):
428            return False
429        if not isinstance(whileNode.body[-1], ast.Assign):
430            return False
431        if len(self.last_stmt.targets) != 1:
432            return False
433        if len(whileNode.body[-1].targets) != 1:
434            return False
435        if not isinstance(self.last_stmt.targets[0], ast.Name):
436            return False
437        if not isinstance(whileNode.body[-1].targets[0], ast.Name):
438            return False
439        if self.last_stmt.targets[0].id != test_var:
440            return False
441        if whileNode.body[-1].targets[0].id != test_var:
442            return False
443        if self.last_stmt_carries != 1:
444            return False
445        if CarryCounter().count(whileNode.body[-1]) != 1:
446            return False
447        return True
448
449    def multicarry_While(self, whileNode):
450        self.block_no += 1
451        this_block = self.block_no
452        original_test_expr = whileNode.test.args[0]
453        if self.carryout == '':
454            whileNode.test.args[0] = mkast.call('simd_and', [original_test_expr, ast.Name('EOF_mask', ast.Load())])
455
456        self.generic_visit(whileNode)
457        local_carry_decl = self.ccgo.GenerateLocalDeclare(this_block)
458        whileNode.body = local_carry_decl + whileNode.body
459        whileNode.test = self.ccgo.GenerateCarryWhileTest(this_block, whileNode.test)
460        final_combine = self.ccgo.GenerateCarryWhileFinalization(this_block)
461        return [whileNode] + final_combine
462
463    def visit_While(self, whileNode):
464        if multicarryWhileMode:
465            return self.multicarry_While(whileNode)
466
467        self.block_no += 1
468        this_block = self.block_no
469        original_test_expr = whileNode.test.args[0]
470        if self.carryout == '':
471            whileNode.test.args[0] = mkast.call('simd_and', [original_test_expr, ast.Name('EOF_mask', ast.Load())])
472        carry_base = self.current_carry
473        assert adv_nCounter().count(whileNode) == 0, 'Advance(x,n) within while: illegal\n'
474        carries = CarryCounter().count(whileNode)
475
476        is_special = self.is_while_special_case(whileNode)
477
478        if carries == 0:
479            return whileNode
480
481        local_carryvar = ast.Name('sub' + self.carryvar.id, ast.Load())
482        inner_while = self.local_while_xfrm(local_carryvar, copy.deepcopy(whileNode))
483        self.generic_visit(whileNode)
484        local_carry_decl = self.ccgo.GenerateLocalDeclare(this_block)
485
486        inner_while.body = local_carry_decl + inner_while.body
487
488        final_combine = self.ccgo.GenerateCarryWhileFinalization(this_block)
489        inner_while.body += final_combine
490
491        if is_special:
492
493            combine1 = mkast.callStmt(ast.Attribute(self.carryvar, 'CarryCombine1', ast.Load()), [ast.Num(carry_base - 1), ast.Num((carry_base + carries) - 1)])
494            while_body_extend = [inner_while, combine1]
495
496            carry_test_arglist = [ast.Num(carry_base), ast.Num(carries - 1)]
497        else:
498            carry_test_arglist = [ast.Num(carry_base), ast.Num(carries)]
499            while_body_extend = [inner_while]
500
501        if self.carryin == '':
502            new_test = whileNode.test
503        else:
504            new_test = self.ccgo.GenerateCarryWhileTest(this_block, whileNode.test)
505        else_part = [self.ccgo.GenerateCarryElseFinalization(this_block)]
506        newIf = ast.If(new_test, whileNode.body + while_body_extend, else_part)
507        self.last_stmt = newIf
508        self.last_stmt_carries = carries
509        return newIf
510
511
512class StreamStructGen(ast.NodeVisitor):
513
514    """
515  Given a BitStreamSet subclass, generate the equivalent C struct.
516  >>> obj = ast.parse(r'''
517  ... class S1(BitStreamSet):
518  ...   a1 = 0
519  ...   a2 = 0
520  ...   a3 = 0
521  ...
522  ... class S2(BitStreamSet):
523  ...   x1 = 0
524  ...   x2 = 0
525  ... ''')
526  >>> print StreamStructGen().gen(obj)
527  struct S1 {
528    BitBlock a1;
529    BitBlock a2;
530    BitBlock a3;
531  }    self.current_adv_n = 0
532
533 
534  struct S2 {
535    BitBlock x1;
536    BitBlock x2;
537  }
538  """
539
540    def __init__(self, asType=False):
541        self.asType = asType
542
543    def gen(self, tree):
544        self.Ccode = ''
545        self.generic_visit(tree)
546        return self.Ccode
547
548    def gen_struct_types(self, tree):
549        self.asType = True
550        self.Ccode = ''
551        self.generic_visit(tree)
552        return self.Ccode
553
554    def gen_struct_vars(self, tree):
555        self.asType = False
556        self.Ccode = ''
557        self.generic_visit(tree)
558        return self.Ccode
559
560    def visit_ClassDef(self, node):
561        class_name = node.name[0].upper() + node.name[1:]
562        instance_name = node.name[0].lower() + node.name[1:]
563        self.Ccode += '  struct ' + class_name
564        if self.asType:
565            self.Ccode += ' {\n'
566            for stmt in node.body:
567                if isinstance(stmt, ast.Assign):
568                    for v in stmt.targets:
569                        if isinstance(v, ast.Name):
570                            self.Ccode += '  BitBlock ' + v.id + ';\n'
571            self.Ccode += '}'
572        else:
573            self.Ccode += ' ' + instance_name
574        self.Ccode += ''';
575
576'''
577
578
579class StreamFunctionDecl(ast.NodeVisitor):
580
581    def __init__(self):
582        pass
583
584    def gen(self, tree):
585        self.Ccode = ''
586        self.generic_visit(tree)
587        return self.Ccode
588
589    def visit_FunctionDef(self, node):
590        self.Ccode += 'static inline void ' + node.name + '('
591        pending = ''
592        for arg in node.args.args:
593            if isinstance(arg, ast.Name):
594                self.Ccode += pending + arg.id.upper()[0] + arg.id[1:] + ' & ' + arg.id
595                pending = ', '
596        self.Ccode += ');\n'
597
598
599class AssertCompiler(ast.NodeTransformer):
600
601    def __init__(self):
602        self.assert_routine = ast.parse(error_routine).body[0].value
603
604    def xfrm(self, t):
605        return self.generic_visit(t)
606
607    def visit_Expr(self, node):
608        if isinstance(node.value, ast.Call):
609            if is_BuiltIn_Call(node.value, 'assert_0', 2):
610                err_stream = node.value.args[0]
611                err_code = node.value.args[1]
612                return ast.If(mkast.call('bitblock::any', [err_stream]), [ast.Expr(mkast.call(self.assert_routine, [err_code, err_stream]))], [])
613            else:
614                return node
615        else:
616            return node
617
618
619class Add_SIMD_Register_Dump(ast.NodeTransformer):
620
621    def xfrm(self, t):
622        return self.generic_visit(t)
623
624    def visit_Assign(self, t):
625        self.generic_visit(t)
626        v = t.targets[0]
627        dump_stmt = mkast.callStmt('print_register<BitBlock>', [ast.Str(Cgen.py2C().gen(v)), v])
628        return [t, dump_stmt]
629
630
631class Add_Assert_BitBlock_Align(ast.NodeTransformer):
632
633    def xfrm(self, t):
634        return self.generic_visit(t)
635
636    def visit_Assign(self, t):
637        self.generic_visit(t)
638        v = t.targets[0]
639        dump_stmt = mkast.callStmt(' ASSERT_BITBLOCK_ALIGN', [v])
640        return [t, dump_stmt]
641
642
643class StreamFunctionCarryCounter(ast.NodeVisitor):
644
645    def __init__(self):
646        self.carry_count = {}
647
648    def count(self, node):
649        self.generic_visit(node)
650        return self.carry_count
651
652    def visit_FunctionDef(self, node):
653        type_name = node.name[0].upper() + node.name[1:]
654        self.carry_count[type_name] = CarryCounter().count(node)
655
656
657class StreamFunctionCallXlator(ast.NodeTransformer):
658
659    def __init__(self, xlate_type='normal'):
660        self.stream_function_type_names = []
661        self.xlate_type = xlate_type
662
663    def xfrm(self, node, stream_function_type_names, C_syntax):
664        self.stream_function_type_names = stream_function_type_names
665        self.C_syntax = C_syntax
666        self.generic_visit(node)
667
668    def visit_Call(self, node):
669        self.generic_visit(node)
670
671        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):  # and node.func.id in self.stream_function_type_names:
672            name = lower1(node.func.id)
673            node.func.id = name + ('_' if self.C_syntax else '.') + ('do_final_block' if self.xlate_type == 'final' else 'do_block')
674            if self.C_syntax:
675                node.args = [ast.Name(lower1(name), ast.Load())] + node.args
676            if self.xlate_type == 'final':
677                node.args = node.args + [ast.Name('EOF_mask', ast.Load())]
678
679        return node
680
681
682class StreamFunctionVisitor(ast.NodeVisitor):
683
684    def __init__(self, node):
685        self.stream_function_node = {}
686        self.generic_visit(node)
687
688    def visit_FunctionDef(self, node):
689        key = node.name[0].upper() + node.name[1:]
690        self.stream_function_node[key] = node
691
692
693class StreamFunction:
694
695    def __init__(self):
696        self.carry_count = 0
697        self.init_to_one_list = []
698        self.adv_n_count = 0
699        self.type_name = ''
700        self.instance_name = ''
701        self.parameters = ''
702        self.declarations = ''
703        self.initializations = ''
704
705    def dump(self):
706        print '%s' % self.type_name
707        print '%s=%s' % ('Carry Count', str(self.carry_count))
708        print '%s=[%s]' % ('Init to One List', ','.join(map(str, self.init_to_one_list)))
709        print '%s=%s' % ('Adv n Count', str(self.adv_n_count))
710
711
712def lower1(name):
713    return name[0].lower() + name[1:]
714
715
716def upper1(name):
717    return name[0].upper() + name[1:]
718
719
720def escape_newlines(str):
721    return str.replace('\n', '\\\n')
722
723
724class Emitter:
725
726    def __init__(self, use_C_syntax, strm_fn):
727        self.use_C_syntax = use_C_syntax
728        self.strm_fn = strm_fn
729
730    def definition(self, stream_function, icount=0):
731
732        constructor = ''
733        carry_declaration = ''
734        self.type_name = stream_function.type_name
735
736        if stream_function.carry_count > 0 or stream_function.adv_n_count > 0:
737            constructor = self.constructor(stream_function.type_name, stream_function.carry_count, stream_function.init_to_one_list, stream_function.adv_n_count)
738            carry_declaration = self.carry_declare('carryQ', stream_function.carry_count, stream_function.adv_n_count)
739
740        do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters), stream_function.declarations, stream_function.initializations, stream_function.statements)
741        clear_function = self.mk_clear(stream_function.carry_count)
742
743        do_final_block_function = self.do_final_block(self.do_final_block_parameters(stream_function.parameters), stream_function.declarations, stream_function.initializations, stream_function.final_block_statements)
744
745        do_segment_function = self.do_segment(self.do_segment_parameters(stream_function.parameters), self.do_segment_args(stream_function.parameters))
746
747        if self.use_C_syntax:
748            return self.indent(icount) + 'struct ' + stream_function.type_name + ' {' + '\n' + self.indent(icount) + carry_declaration + '\n' + self.indent(icount) + '};\n' + '\n' + self.indent(icount) + do_block_function + '\n' + self.indent(icount) + do_final_block_function + '''
749
750'''
751
752        return self.indent(icount) + 'struct ' + stream_function.type_name + ' {' + '\n' + self.indent(icount) + constructor + '\n' + self.indent(icount) + do_block_function + '\n' + self.indent(icount) + clear_function + '\n' + self.indent(icount) + do_final_block_function + '\n' + self.indent(icount) + carry_declaration + '\n' + self.indent(icount) + '''};
753
754'''
755
756    def constructor(
757        self,
758        type_name,
759        carry_count,
760        init_to_one_list,
761        adv_n_count,
762        icount=0,
763        ):
764        one_inits = self.strm_fn.ccgo.GenerateInitializations()
765
766        adv_n_decl = ''
767
768        return self.indent(icount) + '%s() { \n' % type_name + adv_n_decl + self.carry_init(carry_count) + one_inits + ' }'
769
770    def mk_clear(self, carry_count, icount=0):
771        one_inits = self.strm_fn.ccgo.GenerateInitializations()
772        return self.indent(icount) + 'IDISA_INLINE void clear() { \n' + self.carry_init(carry_count) + one_inits + ' }'
773
774    def do_block(
775        self,
776        parameters,
777        declarations,
778        initializations,
779        statements,
780        icount=0,
781        ):
782        pfx = lower1(self.type_name) + '_' if self.use_C_syntax else ''
783        if self.use_C_syntax:
784            return '#define ' + pfx + 'do_block(' + parameters + ')\\\n do {' + '\\\n' + self.indent(icount) + escape_newlines(declarations) + '\\\n' + self.indent(icount) + escape_newlines(initializations) + '\\\n' + self.indent(icount) + escape_newlines(statements) + '\\\n' + self.indent(icount + 2) + '} while (0)'
785        return self.indent(icount) + do_block_inline_decorator + 'void ' + pfx + 'do_block(' + parameters + ') {' + '\n' + self.indent(icount) + declarations + '\n' + self.indent(icount) + initializations + '\n' + self.indent(icount) + statements + '\n' + self.indent(icount + 2) + '}'
786
787    def do_final_block(
788        self,
789        parameters,
790        declarations,
791        initializations,
792        statements,
793        icount=0,
794        ):
795        pfx = lower1(self.type_name) + '_' if self.use_C_syntax else ''
796        if self.use_C_syntax:
797            return '#define ' + pfx + 'do_final_block(' + parameters + ')\\\n do {' + '\\\n' + self.indent(icount) + escape_newlines(declarations) + '\\\n' + self.indent(icount) + escape_newlines(initializations) + '\\\n' + self.indent(icount) + escape_newlines(statements) + '\\\n' + self.indent(icount + 2) + '} while (0)'
798        return self.indent(icount) + do_final_block_inline_decorator + 'void ' + pfx + 'do_final_block(' + parameters + ') {' + '\n' + self.indent(icount) + declarations + '\n' + self.indent(icount) + initializations + '\n' + self.indent(icount) + statements + '\n' + self.indent(icount + 2) + '}'
799
800    def do_segment(
801        self,
802        parameters,
803        do_block_call_args,
804        icount=0,
805        ):
806        pfx = lower1(self.type_name) + '_' if self.use_C_syntax else ''
807        if self.use_C_syntax:
808            return '#define ' + pfx + 'do_segment(' + parameters + ')\\\n do {' + '\\\n' + self.indent(icount) + '  int i;' + '\\\n' + self.indent(icount) + '  for (i = 0; i < segment_blocks; i++)' + '\\\n' + self.indent(icount) + '    ' + pfx + 'do_block(' + do_block_call_args + ');' + '\\\n' + self.indent(icount + 2) + '} while (0)'
809        return self.indent(icount) + 'void ' + pfx + 'do_segment(' + parameters + ') {' + '\n' + self.indent(icount) + '  int i;' + '\n' + self.indent(icount) + '  for (i = 0; i < segment_blocks; i++)' + '\n' + self.indent(icount) + '    ' + pfx + 'do_block(' + do_block_call_args + ');' + '\n' + self.indent(icount + 2) + '}'
810
811    def declaration(
812        self,
813        type_name,
814        instance_name,
815        icount=0,
816        ):
817        if self.use_C_syntax:
818            return self.indent(icount) + 'struct ' + type_name + ' ' + instance_name + ';\n'
819        return self.indent(icount) + type_name + ' ' + instance_name + ';\n'
820
821    def carry_init(self, carry_count, icount=0):
822
823        return ''
824
825    def carry_declare(
826        self,
827        carry_variable,
828        carry_count,
829        adv_n_count=0,
830        icount=0,
831        ):
832        adv_n_decl = ''
833
834        return self.indent(icount) + self.strm_fn.ccgo.GenerateCarryDecls()
835
836    def carry_test(
837        self,
838        carry_variable,
839        carry_count,
840        icount=0,
841        ):
842
843        return self.indent(icount) + 'carryQ.CarryTest(0, %i)' % carry_count
844
845    def indent(self, icount):
846        s = ''
847        for i in range(0, icount):
848            s += ' '
849        return s
850
851    def do_block_parameters(self, parameters):
852        if self.use_C_syntax:
853
854            return ', '.join([lower1(self.type_name)] + [lower1(p) for p in parameters])
855        else:
856            normal_parms = [upper1(p) + ' & ' + lower1(p) for p in parameters]
857            lookahead_parms = [upper1(p) + ' & ' + lower1(p) + '_ahead' for p in parameters if self.strm_fn.lookahead_info.LookAheadSet.has_key(p)]
858            return ', '.join(normal_parms + lookahead_parms)
859
860    def do_final_block_parameters(self, parameters):
861        if self.use_C_syntax:
862
863            return ', '.join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ['EOF_mask'])
864        else:
865            return ', '.join([upper1(p) + ' & ' + lower1(p) for p in parameters] + ['BitBlock EOF_mask'])
866
867    def do_segment_parameters(self, parameters):
868        if self.use_C_syntax:
869
870            return ', '.join([lower1(self.type_name)] + [lower1(p) for p in parameters] + ['segment_blocks'])
871        else:
872            return ', '.join([upper1(p) + ' ' + lower1(p) + '[]' for p in parameters] + ['int segment_blocks'])
873
874    def do_segment_args(self, parameters):
875        if self.use_C_syntax:
876            return ', '.join([lower1(self.type_name)] + [lower1(p) + '[i]' for p in parameters])
877        else:
878            return ', '.join([lower1(p) + '[i]' for p in parameters])
879
880
881def main(infilename, outfile=sys.stdout):
882    t = ast.parse(file(infilename).read())
883    outfile.write(StreamStructGen(True).gen(t))
884    outfile.write(FunctionXlator().xlat(t))
885
886
887class MainLoopTransformer:
888
889    def __init__(self, main_module, C_syntax=False, add_dump_stmts=False, add_assert_bitblock_align=False, dump_func_data=False, main_node_id='Main'):
890        self.main_module = main_module
891        self.main_node_id = main_node_id
892        self.use_C_syntax = C_syntax
893        mkast.use_C_syntax = self.use_C_syntax
894        self.add_dump_stmts = add_dump_stmts
895        self.add_assert_bitblock_align = add_assert_bitblock_align
896        self.dump_func_data = dump_func_data
897       
898        stream_function_visitor = StreamFunctionVisitor(self.main_module)
899        self.stream_function_node = stream_function_visitor.stream_function_node
900        for (key, node) in self.stream_function_node.iteritems():
901            AdvanceCombiner().xfrm(node)
902        self.main_node = self.stream_function_node[main_node_id]
903        self.main_carry_count = CarryCounter().count(self.main_node)
904        self.main_adv_n_count = adv_nCounter().count(self.main_node)
905        self.main_carry_info_set = CarryInfoSetVisitor(self.main_node)
906        self.main_ccgo = Strategic_CCGO_Factory(self.main_carry_info_set)
907        assert self.main_adv_n_count == 0, 'Advance32() in main not supported.\n'
908        del self.stream_function_node[self.main_node_id]
909
910        self.stream_functions = {}
911        for (key, node) in self.stream_function_node.iteritems():
912            stream_function = StreamFunction()
913            stream_function.carry_count = CarryCounter().count(node)
914            stream_function.init_to_one_list = CarryInitToOneList().count(node)
915            stream_function.adv_n_count = adv_nCounter().count(node)
916            carry_info_set = CarryInfoSetVisitor(node)
917
918            stream_function.lookahead_info = lookAhead.LookAheadInfoSetVisitor(node)
919            stream_function.ccgo = Strategic_CCGO_Factory(carry_info_set)
920            stream_function.type_name = node.name[0].upper() + node.name[1:]
921            stream_function.instance_name = node.name[0].lower() + node.name[1:]
922            stream_function.parameters = FunctionVars(node).params
923            stream_function.declarations = BitBlock_decls_of_fn(node)
924            stream_function.declarations += '\n' + stream_function.ccgo.GenerateStreamFunctionDecls()
925            stream_function.initializations = StreamInitializations().xfrm(node)
926
927            tempifier = TempifyBuiltins()
928           
929            AugAssignRemoval().xfrm(node)
930            if self.add_dump_stmts:
931                Add_SIMD_Register_Dump().xfrm(node)           
932            tempifier.xfrm(node)
933
934            Bitwise_to_SIMD().xfrm(node)
935            final_block_node = copy.deepcopy(node)
936           
937            RewriteEOF().xfrm(node, False)
938            if rewrite_errors:
939                RewriteErrorStatements(tempifier).xfrm(node)           
940            Optimize().xfrm(node)
941           
942            RewriteEOF().xfrm(final_block_node, True) 
943            Optimize().xfrm(final_block_node)
944           
945            if self.use_C_syntax:
946                carryQname = stream_function.instance_name + '.carryQ'
947            else:
948                carryQname = 'carryQ'
949            CarryIntroVisitor = CarryIntro(stream_function.ccgo, carryQname)
950
951            lookAhead.LookAheadTransformer(stream_function, 'nonfinal').xfrm(node)
952            lookAhead.LookAheadTransformer(stream_function, 'final').xfrm(final_block_node)
953
954            stream_function.declarations += '\n' + BitBlock_decls_from_vars(tempifier.tempVars())
955           
956            CarryIntroVisitor.xfrm_fndef(node)
957            CarryIntroVisitor.xfrm_fndef_final(final_block_node)
958
959            AssertCompiler().xfrm(node)
960            AssertCompiler().xfrm(final_block_node)
961
962            if self.add_assert_bitblock_align:
963                Add_Assert_BitBlock_Align().xfrm(node)
964                Add_Assert_BitBlock_Align().xfrm(final_block_node)
965
966            node.body += stream_function.ccgo.GenerateStreamFunctionFinalization()
967
968            stream_function.statements = Cgen.py2C(4).gen(node.body)
969            stream_function.final_block_statements = Cgen.py2C(4).gen(final_block_node.body)
970            self.stream_functions[stream_function.type_name] = stream_function
971
972        if self.dump_func_data:
973            for (key, value) in self.stream_functions.iteritems():
974                value.dump()
975            sys.exit()
976
977        self.emitter = Emitter(self.use_C_syntax, stream_function)
978
979    def any_carry_expr(self):
980
981        tests = [self.stream_functions[key].ccgo.GenerateTestAll(self.stream_functions[key].instance_name) for key in self.stream_functions.keys()]
982        return ' || '.join([Cgen.py2C().gen(t) for t in tests])
983
984    def gen_globals(self):
985        self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
986        for key in self.stream_functions.keys():
987            sf = self.stream_functions[key]
988            self.Cglobals += Emitter(self.use_C_syntax, sf).definition(sf, 2)
989
990    def gen_declarations(self):
991        self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
992        self.Cdecls += BitBlock_decls_of_fn(self.main_node) + '\n' + self.main_ccgo.GenerateStreamFunctionDecls()
993        if self.main_carry_count > 0:
994            self.Cdecls += self.emitter.carry_declare('carryQ', self.main_carry_count)
995
996    def gen_initializations(self):
997        self.Cinits = ''
998        if self.main_carry_count > 0:
999            self.Cinits += self.emitter.carry_init(self.main_carry_count)
1000        self.Cinits += StreamInitializations().xfrm(self.main_module)
1001        if self.use_C_syntax:
1002            for key in self.stream_functions.keys():
1003                if self.stream_functions[key].carry_count == 0:
1004                    continue
1005                self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
1006                self.Cinits += 'CarryInit(' + self.stream_functions[key].instance_name + '.carryQ, %i);\n' % self.stream_functions[key].carry_count
1007        else:
1008            for key in self.stream_functions.keys():
1009                self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
1010
1011    def xfrm_block_stmts(self):
1012
1013        AugAssignRemoval().xfrm(self.main_node)
1014        Bitwise_to_SIMD().xfrm(self.main_node)       
1015        final_block_main = copy.deepcopy(self.main_node)
1016       
1017        RewriteEOF().xfrm(self.main_node, False) 
1018        RewriteEOF().xfrm(final_block_main, True)
1019       
1020        carry_info_set = CarryInfoSetVisitor(self.main_node)
1021
1022        AssertCompiler().xfrm(self.main_node)
1023        if self.add_dump_stmts:
1024            Add_SIMD_Register_Dump().xfrm(self.main_node)
1025            Add_SIMD_Register_Dump().xfrm(final_block_main)
1026
1027        if self.add_assert_bitblock_align:
1028            print 'add_assert_bitblock_align'
1029            Add_Assert_BitBlock_Align().xfrm(self.main_node)
1030            Add_Assert_BitBlock_Align().xfrm(final_block_main)
1031
1032        StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys(), self.use_C_syntax)
1033        StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
1034
1035        self.Cstmts = Cgen.py2C().gen(self.main_node.body)
1036        self.Cfinal_stmts = Cgen.py2C().gen(final_block_main.body)
1037
1038
1039if __name__ == '__main__':
1040    import doctest
1041    doctest.testmod()
Note: See TracBrowser for help on using the repository browser.