source: proto/parabix2/Compiler/py2bitexpr.py @ 352

Last change on this file since 352 was 352, checked in by eamiri, 10 years ago

arrays indexed by string are supported

File size: 49.9 KB
Line 
1#
2# py2bitexpr.py
3#
4# (c) 2009 Robert D. Cameron with modifications by Ehsan Amiri
5# All rights reserved.
6# Licensed to International Characters, Inc. under Academic Free License 3.0
7#
8# Translate unbounded bitstream code in "PyBit" form into our internal
9# representation.
10# Pybit form: series of sequential assignment statements using
11# bitwise logice operations, Advance, ScanThru.
12#
13# Requires ast module of python 2.6
14
15import ast, bitexpr, copy, basic_block
16
17AllOne = 'AllOne'
18AllZero = 'AllZero'
19
20class PyBitError(Exception):
21        pass
22
23#############################################################################################
24## Function Inlining
25#############################################################################################
26TEMP_VAR_TEMPLATE = "InlineTemp%i"
27
28
29def replace_in_exp(exp, translation):
30    if isinstance(exp, ast.BinOp):
31        exp.left = replace_in_exp(exp.left, translation)
32        exp.right = replace_in_exp(exp.right, translation)
33    elif isinstance(exp, ast.UnaryOp):
34        exp.operand = replace_in_exp(ast.UnaryOp)
35    elif isinstance(exp, ast.Name):
36        exp.id = translation.setdefault(exp.id, exp.id)
37    elif isinstance(exp.ast.Attribute):
38        exp.value.id = translation.setdefault(exp.value.id, exp.value.id)
39    elif isinstance(exp, ast.Subscript):
40        exp.value.id = translation.setdefault(exp.value.id, exp.value.id)
41    else:
42        assert(1==0)
43
44    return exp
45
46def get_all_calls(main):
47    call_list = []
48    for index, loc in enumerate(main.body):
49        if isinstance(loc, ast.Assign) and isinstance(loc.value, ast.Call):
50            call_list.append(index)
51        if isinstance(loc, ast.AugAssign) and isinstance(loc.value, ast.Call):
52            call_list.append(index)
53    return call_list
54
55
56def prepare_code(callee, args, unique_num):
57    #check number of args matches
58    assert (len(args)==len(callee.args.args))
59    translation = {}
60    for index in range(len(args)):
61        translation[translate_var(callee.args.args[index])] = translate_var(args[index])
62
63    #print translation
64    rets = {}
65    for index, loc in enumerate(callee.body):
66        if isinstance(loc, ast.Assign):
67            for index, var in enumerate(loc.targets):
68                var = ast.Name(id=translation.setdefault(translate_var(var), translate_var(var)))
69            loc.value = replace_in_exp(loc.value, translation)
70        elif isinstance(loc, ast.AugAssign):
71            loc.target = ast.Name(id=translation.setdefault(translate_var(loc.target), translate_var(loc.target)))
72            loc.value = replace_in_exp(loc.value, translation)
73        elif isinstance(loc, ast.Return):
74            rets[index] = ast.Assign([ast.Name(id=TEMP_VAR_TEMPLATE%unique_num)], loc.value)
75
76        else:
77            assert(1==0)
78
79        for key in rets:
80            callee.body[key] = rets[index]
81
82    return callee
83
84def finalize(main, inline_code, call_list):
85    for list_index, main_index in enumerate(reversed(call_list)):
86        main.body[main_index].value = ast.Name(id=TEMP_VAR_TEMPLATE%main_index)
87        main.body = main.body[:main_index]+inline_code[-list_index-1].body+main.body[main_index:]
88    return main
89   
90def do_inlining(module):
91    func_dict = {}
92    for index, func in enumerate(module.body):
93        func_dict[func.name] = index
94
95    main = module.body[func_dict["main"]]
96    call_list = get_all_calls(main)
97
98    inline_code = []
99    for line_no in call_list:
100        callee = copy.deepcopy(module.body[func_dict[main.body[line_no].value.func.id]])
101        args = main.body[line_no].value.args
102        inline_code.append(prepare_code(callee, args, line_no))
103    main = finalize(main, inline_code, call_list)
104
105    return call_list, module.body[func_dict["main"]].body
106
107#############################################################################################
108## Translation to bitExpr classes --- Pass 2
109#############################################################################################
110
111def translate_index(ast_value):
112        if isinstance(ast_value, ast.Str): return ast_value.s
113        elif isinstance(ast_value, ast.Num): return repr(ast_value.n)
114        else: raise PyBitError("Unknown value %s\n" % ast.dump(ast_value))
115
116def is_Advance_Call(fncall):
117        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
118        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
119                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
120        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
121
122def is_ScanThru_Call(fncall):
123        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
124        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
125                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
126        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
127
128def translate(ast_expr):
129        if isinstance(ast_expr, ast.Name):
130                if ast_expr.id=="allzero":
131                        return bitexpr.FalseLiteral()
132                if ast_expr.id=="allone":
133                        return bitexpr.TrueLiteral()
134                return bitexpr.Var(ast_expr.id)
135        elif isinstance(ast_expr, ast.Attribute):
136                if isinstance(ast_expr.value, ast.Name):
137                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, ast_expr.attr))
138                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(ast_expr))
139        elif isinstance(ast_expr, ast.Subscript):
140                if isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice.value, ast.Num):
141                        return bitexpr.Var("%s[%s]" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
142                elif isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice.value, ast.Str):
143                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
144                else: raise PyBitError("Illegal array %s\n" % ast.dump(ast_expr))
145        elif isinstance(ast_expr, ast.UnaryOp):
146                e1 = translate(ast_expr.operand)
147                if isinstance(ast_expr.op, ast.Invert):
148                        return bitexpr.make_not(e1)
149                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
150        elif isinstance(ast_expr, ast.BinOp):
151                e1 = translate(ast_expr.left)
152                e2 = translate(ast_expr.right)
153                if isinstance(ast_expr.op, ast.BitOr):
154                        return bitexpr.make_or(e1, e2)
155                elif isinstance(ast_expr.op, ast.BitAnd):
156                        return bitexpr.make_and(e1, e2)
157                elif isinstance(ast_expr.op, ast.BitXor):
158                        return bitexpr.make_xor(e1, e2)
159                elif isinstance(ast_expr.op, ast.Add):
160                        return bitexpr.make_add(e1, e2)
161                elif isinstance(ast_expr.op, ast.Sub):
162                        return bitexpr.make_sub(e1, e2)
163                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
164        elif isinstance(ast_expr, ast.Call):
165                if is_Advance_Call(ast_expr):
166                        e = translate(ast_expr.args[0])
167                        temp = bitexpr.make_add(e,e)
168                        return temp
169                elif is_ScanThru_Call(ast_expr):
170                        e0 = translate(ast_expr.args[0])
171                        e1 = translate(ast_expr.args[1])
172                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
173                else: raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
174        elif isinstance(ast_expr, ast.Compare):
175                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
176                        e0 = translate(ast_expr.left)
177                        return bitexpr.isNoneZero(e0)
178                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
179        else: raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
180
181def translate_var(v):
182        if isinstance(v, ast.Name):
183                return v.id
184        elif isinstance(v, ast.Attribute):
185                if isinstance(v.value, ast.Name):
186                        return "%s.%s" % (v.value.id, v.attr)
187                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(v))
188        elif isinstance(v, ast.Subscript):
189                if isinstance(v.value, ast.Name) and isinstance(v.slice, ast.Index):
190                        if isinstance(v.slice.value, ast.Num):
191                            return "%s[%s]" % (v.value.id, translate_index(v.slice.value))
192                        elif isinstance(v.slice.value, ast.Str):
193                            return "%s.%s" % (v.value.id, translate_index(v.slice.value))
194                        else:
195                           
196                            assert(1==0)
197                else: raise PyBitError("Unknown operator %s\n" % ast.dump(v))
198
199
200def translate_stmts(ast_stmts):
201        translated = []
202        for s in ast_stmts:
203                if isinstance(s, ast.Expr):
204                        if (s.value.func.id=='optimize'):
205                            target = translate_var(s.value.args[0])
206                            replace = translate_var(s.value.args[1])
207                            translated.append(bitexpr.Reduce(target, replace))
208                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
209                elif isinstance(s, ast.Assign):
210                        e = translate(s.value)
211                        for astv in s.targets: 
212                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
213                elif isinstance(s, ast.AugAssign):
214                        v = bitexpr.Var(translate_var(s.target))
215                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
216                elif isinstance(s, ast.While):
217                        e = translate(s.test)
218                        body = translate_stmts(s.body)
219                        translated.append(bitexpr.WhileLoop(e, body))
220                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
221        return translated
222
223#############################################################################################
224## Conversion to SSA form --- Pass 3
225#############################################################################################
226
227def extract_vars(rhs):
228    if isinstance(rhs, bitexpr.Var):
229        return [rhs.varname]
230    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
231        return []
232    if isinstance(rhs, bitexpr.Not):
233        return extract_vars(rhs.operand1)
234    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
235    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
236
237def update_lineno(inner_table, shift):
238    for key in inner_table:
239        for index in range(len(inner_table[key][0])):
240            inner_table[key][0][index] += shift
241        for index in range(len(inner_table[key][1])):
242            inner_table[key][1][index] += shift
243    return inner_table
244
245def merge_tables(table, inner_table):
246    for key in inner_table:
247        if key in table:
248            table[key][0] += inner_table[key][0]
249            table[key][1] += inner_table[key][1]
250        else:
251            table[key] = inner_table[key]
252    return table
253
254def gen_sym_table(code, goInside = False):
255    """Generate a simple symbol table for a three address code
256        each entry is of this form: var_name:[[defs][uses]]
257    """
258    table = {}
259    index = 0
260    for stmt in code:
261        if isinstance(stmt, bitexpr.Reduce):
262            if stmt.target in table:
263                table[stmt.target][1].append(index)
264            else:
265                table[stmt.target] = [[], [index]]
266            index += 1
267        elif isinstance(stmt, bitexpr.BitAssign):
268            current = stmt.LHS.varname
269            if current in table:
270                table[current][0].append(index)
271            else:
272                table[current] = [[index],[]]
273
274            varlist = extract_vars(stmt.RHS)
275            for var in varlist:
276                if var in table:
277                    table[var][1].append(index)
278                else:
279                    table[var] = [[], [index]]
280            index += 1
281        elif isinstance(stmt, bitexpr.WhileLoop):
282            cond_var = stmt.control_expr.var.varname
283            if cond_var in table:
284                table[cond_var][1].append(index)
285            else:
286                #while loop conditioned on an undefined variable
287                assert(1==0)
288            if goInside:
289                inner_table = gen_sym_table(stmt.stmts, True)
290                inner_table = update_lineno(inner_table, index+1)
291                table = merge_tables(table, inner_table)
292                index += (1+len(stmt.stmts))
293            else:
294                index += 1
295        else:
296            assert(1==0)
297    return table
298##################################################################################
299def get_line(code, line):
300
301    lineno = 0
302    for stmt in code:
303        if isinstance(stmt, bitexpr.BitAssign):
304            if lineno == line:
305                return stmt
306            lineno += 1
307        elif isinstance(stmt, bitexpr.WhileLoop):
308            if lineno == line:
309                return stmt
310            elif lineno+len(stmt.stmts) < line:
311                lineno += 1
312                line -= len(stmt.stmts)
313                continue
314            else:
315                return get_line(stmt.stmts, line-(lineno+1))
316
317        elif isinstance(stmt, bitexpr.Reduce):
318            if lineno == line:
319                return stmt
320            lineno += 1
321        else:
322            assert(1==0)
323
324def update_def(code, var, line, suffix_num):
325    #for i in code:
326    #    print i
327    #print "------------------------", len(code), line, var
328    loc = code[line]
329    if isinstance(loc, bitexpr.BitAssign):
330        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
331        return loc
332    else:
333        #either Reduce or While Loop in both cases it's a use
334        pass
335
336def update_rhs(rhs, varname, newname):
337    if isinstance(rhs, bitexpr.Var):
338        if rhs.varname == varname:
339            #if varname == "u8.unibyte":
340            #    print ")))))))))))))))))"
341           
342            rhs.varname = newname
343        return rhs
344
345    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
346        return rhs
347
348    if isinstance(rhs, bitexpr.Not):
349        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
350        return rhs
351    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
352    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
353    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
354    return rhs
355
356def update_use(code, var, line, suffix_num):
357    loc = code[line]
358    if isinstance(loc, bitexpr.BitAssign):
359        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
360    elif isinstance(loc, bitexpr.Reduce):
361        pass
362    elif isinstance(loc, bitexpr.WhileLoop):
363        loc.control_expr.var.varname += "_%i"%suffix_num
364    return loc
365
366def parse_var(var):
367    index=var.find('.')
368    if index >= 0:
369        return ('struct', var[0:index], var[index+1:])
370               
371    index = var.find('[')
372    if index >= 0:
373        right_index = var.find(']')
374        if var[index+1:right_index].isdigit():
375           return ('array', var[0:index], var[index+1:right_index])     
376        else:
377            return ('struct', var[0:index], var[index+1:right_index])
378   
379    if var.startswith("carry") or var.startswith("Carry"):
380        return('int', var, None)
381   
382    return ('bitblock', var, None)
383
384def simplify_name(var):
385    (vartype, name, extra) = parse_var(var)
386    if vartype == 'bitblock':
387        return name
388    if vartype == 'array':
389        return "a_%s_%s"%(name, extra)
390    if vartype == 'struct':
391        return "s_%s_%s"%(name, extra)
392    if vartype == 'int':
393        assert(1==0)
394    assert(1==0)
395
396def pairs(lst):
397    if lst == []:
398            return []
399    return zip(lst,lst[1:]+[lst[0]])
400
401
402def make_SSA(code, st):
403    new_vars = []
404    total_lines = len(code)
405    for var in st:
406        st[var][0].append(total_lines)
407        st[var][1].append(total_lines)
408
409    unique_number = 0
410    for var in st:
411        use_index = 0
412        for current, next in pairs(st[var][0])[0:-2]:
413            code[current] = update_def(code, var, current, unique_number)
414            uline = st[var][1][use_index]
415
416            while uline <= next and uline < total_lines:
417                if uline > current:
418                    code[uline] = update_use(code, var, uline, unique_number)
419                use_index += 1
420                uline = st[var][1][use_index]
421            unique_number += 1
422    return code
423
424#################################################################################################################
425## Breaking the code into basic blocks depending on where optimize pragma appears --- Pass 4
426#################################################################################################################
427def get_opt_list(s):
428    opt_list = []
429    remove_index = []
430    for line, stmt in enumerate(s):
431        if isinstance(stmt, bitexpr.Reduce):
432            opt_pair = (stmt.target, stmt.replace)
433            opt_list.append(opt_pair)
434            remove_index.append(line)
435    for ind in reversed(remove_index):
436        del s[ind]
437
438    return opt_list
439
440def break_to_segs(s, breaks):
441    segs = []
442    breaks = [-1]+breaks+[len(s)]
443    for start, stop in pairs(breaks)[:-1]:
444        #print start, stop
445        next = s[start+1:stop]
446        if len(next)>0:
447            segs.append(next)
448    return segs
449
450def extract_pragmas(s):
451    targets = []
452    exprs = []
453    lines = []
454
455    #extracting pragmas and their targets
456    for line, loc in enumerate(s):
457        if isinstance(loc, bitexpr.Reduce):
458            targets.append(loc.target)
459            exprs.append(loc)
460            lines.append(line)
461
462    #removing pragmas from the source code
463    for i in reversed(lines):
464        del s[i]
465
466    return targets, exprs
467
468def get_defs(targets, s):
469    temp = [(x,-1) for x in targets]
470    targ_dic = dict(temp)
471    for line, loc in enumerate(s):
472        if loc.LHS.varname in targets:
473            targ_dic[loc.LHS.varname] = line
474
475    items = targ_dic.items()
476    items = sorted(items, key=(lambda x: x[1]))
477
478    lineno = [key for value, key in items]
479    sorted_targets = [value for value, key in items]
480
481    return sorted_targets, lineno
482
483def get_boundary(def_line, lastuse_line):
484    earliest = min(def_line)
485
486    # The indices of all variables defined in the same line of code and earlier than all other variables
487    # There is more than one such variable only if these variables are input variables
488    all_early = [] 
489    all_early_cnt = 0
490    j = 0
491
492    while earliest in def_line[j:]:
493        m = def_line[j:].index(earliest)
494        all_early_cnt += 1
495        all_early.append(m+j)
496        j += m+1
497
498    #all_early contains the indices of the all earliest defined vars
499    lasts = []
500    map(lambda y: lasts.append(lastuse_line[y]), all_early)
501
502    latest = max(lasts)
503    return earliest, latest
504
505def adjust_latest(s, latest):
506    line = 0
507    for stmt in s:
508        if isinstance(stmt, bitexpr.BitAssign):
509            if latest == line:
510                return latest
511            line += 1
512        if isinstance(stmt, bitexpr.WhileLoop):
513            end_of_loop = line+len(stmt.stmts)
514            if latest <= end_of_loop:
515                return end_of_loop
516            else:
517                line = end_of_loop+1
518    return line
519
520#There is slight difference between the function below and adjust_latest (look at the second return)
521def adjust_earliest(s, earliest):
522    line = 0
523    if earliest == -1:
524        return -1
525
526    for stmt in s:
527        if isinstance(stmt, bitexpr.BitAssign):
528            if earliest == line:
529                return earliest
530            line += 1
531        if isinstance(stmt, bitexpr.WhileLoop):
532            end_of_loop = line+len(stmt.stmts)
533            if earliest <= end_of_loop:
534                return end_of_loop+1
535            else:
536                line = end_of_loop+1
537
538def chop_code(s, start, stop):
539    line = 0
540    if start == -1:
541        cut1 = -1
542   
543    if stop == len(s):
544        cut2 = len(s)
545   
546    for index, stmt in enumerate(s):
547        if isinstance(stmt, bitexpr.BitAssign):
548            if line == start:
549                cut1 = index
550            if line == stop:
551                cut2 = index
552            line += 1
553
554        if isinstance(stmt, bitexpr.WhileLoop):
555            end_of_loop = line+len(stmt.stmts)
556            if line == start:
557                cut1 = index
558            if end_of_loop == stop:
559                cut2 = index
560            line = end_of_loop+1
561
562    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
563
564def count_lines(code):
565    cnt = len(code)
566    for stmt in code:
567        if isinstance(stmt, bitexpr.WhileLoop):
568            cnt += count_lines(stmt.stmts)
569    return cnt
570
571
572def gen_bb(s, opt_list, def_line, lastuse_line):
573    assert (len(opt_list) == len(def_line) == len(lastuse_line))
574    if opt_list==[]:
575        return s
576
577    earliest, latest = get_boundary(def_line, lastuse_line)
578
579    indices = []
580    for index, i in enumerate(def_line): #was enumerate(lastuse_line)
581        if i <= latest:
582            indices.append(index)
583    #if latest or earliest are inside a while loop they are pushed to the end of while loop or beginning of the next block (respectively)
584
585    e = earliest
586    l = latest
587
588    opt1 = {}
589    def1 = {}
590    use1 = {}
591    for index, i in enumerate(opt_list):
592        opt1.setdefault(index in indices, []).append(i)
593    for index, i in enumerate(def_line):
594        def1.setdefault(index in indices, []).append(i)
595    for index, i in enumerate(lastuse_line):
596        use1.setdefault(index in indices, []).append(i)
597
598    latest = max(use1[True])
599    latest = adjust_latest(s, latest)
600    earliest = adjust_earliest(s, earliest)
601
602    #use earliest to construct an initial block with no if-then-else
603    #use use1[True], def1[True], opt1[True] to construct if then else and recurse on inner blocks
604    #use use1[False], def1[False], opt1[False] to construct the block after if-then-else and recurse on that block
605    first, second, third = chop_code(s, earliest, latest)
606    the_opt = None
607    the_index = None
608    for index, value in enumerate(opt1[True]):
609        if e==def1[True][index] and l==use1[True][index]:
610            the_opt = value
611            the_index = index
612    del opt1[True][the_index]
613    del use1[True][the_index]
614    del def1[True][the_index]
615
616    second = gen_bb(second, opt1[True], def1[True],use1[True])
617
618    if the_opt[1] == 'allzero':
619        cond_obj = bitexpr.isAllZero(the_opt[0])
620    if the_opt[1] == 'allone':
621        cond_obj = bitexpr.isAllOne(the_opt[0])
622
623    change = count_lines(first) + count_lines(second)
624    new_def = []
625    new_use = []
626
627    if (False in opt1):
628        for i in def1[False]:
629            new_def.append(max(i-change, -1))
630        for i in use1[False]:
631            new_use.append(i-change)
632
633        third = gen_bb(third, opt1[False], new_def, new_use)
634
635    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))]+third
636
637    return result
638
639def partition2bb(s):
640    basicblocks = []
641    lineno = []
642    #####
643    def_line = []
644    lastuse_line = []
645    opt_list = get_opt_list(s)
646
647    st = gen_sym_table(s, True)
648    st2 = gen_sym_table(s)
649    total_lines = count_lines(s)
650
651    for item in opt_list:
652        if len(st2[item[0]][0]) > 0:
653            #The last definition of the variable is extracted
654            def_line.append(st2[item[0]][0][-1])
655        else:
656            #This variable is an input variable and not defined by the programmer
657            def_line.append(-1)
658
659        #if len(st[item[0]][1]) > 0:
660            #last use of the variables is extracted
661        #    lastuse_line.append(st[item[0]][1][-1])
662        #else:
663            #The variable is not used in the code anywhere, we assume it is needed at the end
664        lastuse_line.append(total_lines)
665
666    return gen_bb(s, opt_list, def_line, lastuse_line)
667
668#################################################################################################################
669## Generating declarations for variables. This pass is based on the syntax of C programming language ---
670## Normalizing the code
671#################################################################################################################
672
673def normalize(s, predec = {}, ccelim=True):
674    if len(s)==0:
675        return []
676    if isinstance(s[0], bitexpr.If):
677        gc = basic_block.BasicBlock.gensym_counter
678        cc = basic_block.BasicBlock.carry_counter
679        bc = basic_block.BasicBlock.brw_counter
680        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
681        maxgc = basic_block.BasicBlock.gensym_counter
682        maxcc = basic_block.BasicBlock.carry_counter
683        maxbc = basic_block.BasicBlock.brw_counter
684        basic_block.BasicBlock.gensym_counter = gc
685        basic_block.BasicBlock.carry_counter = cc
686        basic_block.BasicBlock.brw_counter = bc
687        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
688        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
689        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
690        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
691       
692        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
693    if isinstance(s[0], bitexpr.WhileLoop):
694        orig = copy.deepcopy(predec)
695        s[0].stmts = normalize(s[0].stmts, predec, False)
696        return [s[0]]+normalize(s[1:], orig)
697    if isinstance(s[0], bitexpr.BitAssign):
698        code, next = recurse_forward(s)
699        bb = basic_block.BasicBlock(predec)
700        predec.update(bb.normalize(code, ccelim))
701        return bb.get_code() + normalize(next, copy.deepcopy(predec))
702
703#################################################################################################################
704## Generating declarations for variables. This pass is based on the syntax of C programming language --- Pass 5
705#################################################################################################################
706def get_vars(stmt):
707        ints = set([])
708        bitblocks = set(['AllOne', 'AllZero'])
709        arrays = {}
710        structs = {}
711
712        #for stmt in code:
713
714        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
715
716        if isinstance(stmt.RHS, bitexpr.Add): 
717                ints.add(stmt.RHS.carry)
718        if isinstance(stmt.RHS, bitexpr.Sub):
719                ints.add(stmt.RHS.brw)
720
721        for var in all_vars:
722                (var_type, name, extra) = parse_var(var.varname)
723
724                if (var_type == "bitblock"):
725                        bitblocks.add(name)
726
727                if var_type == "array":
728                        if not name in arrays:
729                                arrays[name]= extra
730                        else:
731                                arrays[name] = max(arrays[name], extra)
732
733                if var_type == "struct":
734                        if not name in structs:
735                                structs[name] = set([extra])
736                        else:
737                                structs[name].add(extra)
738                if var_type == "int":
739                    ints.add(name)
740        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
741
742def merge_var_dic(first, second):
743    res = {}
744    #print first
745    #print second
746    res['int'] = first['int'].union(second['int'])
747    res['bitblock'] = first['bitblock'].union(second['bitblock'])
748
749    res['array'] = first['array']
750    temp = dict({'array':copy.copy(second['array'])})
751    for i in res['array']:
752        if i in temp['array']:
753            res['array'][i] = max(res['array'][i], temp['array'][i])
754            del temp['array'][i]
755    res['array'].update(temp['array'])
756
757
758    res['struct'] = first['struct']
759    temp = dict({'struct':copy.copy(second['struct'])})
760    for i in res['struct']:
761        if i in temp['struct']:
762            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
763            del temp['struct'][i]
764    res['struct'].update(temp['struct'])
765    return res
766
767def gen_output(var_dic):
768        s = ''
769        for i in  var_dic['int']:
770                s+="int %s=0;\n"%i
771
772        for i in var_dic['bitblock']:
773                if i == AllOne:
774                        s += "BitBlock %s = simd_const_1(1);\n"%i
775                elif i== AllZero:
776                        s+="BitBlock %s = simd_const_1(0);\n"%i
777                else:
778                        s+="BitBlock %s;\n"%i
779
780        for i in var_dic['array']:
781                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
782
783        for i in var_dic['struct']:
784                s+="struct __%s__{\n"%i
785                for j in var_dic['struct'][i]:
786                        s+= "\tBitBlock %s;\n"%j
787                s+="};\n"
788                s+="struct __%s__ %s;\n"%(i,i)
789
790        return s
791
792def gen_var_dic(s):
793    if len(s)==0:
794        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
795    if isinstance(s[0], bitexpr.BitAssign):
796        vd = get_vars(s[0])
797        more = gen_var_dic(s[1:])
798        vd = merge_var_dic(vd, more)
799        return vd
800
801    if isinstance(s[0], bitexpr.If):
802        vd1 = gen_var_dic(s[0].true_branch)
803        vd2 = gen_var_dic(s[0].false_branch)
804        vd = merge_var_dic(vd1, vd2)
805        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
806        vd =  merge_var_dic(vd, vd3)
807        more = gen_var_dic(s[1:])
808        vd  = merge_var_dic(vd, more)
809        return vd
810
811    if isinstance(s[0], bitexpr.WhileLoop):
812        vd = gen_var_dic(s[0].stmts)
813        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
814        vd =  merge_var_dic(vd, vd1)
815        more = gen_var_dic(s[1:])
816        vd  = merge_var_dic(vd, more)
817        return vd
818
819    """for loc in s[1:]:
820        more = get_vars(loc)
821
822    vd = get_vars(bb[0].code)
823    for block in bb[1:]:
824        more = get_vars(block.code)
825        vd = merge_var_dic(vd, more)
826    declarations = gen_output(vd)
827    return declarations
828    """
829
830def gen_declarations(s):
831    vd = gen_var_dic(s)
832    return gen_output(vd)
833#################################################################################################################
834## This class replaces all occurences of a reduced variable to its value --- Pass 7
835## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
836## *** the new AST notation                                                                                   ***
837#################################################################################################################
838def replace_in_rhs(rhs, target, replace):
839    if isinstance(rhs, bitexpr.Var):
840        if rhs.varname == target:
841            rhs = replace
842            return replace
843
844    if isinstance(rhs.operand1, bitexpr.Var):
845        if rhs.operand1.varname == target:
846            rhs.operand1 = replace
847    else:
848        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
849            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
850
851    if isinstance(rhs.operand2, bitexpr.Var):
852        if rhs.operand2.varname == target:
853            rhs.operand2 = replace
854    else:
855        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
856            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
857    return rhs
858
859def apply_single_opt(code, target, replace):
860    if len(code) == 0:
861        return []
862
863    if replace=='AllZero':
864        replace = bitexpr.FalseLiteral()
865    if replace == 'AllOne':
866        replace = bitexpr.TrueLiteral()
867
868    if isinstance(code[0], bitexpr.BitAssign):
869        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
870
871    if isinstance(code[0], bitexpr.If):
872        if code[0].control_expr.var.varname == target:
873            code[0].control_expr.var = replace
874        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
875        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
876
877    if isinstance(code[0], bitexpr.WhileLoop):
878        if code[0].control_expr.var.varname == target:
879            code[0].control_expr.var = replace
880        code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
881
882    return [code[0]]+apply_single_opt(code[1:], target, replace)
883
884def apply_all_opt(s):
885    if len(s) == 0:
886        return []
887    elif isinstance(s[0], bitexpr.If):
888        target = s[0].control_expr.var.varname
889        replace = s[0].control_expr.val
890        apply_single_opt(s[0].true_branch, target, replace)
891        apply_all_opt(s[0].true_branch)
892        apply_all_opt(s[0].false_branch)
893
894    elif isinstance(s[0], bitexpr.WhileLoop):
895        apply_all_opt(s[0].stmts)
896
897    return [s[0]]+apply_all_opt(s[1:])
898 
899#################################################################################################################
900## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation. --- Pass 8
901## method prune is supposed to remove unreachable branches but it is incomplete
902#################################################################################################################
903
904def prune(fixed, tree):
905    """removes unreachable branches of the tree"""
906    target = tree[0].target
907    replace = tree[0].replace
908    if target in fixed:
909        if replace == 'allone':
910            if isinstance(fixed[target], bitexpr.TrueLiteral):
911                tree[0] = None
912                tree[1].join(tree[2][1])
913                del tree[3]
914                del tree[2]
915                fixed.update(tree[1].simplify(fixed))
916            if isinstance(fixed[target], bitexpr.FalseLiteral):
917                tree[0] = None
918                tree[1].join(tree[3][1])
919                del tree[3]
920                del tree[2]
921                fixed.update(tree[1].simplify(fixed))
922        if replace == 'allzero':
923            if isinstance(fixed[target], bitexpr.FalseLiteral):
924                tree[0] = None
925                tree[1].join(tree[2][1])
926                del tree[3]
927                del tree[2]
928                fixed.update(tree[1].simplify(fixed))
929            if isinstance(fixed[target], bitexpr.TrueLiteral):
930                tree[0] = None
931                tree[1].join(tree[3][1])
932                del tree[3]
933                del tree[2]
934                fixed.update(tree[1].simplify(fixed))
935
936    empty_list = []
937    for index, branch in enumerate(tree[2:]):
938        if branch[0] is None:
939            if branch[1].code == []:
940                empty_list.append(2+index)
941
942    empty_list.reverse()
943    for i in empty_list:
944        del tree[i]
945def filter_fixed(fixed, stmts):
946    for loc in stmts:
947        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
948            del fixed[loc.LHS.varname]
949        if isinstance(loc, bitexpr.If):
950            fixed = filter_fixed(fixed, loc.true_branch)
951            fixed = filter_fixed(fixed, loc.false_branch)
952        if isinstance(loc, bitexpr.WhileLoop):
953            fixed = filter_fixed(fixed, loc.stmts)
954    return fixed
955           
956def simplify_tree(code, fixed = {}, prev = []):
957    #print len(code)
958    #print "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
959    if len(code) == 0:
960        return []
961
962    assumptions = {}
963
964    if isinstance(code[0], bitexpr.BitAssign):
965        this, next = recurse_forward(code)
966        fixed.update(basic_block.simplify(this, fixed))
967        #assumptions = basic_block.calc_implications(this, copy.deepcopy(fixed))
968        #fixed.update(assumptions)
969        return this+simplify_tree(next, fixed, this)
970
971    elif isinstance(code[0], bitexpr.If):
972        fixed1 = copy.deepcopy(fixed)
973        fixed2 = copy.deepcopy(fixed)
974        if isinstance(code[0].control_expr, bitexpr.isAllOne):
975            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
976        if isinstance(code[0].control_expr, bitexpr.isAllZero):
977            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
978        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
979        #print assumptions, len(prev), prev[0].LHS.varname
980        fixed1.update(assumptions)
981        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
982        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
983        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
984
985    elif isinstance(code[0], bitexpr.WhileLoop):
986        fixed = filter_fixed(fixed, code[0].stmts)
987        fixed1 = copy.deepcopy(fixed)
988        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
989        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
990
991#################################################################################################################
992## Dead Code Elimination ---- Pass 9
993## *** Changes required here is the same as Pass 7
994#################################################################################################################
995def get_effective_name(varname):
996    dot = varname.find('.')
997    bracket = varname.find('[')
998    if dot > 0: return varname[:dot]
999    if bracket > 0: return varname[:bracket]
1000    return varname
1001   
1002   
1003def check_loc(loc, must_liv):
1004    effective_name = get_effective_name(loc.LHS.varname)
1005    if effective_name in must_liv:
1006        if isinstance(loc.RHS, bitexpr.Not):
1007            return set([loc.RHS.operand1.varname]), False
1008        elif isinstance(loc.RHS, bitexpr.Var):
1009            return set([loc.RHS.varname]), False
1010        elif isinstance(loc.RHS, bitexpr.FalseLiteral):
1011            return set([]), False
1012        elif isinstance(loc.RHS, bitexpr.TrueLiteral):
1013            return set([]), False
1014        else:
1015            return set([loc.RHS.operand1.varname, loc.RHS.operand2.varname]), False
1016    else:
1017        return set([]), True
1018
1019def remove_copies(bb):
1020    """removes all copy statements e.g. var1 = var2"""
1021    lhs = [x.LHS.varname for x in bb]
1022    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
1023    for i in rhs:
1024        if i[1] in lhs:
1025            line = lhs.index(i[1])
1026            if i[0] > line:
1027                bb[i[0]].RHS = bb[line].RHS
1028    return bb
1029
1030def remove_dead(bb, must_live):
1031    #eliminates dead code from a basic block
1032    bb = remove_copies(bb)
1033    my_alives = set([])
1034    dead = []
1035
1036    for line, loc in reversed(list(enumerate(bb))):
1037        #print line, my_alives.union(must_live)
1038        new_lives, removable = check_loc(loc, my_alives.union(must_live))
1039
1040        if removable:
1041            dead.append(line)
1042        else:
1043            my_alives = my_alives.union(new_lives)
1044
1045    for i in dead:
1046        del bb[i]
1047
1048    return my_alives, bb
1049
1050def eliminate_dead_code(tree, must_live):
1051
1052    if len(tree) == 0:
1053        return [], []
1054
1055    last = len(tree) - 1
1056    new_alives = set([])
1057    bb = []
1058
1059    for loc in tree:
1060        if isinstance(loc, bitexpr.WhileLoop):
1061            if not loc.carry_expr is None:
1062                must_live.add(loc.carry_expr.var.varname)
1063
1064    if isinstance(tree[-1], bitexpr.BitAssign):
1065        first = 0
1066        for i in reversed(range(len(tree))):
1067            if not isinstance(tree[i], bitexpr.BitAssign):
1068                first = i+1
1069                break
1070        new_alives, bb = remove_dead(tree[first:], must_live)
1071        last = first
1072
1073    elif isinstance(tree[-1], bitexpr.If):
1074        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
1075
1076        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
1077        bb = [tree[-1]]
1078        new_alives.add(tree[-1].control_expr.var.varname)
1079
1080    elif isinstance(tree[-1], bitexpr.WhileLoop):
1081        must_live.add(tree[-1].control_expr.var.varname)
1082        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
1083        bb = [tree[-1]]
1084
1085    all_alives = new_alives.union(must_live)
1086    all_lives, new_tree = eliminate_dead_code(tree[:last], all_alives)
1087    tree = new_tree+bb
1088    return all_alives, tree
1089
1090#################################################################################################################
1091## This pass processes the code in the while loop and adds extra code required for handling carry variables
1092#################################################################################################################
1093carry_suffix = "_i"
1094
1095def fix_the_loop(loop):
1096    carries = []
1097    for loc in loop.stmts:
1098        if isinstance(loc.RHS, bitexpr.Add):
1099            carries.append(loc.RHS.carry)
1100    for item in carries:
1101        newvar = bitexpr.Var(item+carry_suffix)
1102        loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), "int")))
1103    for item in carries:
1104        loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral("int") ))
1105
1106    return loop, carries
1107
1108def process_while_loops(code):
1109    all = []
1110    update = {}
1111    for index, loc in enumerate(code):
1112        if isinstance(loc, bitexpr.WhileLoop):
1113            update[index] = fix_the_loop(loc)
1114            for i in update[index][1]:
1115                all.append(i)
1116                all.append(i+carry_suffix)
1117
1118    keys = [k for k in update]
1119    keys.sort(reverse=True)
1120    for key in keys:
1121        code[key] = update[key][0]
1122        carry_variable = bitexpr.Var("CarryTemp"+str(key))
1123        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral("int")))
1124        for item in update[key][1][2:]:
1125            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), "int")))
1126        if len(update[key][1]) == 1:
1127            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
1128        elif len(update[key][1]) > 1:
1129            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), "int")))
1130        for item in update[key][1]:
1131            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral("int")))
1132        for item in update[key][1]:
1133            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
1134
1135        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
1136    return code, all
1137
1138
1139
1140
1141
1142
1143#################################################################################################################
1144## This pass factors out the code that is common between the true branch and false branch of an if
1145## statement.
1146#################################################################################################################
1147
1148def are_the_same(one, two):
1149    if one.__class__ != two.__class__:
1150        return False
1151
1152    if isinstance(one, bitexpr.BitAssign):
1153        if (one.LHS.varname != two.LHS.varname):
1154            return False
1155        if one.RHS.__class__ != two.RHS.__class__:
1156            return False
1157        if isinstance(one.RHS, bitexpr.FalseLiteral):
1158            return True
1159        if isinstance(one.RHS, bitexpr.TrueLiteral):
1160            return True
1161        if isinstance(one.RHS, bitexpr.Var):
1162            return (one.RHS.varname == two.RHS.varname)
1163
1164        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
1165
1166    if isinstance(one, bitexpr.If):
1167        if (one.control_expr.__class__ != two.control_expr.__class__):
1168            return False
1169        if (one.control_expr.var.varname != two.control_expr.var.varname):
1170            return False
1171        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
1172            return False
1173        for ind in range(len(one.true_branch)):
1174            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
1175                return False
1176        for ind in range(len(one.false_branch)):
1177            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
1178                return False
1179        return True
1180
1181    if isinstance(one, bitexpr.WhileLoop):
1182        if (one.control_expr.__class__ != two.control_expr.__class__):
1183            return False
1184        if (one.control_expr.var.varname != two.control_expr.var.varname):
1185            return False
1186        if (one.carry_expr.__class__ != two.carry_expr.__class__):
1187            return False
1188        if (one.carry_expr.var.varname != two.carry_expr.var.varname):
1189            return False
1190        if (len(one.stmts) != len(two.stmts)):
1191            return False
1192        for ind in range(len(one.stmts)):
1193            if not are_the_same(one.stmts[ind], two.stmts[ind]):
1194                return False
1195        return True
1196
1197def get_factorable(cond):
1198    earliest = 0
1199    #print cond.control_expr, cond.control_expr.var.varname
1200    l = min(len(cond.true_branch), len(cond.false_branch))
1201    for index in reversed(range(-l,0)):
1202        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
1203            earliest = index
1204        else:
1205            break
1206
1207    common_length = -earliest
1208    return common_length
1209
1210def get_common_code(cond, common):
1211    true_len = len(cond.true_branch)-common
1212    false_len = len(cond.false_branch)-common
1213
1214    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
1215    common = cond.true_branch[true_len:]
1216
1217    return [new_cond], common
1218
1219def do_factorization(code, pos):
1220    indices = [key for key in pos]
1221    indices.sort()
1222
1223    for index in reversed(indices):
1224        new_cond, common = get_common_code(code[index], pos[index])
1225        code = code[:index]+new_cond+common+code[index+1:]
1226    return code
1227
1228def factor_out(code):
1229    for index, loc in enumerate(code):
1230        if isinstance(loc, bitexpr.If):
1231            code[index].true_branch = factor_out(code[index].true_branch)
1232            code[index].false_branch = factor_out(code[index].false_branch)
1233        if isinstance(loc, bitexpr.WhileLoop):
1234            code[index].stmts = factor_out(code[index].stmts)
1235
1236    pos = {}
1237    for index, loc in enumerate(code):
1238        if isinstance(loc, bitexpr.If):
1239            common_length = get_factorable(code[index])
1240            pos[index] = common_length
1241
1242    return do_factorization(code, pos)
1243
1244#################################################################################################################
1245## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 10
1246## ***Changes needed here are the same as Pass 7 and Pass 10***
1247#################################################################################################################
1248
1249def generate_condition(expr):
1250    if isinstance(expr, bitexpr.isAllZero):
1251        return "!bitblock_has_bit(%s)"%(expr.var.varname)
1252    elif isinstance(expr, bitexpr.isAllOne):
1253        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
1254    elif isinstance(expr, bitexpr.isNoneZero):
1255        return "bitblock_has_bit(%s)"%(expr.var.varname)
1256    else:
1257        print expr
1258        assert (1==0)
1259
1260def generate_statement(expr, indent, cond_stmt):
1261    if isinstance(expr, bitexpr.isAllZero):
1262        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1263    elif isinstance(expr, bitexpr.isAllOne):
1264        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1265    elif isinstance(expr, bitexpr.isNoneZero):
1266        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1267    else:
1268        print expr
1269        assert (1==0)
1270
1271def print_prog(s, indent = 0):
1272    indent_unit = 4
1273    code = ""
1274    if len(s) == 0:
1275        return ""
1276    if isinstance(s[0], bitexpr.If):
1277        code = generate_statement(s[0].control_expr, indent, "if")
1278        code += print_prog(s[0].true_branch, indent+indent_unit)
1279        code += " "*indent+"}\n"
1280        code += " "*indent+"else {\n"
1281        code += print_prog(s[0].false_branch, indent+indent_unit)
1282        code += " "*indent+"}\n"
1283
1284    if isinstance(s[0], bitexpr.WhileLoop):
1285        #code = generate_statement(s[0].control_expr, indent, "while")
1286        code = "\n%s(%s|%s) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname+">0")
1287        code += print_prog(s[0].stmts, indent+indent_unit)
1288        code += " "*indent+"}\n"
1289
1290    if isinstance(s[0], bitexpr.BitAssign):
1291        code += " "*indent + s[0].LHS.varname
1292        code += " = "
1293
1294        if s[0].RHS.data_type == "vector":
1295            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1296                code += s[0].RHS.operand1.varname
1297                code += ";\n"
1298            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
1299                code += s[0].RHS.operand1.varname
1300                code += ";\n"
1301            elif isinstance(s[0].RHS, bitexpr.Var):
1302                code += s[0].RHS.operand1.varname
1303                code += ";\n"
1304            elif isinstance(s[0].RHS, bitexpr.Add):
1305                code += s[0].RHS.op_C + "("
1306                code += s[0].RHS.operand1.varname
1307                code += ','
1308                code += s[0].RHS.operand2.varname
1309                code += ','
1310                code += s[0].RHS.carry
1311                code += ");\n"
1312
1313            else:
1314                code += s[0].RHS.op_C + "("
1315                code += s[0].RHS.operand1.varname
1316                code += ','
1317                code += s[0].RHS.operand2.varname
1318                code += ");\n"
1319        elif s[0].RHS.data_type == "int":
1320            if isinstance(s[0].RHS, bitexpr.Or):
1321                code += s[0].RHS.operand1.varname
1322                code += '|'
1323                code += s[0].RHS.operand2.varname
1324                code += ";\n"
1325            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1326                code += "0;\n"
1327
1328    s.pop(0)
1329    return code+print_prog(s, indent)
1330
1331#################################################################################################################
1332## Auxiliary Functions
1333#################################################################################################################
1334
1335def get_bb_of(code, index):
1336    if index < 0 or index >= len(code):
1337        return None
1338   
1339    if isinstance(code[index], bitexpr.If):
1340        return code[index]
1341    if isinstance(code[index], bitexpr.WhileLoop):
1342        return code[index]
1343    if isinstance(code[index], bitexpr.BitAssign):
1344        last = index
1345        for loc in code[index+1:]:
1346            if isinstance(loc, bitexpr.BitAssign):
1347                last += 1
1348            else:
1349                break
1350 
1351        first = index
1352        for loc in reversed(code[:index]):
1353            if isinstance(loc, bitexpr.BitAssign):
1354                first -= 1
1355            else:
1356                break
1357
1358        return code[first:last+1]
1359
1360def get_previous_bb(code, index):
1361    if index < 0:
1362        return None
1363    if index >= len(code):
1364        return get_bb_of(code, len(code)-1)
1365    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
1366        return get_bb_of(code, index-1)
1367    if isinstance(code[index], bitexpr.BitAssign):
1368        for ind, loc in reversed(enumerate(code[:index])):
1369            if not isinstance(loc, bitexpr.BitAssign):
1370                return get_bb_of(code, ind)
1371        return get_bb_of(code, -1)
1372    assert(1==0)
1373
1374def recurse_forward(code):
1375    if isinstance(code[0], bitexpr.If):
1376        return code[0], code[1:]
1377    if isinstance(code[0], bitexpr.WhileLoop):
1378        return code[0], code[1:]
1379    if isinstance(code[0], bitexpr.BitAssign):
1380        last = 0
1381        for loc in code[1:]:
1382            if isinstance(loc, bitexpr.BitAssign):
1383                last += 1
1384            else:
1385                break
1386        return code[0:last+1], code[last+1:]
Note: See TracBrowser for help on using the repository browser.