source: proto/Compiler/py2bitexpr.py @ 555

Last change on this file since 555 was 555, checked in by eamiri, 9 years ago

a bug in dead code elimination fixed

File size: 59.8 KB
Line 
1# -*- coding: utf-8 -*-
2import ast, bitexpr, copy, basic_block
3
4type_dict = {}
5structs = {}
6arrays = {}
7
8AllOne = 'AllOne'
9AllZero = 'AllZero'
10
11STRUCT_TEMPLATE="strct_%s__%s_"
12ARRAY_TEMPLATE="array_%s__%s_"
13
14class PyBitError(Exception):
15        pass
16
17#############################################################################################
18## Extracting list of live variables
19#############################################################################################
20
21def get_lives(s):
22    for item in s.body:
23        if isinstance(item, ast.FunctionDef):
24            if (item.name =="main"):
25                assert (isinstance(item.body[-1], ast.Return))
26                ret = item.body[-1].value
27                if isinstance(ret, ast.Tuple):
28                    del item.body[-1]
29                    return [translate_var(x) for x in ret.elts],s
30                else:
31                    assert(isinstance(ret, ast.Name))
32                    del item.body[-1]
33                    return [translate_var(ret)],s
34
35#############################################################################################
36## Function Inlining
37#############################################################################################
38
39TEMP_VAR_TEMPLATE = "InlineTemp%i"
40
41def get_var_name(var):
42    var_name = translate_var(var)
43    if isinstance(var, ast.Name):
44        type_name = "simple"
45        extra = None
46    if isinstance(var, ast.Attribute):
47        type_name = "structure"
48        extra = var.attr
49    if isinstance(var, ast.Index):
50        type_name = "array"
51        extra = translate_index(v.slice.value)
52
53    return type_name, var_name, extra
54
55def register_var(final_name, original_name, type_name, extra):
56    pass
57
58def collect_functions(module, file_path):
59    imps = []
60    to_remove = []
61    for index, item in enumerate(module.body):
62        if isinstance(item, ast.Import):
63            imps.append(item)
64            to_remove.append(index)
65
66    for n in reversed(to_remove):
67        del module.body[n]
68
69    for imp in imps:
70        for mod in imp.names:
71            mod_file = open(file_path + mod.name + '.py')
72            mod_code = ""
73            for line in mod_file:
74                mod_code += line
75            mod_file.close()
76            mod_parsed = ast.parse(mod_code)
77            modified = []
78            for item in mod_parsed.body:
79                if isinstance(item, ast.FunctionDef):
80                    item.name = translate_var(ast.Attribute(value=ast.Name(id=mod.name), attr=item.name))
81            module.body += mod_parsed.body
82    return module
83
84
85def remove_unsupported_code(callee, fn_names, predefined, main=False):
86    to_remove = []
87    for index, loc in enumerate(callee.body):
88        if isinstance(loc, ast.While) or isinstance(loc, ast.If):
89            if not main:
90                to_remove.append(index)
91            continue
92
93        if isinstance(loc, ast.Expr):
94            if isinstance(loc.value, ast.Call):
95                if main and not (translate_var(loc.value.func) in fn_names+predefined):
96                    to_remove.append(index)
97                if not main and not translate_var(loc.value.func) in predefined:
98                    to_remove.append(index)
99
100            else:
101                to_remove.append(index)
102
103            continue
104
105        if not (isinstance(loc, ast.Assign) or isinstance(loc, ast.AugAssign) or isinstance(loc, ast.Return)):
106            to_remove.append(index)
107
108        #we do not support AugAssign with function calls
109        if isinstance(loc, ast.AugAssign):
110            if isinstance(loc.value, ast.Call):
111                assert (1==0)
112                to_remove.append(index)
113
114        #we do not support function calls in functions other than main
115        #these are OK if they are object construction, otherwise error.
116        if isinstance(loc, ast.Assign):
117            if isinstance(loc.value, ast.Call):
118                if main and not (translate_var(loc.value.func) in fn_names+predefined):
119                    to_remove.append(index)
120                if not main and not translate_var(loc.value.func) in predefined:
121                    to_remove.append(index)
122    to_remove.sort(reverse=True)
123    for line_no in to_remove:
124        del callee.body[line_no]
125
126    return callee
127
128
129def clean_functions(module, fn_names, predefined):
130    new = []
131    for item in module.body:
132        if isinstance(item, ast.FunctionDef):
133            new.append(remove_unsupported_code(item, fn_names, predefined, item.name == "main")) 
134        else:
135            new.append(item)
136    module.body = new
137    return module
138
139def collect_names(body):
140    names = {}
141    for index, item in enumerate(body):
142        if isinstance(item, ast.FunctionDef):
143            names[item.name] = index
144
145    return names
146
147def modify_name(var, args, cargs, unique_prefix):
148    carg_names = [x.id for x in cargs]
149    arg_names = [translate_var(x) for x in args]
150
151    varname = translate_var(var)
152    if isinstance(var, ast.Name):
153        if varname in carg_names:
154            return ast.Name(id=arg_names[carg_names.index(varname)])
155        return ast.Name(id=unique_prefix+varname)
156    if isinstance(var, ast.Attribute) or isinstance(var, ast.Subscript):
157        if var.value.id in carg_names:
158            var.value.id = arg_names[carg_names.index(var.value.id)]
159            return ast.Name(id=translate_var(var))
160        return ast.Name(id=unique_prefix+translate_var(var))
161
162def replace_in_exp(exp, args, cargs, unique_prefix):
163    if isinstance(exp, ast.BinOp):
164        exp.left = replace_in_exp(exp.left, args, cargs, unique_prefix)
165        exp.right = replace_in_exp(exp.right, args, cargs, unique_prefix)
166    elif isinstance(exp, ast.UnaryOp):
167        exp.operand = replace_in_exp(exp.operand, args, cargs, unique_prefix)
168    elif isinstance(exp, ast.Name):
169        exp = modify_name(exp, args, cargs, unique_prefix)
170    elif isinstance(exp, ast.Attribute):
171        exp = modify_name(exp, args, cargs, unique_prefix)
172    elif isinstance(exp, ast.Subscript):
173        exp = modify_name(exp, args, cargs, unique_prefix)
174    elif isinstance(exp, ast.Call):
175        for index, item in enumerate(exp.args):
176            exp.args[index] = replace_in_exp(item, args, cargs, unique_prefix)
177    elif isinstance(exp, ast.Tuple):
178        for index, elt in enumerate(exp.elts):
179            exp.elts[index] = replace_in_exp(elt, args, cargs, unique_prefix)
180    elif isinstance(exp, ast.Num):
181        pass
182    else:
183        assert(1==0)
184    return exp
185
186def get_all_calls(main, all_func):
187    call_list = []
188    for index, loc in enumerate(main.body):
189        if isinstance(loc, ast.Assign) and isinstance(loc.value, ast.Call):
190                func_name = translate_var(loc.value.func)
191                if func_name in all_func:
192                    call_list.append(index)
193        if isinstance(loc, ast.AugAssign) and isinstance(loc.value, ast.Call):
194                func_name = translate_var(loc.value.func)
195                if func_name in all_func:
196                    call_list.append(index)
197    return call_list
198
199def update_var_names(callee, args):
200    unique_prefix = '_'+callee.name.replace('.', '_')+'_'
201
202    assert (len(args)==len(callee.args.args))
203
204    for index, loc in enumerate(callee.body):
205        if isinstance(loc, ast.Assign):
206            for index, var in enumerate(loc.targets):
207                loc.targets[index] = modify_name(var, args, callee.args.args, unique_prefix)
208            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
209        elif isinstance(loc, ast.AugAssign):
210            loc.target = modify_name(loc.target, args, callee.args.args, unique_prefix)
211            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
212        elif isinstance(loc, ast.Return):
213            pass
214            #loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
215        else:
216            assert (1==0)
217
218    return callee, unique_prefix
219
220
221def get_name(loc, index):
222    if isinstance(loc.value, ast.Tuple):
223        return translate_var(loc.value.elts[index])
224    elif index == 0:
225        return translate_var(loc.value)
226    else:
227        assert(1==0)
228
229def modify_assignments_in_caller(main, line_no, return_list, prefix, callee):
230    if isinstance(main.body[line_no].targets[0], ast.Tuple):
231        all_vars = main.body[line_no].targets[0].elts
232    else:
233        all_vars = main.body[line_no].targets
234    new_code = []
235
236    if isinstance(callee.body[-1].value, ast.Tuple):
237        return_vars = [translate_var(x) for x in callee.body[-1].value.elts]
238    else:
239        return_vars = [translate_var(callee.body[-1].value)]
240
241    for index, item in enumerate(all_vars):
242        actual_name = translate_var(item)
243        if actual_name in [translate_var(x) for x in main.body[line_no].value.args]:
244            continue
245        temp_name = return_vars[index]
246
247        if temp_name in [translate_var(x) for x in callee.args.args]:
248            second_index = [translate_var(x) for x in callee.args.args].index(temp_name)
249            final_name = translate_var(main.body[line_no].value.args[second_index])
250        else:
251            final_name = prefix+temp_name
252
253        if return_list[temp_name][0] == "struct":
254            my_template = STRUCT_TEMPLATE
255        elif return_list[temp_name][0] == "array":
256            my_template = ARRAY_TEMPLATE
257 
258        if return_list[temp_name][0] == "bitblock":
259            new_code.append(ast.Assign( [ast.Name(id=item.id)] , ast.Name(id=final_name) ))
260        else:
261            for elem in return_list[temp_name][1]:
262                new_code.append(ast.Assign([ast.Name(id=my_template%(item.id, elem))], ast.Name(id=prefix+my_template%(temp_name, elem))))
263    main.body = main.body[:line_no]+new_code+main.body[line_no+1:]
264    return main
265
266def get_return_list(callee):
267    ret_list = []
268    loc = callee.body[-1]
269    if isinstance(loc, ast.Return):
270        if isinstance(loc.value, ast.Tuple):
271            all_rets = loc.value.elts
272        else:
273            all_rets = [loc.value]
274        for item in all_rets:
275            if isinstance(item, ast.Attribute) or isinstance(item, ast.Subscript):
276                ret_list.append(item.value.id)
277            if isinstance(item, ast.Name):
278                ret_list.append(item.id)
279    return ret_list
280
281def get_updated(loc):
282    if isinstance(loc, ast.AugAssign):
283        updated = [loc.target]
284    elif isinstance(loc, ast.Assign):
285        if isinstance(loc.targets[0], ast.Tuple):
286            updated = loc.targets[0].elts
287        else:
288            updated = loc.targets
289    else:
290        updated = []
291    return updated
292
293def get_update_details(callee, ret_list):
294    field_info = {}
295    for item in ret_list:
296        field_info[item] = ['simple', set([])]
297    for loc in callee.body:
298        updated = get_updated(loc)
299        for item in updated:
300            type, name, extra = parse_var(translate_var(item, True))
301            if name in field_info:
302                field_info[name][0] = type
303                if type=='struct':
304                    field_info[name][1].add(extra)
305                if type == 'array':
306                    field_info[name][1].add(extra)
307    return field_info
308
309def process_return(callee):
310    if not isinstance(callee.body[-1], ast.Return):
311        return {}
312    ret_list = get_return_list(callee)
313    field_info = get_update_details(callee, ret_list)
314    return field_info
315
316def generate_expanded_code(loc):
317    new_code = []
318    array_name = translate_var(loc.targets[0])
319    for ind, i in enumerate(range(len(loc.value.elts))):
320        new_loc = ast.Assign([ast.Subscript(value=ast.Name(id=array_name), slice=ast.Index(ast.Num(n=ind)) )], loc.value.elts[i])
321        new_code.append(new_loc)
322
323    return new_code
324
325def expand_list_init(module):
326    for item in module.body:
327        replace_code = {}
328        if isinstance(item, ast.FunctionDef):
329            for index,loc in enumerate(item.body):
330                if isinstance(loc, ast.Assign):
331                    if isinstance(loc.value, ast.List):
332                        equivalent_code = generate_expanded_code(loc)
333                        replace_code[index] = equivalent_code
334            keys = [key for key in replace_code]
335            keys.sort(reverse = True)
336            for line_no in keys:
337                item.body = item.body[:line_no]+replace_code[line_no]+item.body[line_no+1:]
338
339    return module
340
341def get_predefined_funcs():
342    advance = ast.Attribute(value=ast.Name(id="bitutil"), attr="Advance")
343    scan_thru = ast.Attribute(value=ast.Name(id="bitutil"), attr="ScanThru")
344    optimize = ast.Attribute(value=ast.Name(id="bitutil"), attr="Optimize")
345    return [translate_var(advance), translate_var(scan_thru), translate_var(optimize)]
346
347def do_inlining(module, file_path):
348    module = collect_functions(module, file_path)
349    func_dict = collect_names(module.body)
350    module = expand_list_init(module)
351
352    valid_fn = [name for name in func_dict] 
353    predef_fn = get_predefined_funcs()
354    module = clean_functions(module, valid_fn, predef_fn)
355
356    main = module.body[func_dict["main"]]
357    call_list = get_all_calls(main, [func for func in func_dict])
358    call_list.sort(reverse=True)
359    for line_no in call_list:
360        callee = copy.deepcopy(module.body[func_dict[translate_var(main.body[line_no].value.func)]])
361        args = main.body[line_no].value.args
362        return_list = process_return(callee)
363        callee, prefix = update_var_names(callee, args)
364
365        main = modify_assignments_in_caller(main, line_no, return_list, prefix, callee)
366
367        if isinstance(callee.body[-1], ast.Return):
368            del callee.body[-1]
369        main.body = main.body[:line_no]+callee.body+main.body[line_no:]
370       
371        #if callee.name == 'strct_byteclass__classify_bytes_':
372        #  for item in module.body[func_dict["main"]].body:
373        #      if isinstance(item, ast.Assign):
374        #       if isinstance(item.targets[0], ast.Name):
375        #           print item.targets[0].id
376         
377    #for item in module.body[func_dict["main"]].body:
378    #   if isinstance(item, ast.Assign):
379    #      if isinstance(item.targets[0], ast.Name):
380    #         print item.targets[0].id
381   
382    return module.body[func_dict["main"]].body
383
384#############################################################################################
385## Translation to bitExpr classes --- Pass 2
386#############################################################################################
387
388def translate_index(ast_value):
389        if isinstance(ast_value, ast.Str): return ast_value.s
390        elif isinstance(ast_value, ast.Num): return repr(ast_value.n)
391        else: raise PyBitError("Unknown value %s\n" % ast.dump(ast_value))
392
393def is_Advance_Call(fncall):
394        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
395        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
396                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
397        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
398
399def is_ScanThru_Call(fncall):
400        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
401        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
402                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
403        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
404
405def translate(ast_expr):
406        if isinstance(ast_expr, ast.Name):
407                if ast_expr.id=="allzero":
408                        return bitexpr.FalseLiteral()
409                if ast_expr.id=="allone":
410                        return bitexpr.TrueLiteral()
411                return bitexpr.Var(translate_var(ast_expr))
412        elif isinstance(ast_expr, ast.Num):
413                if ast_expr.n == 0:
414                    return bitexpr.FalseLiteral()
415                else:
416                    return bitexpr.Const(ast_expr.n)
417        elif isinstance(ast_expr, ast.Attribute):
418                if isinstance(ast_expr.value, ast.Name):
419                        return bitexpr.Var(translate_var(ast_expr))
420                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, ast_expr.attr))
421                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(ast_expr))
422        elif isinstance(ast_expr, ast.Subscript):
423                if isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice, ast.Index):
424                        return bitexpr.Var(translate_var(ast_expr))
425                        return bitexpr.Var("%s[%s]" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
426                else: raise PyBitError("Illegal array %s\n" % ast.dump(ast_expr))
427        elif isinstance(ast_expr, ast.UnaryOp):
428                e1 = translate(ast_expr.operand)
429                if isinstance(ast_expr.op, ast.Invert):
430                        return bitexpr.make_not(e1)
431                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
432        elif isinstance(ast_expr, ast.BinOp):
433                e1 = translate(ast_expr.left)
434                e2 = translate(ast_expr.right)
435                if isinstance(ast_expr.op, ast.BitOr):
436                        return bitexpr.make_or(e1, e2)
437                elif isinstance(ast_expr.op, ast.BitAnd):
438                        return bitexpr.make_and(e1, e2)
439                elif isinstance(ast_expr.op, ast.BitXor):
440                        return bitexpr.make_xor(e1, e2)
441                elif isinstance(ast_expr.op, ast.Add):
442                        return bitexpr.make_add(e1, e2)
443                elif isinstance(ast_expr.op, ast.Sub):
444                        return bitexpr.make_sub(e1, e2)
445                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
446        elif isinstance(ast_expr, ast.Call):
447                if is_Advance_Call(ast_expr):
448                        e = translate(ast_expr.args[0])
449                        temp = bitexpr.make_add(e,e)
450                        return temp
451                elif is_ScanThru_Call(ast_expr):
452                        e0 = translate(ast_expr.args[0])
453                        e1 = translate(ast_expr.args[1])
454                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
455                else:
456                    raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
457        elif isinstance(ast_expr, ast.Compare):
458                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
459                        e0 = translate(ast_expr.left)
460                        return bitexpr.isNoneZero(e0)
461                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
462               
463        else:
464            raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
465
466def translate_var(v, aggregate_type = False):
467        if isinstance(v, ast.Name):
468            agg_name = v.id
469            var_name = v.id
470        elif isinstance(v, ast.Attribute):
471           if isinstance(v.value, ast.Name):
472               var_name = STRUCT_TEMPLATE%(v.value.id, v.attr)
473               agg_name = "%s.%s" % (v.value.id, v.attr)
474           else: raise PyBitError("Illegal attribute %s\n" % ast.dump(v))
475        elif isinstance(v, ast.Subscript):
476           if isinstance(v.value, ast.Name) and isinstance(v.slice, ast.Index):
477               var_name = ARRAY_TEMPLATE%(v.value.id, translate_index(v.slice.value))
478               agg_name = "%s[%s]" % (v.value.id, translate_index(v.slice.value))
479               
480        else: raise PyBitError("Unknown operator %s\n" % ast.dump(v))
481
482        if aggregate_type:
483            return agg_name
484        else:
485            return var_name
486
487def translate_stmts(ast_stmts):
488        translated = []
489        for s in ast_stmts:
490                if isinstance(s, ast.Expr):
491                        if (translate_var(s.value.func)=='strct_bitutil__Optimize_'):
492                            target = translate_var(s.value.args[0])
493                            replace = translate_var(s.value.args[1])
494                            translated.append(bitexpr.Reduce(target, replace))
495                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
496                elif isinstance(s, ast.Assign):
497                        e = translate(s.value)
498                        for astv in s.targets: 
499                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
500                elif isinstance(s, ast.AugAssign):
501                        v = bitexpr.Var(translate_var(s.target))
502                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
503                elif isinstance(s, ast.While):
504                        e = translate(s.test)
505                        body = translate_stmts(s.body)
506                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body))
507                elif isinstance(s, ast.If):
508                        e = translate(s.test)
509                        body = translate_stmts(s.body)
510                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body, fake = True ))
511                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
512        return translated
513#############################################################################################
514## Conversion to SSA form --- Pass 3
515#############################################################################################
516
517def extract_vars(rhs):
518    if isinstance(rhs, bitexpr.Var):
519        return [rhs.varname]
520    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral) or isinstance(rhs, bitexpr.Const):
521        return []
522    if isinstance(rhs, bitexpr.Not):
523        return extract_vars(rhs.operand1)
524    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
525    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
526
527def update_lineno(inner_table, shift):
528    for key in inner_table:
529        for index in range(len(inner_table[key][0])):
530            inner_table[key][0][index] += shift
531        for index in range(len(inner_table[key][1])):
532            inner_table[key][1][index] += shift
533    return inner_table
534
535def merge_tables(table, inner_table):
536    for key in inner_table:
537        if key in table:
538            table[key][0] += inner_table[key][0]
539            table[key][1] += inner_table[key][1]
540        else:
541            table[key] = inner_table[key]
542    return table
543
544def gen_sym_table(code, goInside = False):
545    """Generate a simple symbol table for a three address code
546        each entry is of this form: var_name:[[defs][uses]]
547    """
548    table = {}
549    index = 0
550    for stmt in code:
551        if isinstance(stmt, bitexpr.Reduce):
552            if stmt.target in table:
553                table[stmt.target][1].append(index)
554            else:
555                table[stmt.target] = [[], [index]]
556            index += 1
557        elif isinstance(stmt, bitexpr.BitAssign):
558            current = stmt.LHS.varname
559            if current in table:
560                table[current][0].append(index)
561            else:
562                table[current] = [[index],[]]
563
564            varlist = extract_vars(stmt.RHS)
565            for var in varlist:
566                if var in table:
567                    table[var][1].append(index)
568                else:
569                    table[var] = [[], [index]]
570            index += 1
571        elif isinstance(stmt, bitexpr.WhileLoop):
572            cond_var = stmt.control_expr.var.varname
573            if cond_var in table:
574                table[cond_var][1].append(index)
575            else:
576                #while loop conditioned on an undefined variable
577                assert(1==0)
578            if goInside:
579                inner_table = gen_sym_table(stmt.stmts, True)
580                inner_table = update_lineno(inner_table, index+1)
581                table = merge_tables(table, inner_table)
582                index += (1+len(stmt.stmts))
583            else:
584                index += 1
585        else:
586            assert(1==0)
587    return table
588
589##################################################################################
590def get_line(code, line):
591    lineno = 0
592    for stmt in code:
593        if isinstance(stmt, bitexpr.BitAssign):
594            if lineno == line:
595                return stmt
596            lineno += 1
597        elif isinstance(stmt, bitexpr.WhileLoop):
598            if lineno == line:
599                return stmt
600            elif lineno+len(stmt.stmts) < line:
601                lineno += 1
602                line -= len(stmt.stmts)
603                continue
604            else:
605                return get_line(stmt.stmts, line-(lineno+1))
606
607        elif isinstance(stmt, bitexpr.Reduce):
608            if lineno == line:
609                return stmt
610            lineno += 1
611        else:
612            assert(1==0)
613
614def update_def(code, var, line, suffix_num):
615    loc = code[line]
616    if isinstance(loc, bitexpr.BitAssign):
617        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
618        return loc
619    else:
620        #either Reduce or While Loop in both cases it's a use
621        pass
622
623def update_rhs(rhs, varname, newname):
624    if isinstance(rhs, bitexpr.Var):
625        if rhs.varname == varname:
626            rhs.varname = newname
627        return rhs
628
629    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
630        return rhs
631
632    if isinstance(rhs, bitexpr.Not):
633        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
634        return rhs
635    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
636    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
637    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
638    return rhs
639
640def update_use(code, var, line, suffix_num):
641    loc = code[line]
642    if isinstance(loc, bitexpr.BitAssign):
643        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
644    elif isinstance(loc, bitexpr.Reduce):
645        pass
646    elif isinstance(loc, bitexpr.WhileLoop):
647        loc.control_expr.var.varname += "_%i"%suffix_num
648    return loc
649
650def parse_var(var):
651    index=var.find('.')
652    if index >= 0:
653        return ('struct', var[0:index], var[index+1:])
654               
655    index = var.find('[')
656    if index >= 0:
657        right_index = var.find(']')
658        return ('array', var[0:index], var[index+1:right_index])       
659   
660    if var.startswith("carry") or var.startswith("Carry"):
661        return('int', var, None)
662   
663    return ('bitblock', var, None)
664
665def simplify_name(var):
666    (vartype, name, extra) = parse_var(var)
667    if vartype == 'bitblock':
668        return name
669    if vartype == 'array':
670        return "a_%s_%s"%(name, extra)
671    if vartype == 'struct':
672        return "s_%s_%s"%(name, extra)
673    if vartype == 'int':
674        assert(1==0)
675    assert(1==0)
676
677def pairs(lst):
678    if lst == []:
679            return []
680    return zip(lst,lst[1:]+[lst[0]])
681
682
683def make_SSA(code, st):
684    new_vars = []
685    total_lines = len(code)
686    for var in st:
687        st[var][0].append(total_lines)
688        st[var][1].append(total_lines)
689
690    unique_number = 0
691    for var in st:
692        use_index = 0
693        for current, next in pairs(st[var][0])[0:-2]:
694            code[current] = update_def(code, var, current, unique_number)
695            uline = st[var][1][use_index]
696
697            while uline <= next and uline < total_lines:
698                if uline > current:
699                    code[uline] = update_use(code, var, uline, unique_number)
700                use_index += 1
701                uline = st[var][1][use_index]
702            unique_number += 1
703    return code
704
705#################################################################################################################
706## Breaking the code into basic blocks depending on where optimize pragma appears --- Pass 4
707#################################################################################################################
708def get_opt_list(s):
709    opt_list = []
710    remove_index = []
711    for line, stmt in enumerate(s):
712        if isinstance(stmt, bitexpr.Reduce):
713            opt_pair = (stmt.target, stmt.replace)
714            opt_list.append(opt_pair)
715            remove_index.append(line)
716    for ind in reversed(remove_index):
717        del s[ind]
718
719    return opt_list
720
721def break_to_segs(s, breaks):
722    segs = []
723    breaks = [-1]+breaks+[len(s)]
724    for start, stop in pairs(breaks)[:-1]:
725        next = s[start+1:stop]
726        if len(next)>0:
727            segs.append(next)
728    return segs
729
730def extract_pragmas(s):
731    targets = []
732    exprs = []
733    lines = []
734
735    #extracting pragmas and their targets
736    for line, loc in enumerate(s):
737        if isinstance(loc, bitexpr.Reduce):
738            targets.append(loc.target)
739            exprs.append(loc)
740            lines.append(line)
741
742    #removing pragmas from the source code
743    for i in reversed(lines):
744        del s[i]
745
746    return targets, exprs
747
748def get_defs(targets, s):
749    temp = [(x,-1) for x in targets]
750    targ_dic = dict(temp)
751    for line, loc in enumerate(s):
752        if loc.LHS.varname in targets:
753            targ_dic[loc.LHS.varname] = line
754
755    items = targ_dic.items()
756    items = sorted(items, key=(lambda x: x[1]))
757
758    lineno = [key for value, key in items]
759    sorted_targets = [value for value, key in items]
760
761    return sorted_targets, lineno
762
763def defines(varname, stmts):
764    is_defined = False
765    for loc in stmts:
766        if isinstance(loc, bitexpr.WhileLoop):
767            is_defined |= defines(varname, loc.stmts)
768        elif isinstance(loc, bitexpr.BitAssign): 
769            if loc.LHS.varname == varname:
770                is_defined = True
771
772    return is_defined
773
774def adjust_earliest(s, earliest, varname):
775    if earliest == -1:
776        return -1
777
778    line = earliest+1
779    for stmt in s[earliest+1:]:
780        if isinstance(stmt, bitexpr.WhileLoop):
781            if defines(varname, stmt.stmts):
782               earliest = line+1 
783        line+=1
784    return earliest
785
786def chop_code(s, start, stop):
787    line = 0
788
789    if start == -1:
790        cut1 = -1
791
792    if stop == len(s):
793        cut2 = len(s)
794
795    for index, stmt in enumerate(s):
796        if isinstance(stmt, bitexpr.BitAssign):
797            if line == start:
798                cut1 = index
799            if line == stop:
800                cut2 = index
801            line += 1
802
803        if isinstance(stmt, bitexpr.WhileLoop):
804            if line == start:
805                cut1 = index
806            if line  == stop:
807                cut2 = index
808            line+=1
809
810    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
811
812def count_lines(code):
813    cnt = len(code)
814    for stmt in code:
815        if isinstance(stmt, bitexpr.WhileLoop):
816            cnt += count_lines(stmt.stmts)
817    return cnt
818
819
820def gen_cond_obj(opt_pragma):
821    if opt_pragma[1] == 'AllZero' or opt_pragma[1] == 'allzero':
822        cond_obj = bitexpr.isAllZero(opt_pragma[0])
823    if opt_pragma[1] == 'AllOne' or opt_pragma[1] == 'allone':
824        cond_obj = bitexpr.isAllOne(opt_pragma[0])
825    return cond_obj
826
827def gen_bb(s, opt_list, def_line):
828    assert (len(opt_list) == len(def_line))
829    if opt_list==[]:
830        return s
831
832    earliest = min(def_line)
833    latest = len(s)
834    earliest_index = def_line.index(earliest)
835    early_varname = opt_list[earliest_index][0]
836   
837    earliest = adjust_earliest(s, earliest, early_varname)
838
839    first, second, third = chop_code(s, earliest, latest)
840
841    cond_obj = gen_cond_obj(opt_list[earliest_index])
842   
843    del opt_list[earliest_index]
844    del def_line[earliest_index]
845    def_line = [x-earliest-1 for x in def_line]
846   
847    second = gen_bb(second, opt_list, def_line)
848
849    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))] + third
850    return result
851
852def partition2bb(s):
853    basicblocks = []
854    lineno = []
855    def_line = []
856    opt_list = get_opt_list(s)
857
858    st2 = gen_sym_table(s)
859
860    for item in opt_list:
861        if len(st2[item[0]][0]) > 0:
862            def_line.append(st2[item[0]][0][-1])
863        else:
864            def_line.append(-1)
865
866
867    return gen_bb(s, opt_list, def_line)
868
869#################################################################################################################
870## Generating declarations for variables. This pass is based on the syntax of C programming language ---
871## Normalizing the code
872#################################################################################################################
873
874def normalize(s, predec = {}, ccelim=True):
875    if len(s)==0:
876        return []
877    if isinstance(s[0], bitexpr.If):
878        gc = basic_block.BasicBlock.gensym_counter
879        cc = basic_block.BasicBlock.carry_counter
880        bc = basic_block.BasicBlock.brw_counter
881        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
882        maxgc = basic_block.BasicBlock.gensym_counter
883        maxcc = basic_block.BasicBlock.carry_counter
884        maxbc = basic_block.BasicBlock.brw_counter
885        basic_block.BasicBlock.gensym_counter = gc
886        basic_block.BasicBlock.carry_counter = cc
887        basic_block.BasicBlock.brw_counter = bc
888        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
889        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
890        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
891        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
892       
893        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
894    if isinstance(s[0], bitexpr.WhileLoop):
895        orig = copy.deepcopy(predec)
896        s[0].stmts = normalize(s[0].stmts, predec, False)
897        return [s[0]]+normalize(s[1:], {})
898    if isinstance(s[0], bitexpr.BitAssign):
899        code, next = recurse_forward(s)
900        bb = basic_block.BasicBlock(predec)
901        predec.update(bb.normalize(code, ccelim))
902        return bb.get_code() + normalize(next, copy.deepcopy(predec))
903
904#################################################################################################################
905## Generating declarations for variables. This pass is based on the syntax of C programming language --- Pass 5
906#################################################################################################################
907def get_vars(stmt):
908        ints = set([])
909        bitblocks = set(['AllOne', 'AllZero'])
910        arrays = {}
911        structs = {}
912
913        #for stmt in code:
914
915        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
916       
917
918
919        if isinstance(stmt.RHS, bitexpr.Add): 
920                ints.add(stmt.RHS.carry)
921        if isinstance(stmt.RHS, bitexpr.Sub):
922                ints.add(stmt.RHS.brw)
923
924        for var in all_vars:
925                if isinstance(var, bitexpr.Const):
926                    continue
927               
928                (var_type, name, extra) = parse_var(var.varname)
929
930                if (var_type == "bitblock"):
931                        bitblocks.add(name)
932
933                if var_type == "array":
934                        if not name in arrays:
935                                arrays[name]= extra
936                        else:
937                                arrays[name] = max(arrays[name], extra)
938
939                if var_type == "struct":
940                        if not name in structs:
941                                structs[name] = set([extra])
942                        else:
943                                structs[name].add(extra)
944                if var_type == "int":
945                    ints.add(name)
946                       
947        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
948
949def merge_var_dic(first, second):
950    res = {}
951    res['int'] = first['int'].union(second['int'])
952    res['bitblock'] = first['bitblock'].union(second['bitblock'])
953
954    res['array'] = first['array']
955    temp = dict({'array':copy.copy(second['array'])})
956    for i in res['array']:
957        if i in temp['array']:
958            res['array'][i] = max(res['array'][i], temp['array'][i])
959            del temp['array'][i]
960    res['array'].update(temp['array'])
961
962
963    res['struct'] = first['struct']
964    temp = dict({'struct':copy.copy(second['struct'])})
965    for i in res['struct']:
966        if i in temp['struct']:
967            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
968            del temp['struct'][i]
969    res['struct'].update(temp['struct'])
970    return res
971
972def gen_output(var_dic):
973        s = ''
974        for i in var_dic['bitblock']:
975                if i == AllOne:
976                        s += "BitBlock %s = simd_const_1(1);\n"%i
977                elif i == AllZero:
978                        s+="BitBlock %s = simd_const_1(0);\n"%i
979                elif i == "EOF_mask":
980                        s+="BitBlock EOF_mask = simd_const_1(1);\n"
981                else:
982                        s+="BitBlock %s;\n"%i
983
984        for i in var_dic['array']:
985                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
986
987        for i in var_dic['struct']:
988                s+="struct __%s__{\n"%i
989                for j in var_dic['struct'][i]:
990                        s+= "\tBitBlock %s;\n"%j
991                s+="};\n"
992                s+="struct __%s__ %s;\n"%(i,i)
993
994        for i in  var_dic['int']:
995                s+="CarryType %s = Carry0;\n"%i
996
997
998
999
1000        return s
1001
1002def gen_var_dic(s):
1003    if len(s)==0:
1004        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
1005    if isinstance(s[0], bitexpr.BitAssign):
1006        vd = get_vars(s[0])
1007        more = gen_var_dic(s[1:])
1008        vd = merge_var_dic(vd, more)
1009        return vd
1010
1011    if isinstance(s[0], bitexpr.If):
1012        vd1 = gen_var_dic(s[0].true_branch)
1013        vd2 = gen_var_dic(s[0].false_branch)
1014        vd = merge_var_dic(vd1, vd2)
1015        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1016        vd =  merge_var_dic(vd, vd3)
1017        more = gen_var_dic(s[1:])
1018        vd  = merge_var_dic(vd, more)
1019        return vd
1020
1021    if isinstance(s[0], bitexpr.WhileLoop):
1022        vd = gen_var_dic(s[0].stmts)
1023        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1024        vd =  merge_var_dic(vd, vd1)
1025        more = gen_var_dic(s[1:])
1026        vd  = merge_var_dic(vd, more)
1027        return vd
1028
1029def gen_declarations(s):
1030    vd = gen_var_dic(s)
1031    return gen_output(vd)
1032#################################################################################################################
1033## This class replaces all occurrences of a reduced variable to its value --- Pass 7
1034## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
1035## *** the new AST notation                                                                                   ***
1036#################################################################################################################
1037def replace_in_rhs(rhs, target, replace):
1038    if isinstance(rhs, bitexpr.Const):
1039        return rhs
1040    if isinstance(rhs, bitexpr.Var):
1041        if rhs.varname == target:
1042            rhs = replace
1043            return replace
1044
1045    if isinstance(rhs.operand1, bitexpr.Var):
1046        if rhs.operand1.varname == target:
1047            rhs.operand1 = replace
1048    else:
1049        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
1050            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
1051
1052    if isinstance(rhs.operand2, bitexpr.Var):
1053        if rhs.operand2.varname == target:
1054            rhs.operand2 = replace
1055    else:
1056        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
1057            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
1058    return rhs
1059
1060def apply_single_opt(code, target, replace):
1061    if len(code) == 0:
1062        return []
1063
1064    if replace=='AllZero':
1065        replace = bitexpr.FalseLiteral()
1066    if replace == 'AllOne':
1067        replace = bitexpr.TrueLiteral()
1068
1069    if isinstance(code[0], bitexpr.BitAssign):
1070        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
1071
1072    if isinstance(code[0], bitexpr.If):
1073        if code[0].control_expr.var.varname == target:
1074            code[0].control_expr.var = replace
1075        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
1076        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
1077
1078    if isinstance(code[0], bitexpr.WhileLoop):
1079        if code[0].control_expr.var.varname == target:
1080            code[0].control_expr.var = replace
1081        #code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
1082
1083    return [code[0]]+apply_single_opt(code[1:], target, replace)
1084
1085def apply_all_opt(s):
1086    if len(s) == 0:
1087        return []
1088    elif isinstance(s[0], bitexpr.If):
1089        target = s[0].control_expr.var.varname
1090        replace = s[0].control_expr.val
1091        apply_single_opt(s[0].true_branch, target, replace)
1092        apply_all_opt(s[0].true_branch)
1093        apply_all_opt(s[0].false_branch)
1094
1095    elif isinstance(s[0], bitexpr.WhileLoop):
1096        apply_all_opt(s[0].stmts)
1097
1098    return [s[0]]+apply_all_opt(s[1:])
1099 
1100#################################################################################################################
1101## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation. --- Pass 8
1102## method prune is supposed to remove unreachable branches but it is incomplete
1103#################################################################################################################
1104
1105def prune(fixed, tree):
1106    """removes unreachable branches of the tree"""
1107    target = tree[0].target
1108    replace = tree[0].replace
1109    if target in fixed:
1110        if replace == 'allone':
1111            if isinstance(fixed[target], bitexpr.TrueLiteral):
1112                tree[0] = None
1113                tree[1].join(tree[2][1])
1114                del tree[3]
1115                del tree[2]
1116                fixed.update(tree[1].simplify(fixed))
1117            if isinstance(fixed[target], bitexpr.FalseLiteral):
1118                tree[0] = None
1119                tree[1].join(tree[3][1])
1120                del tree[3]
1121                del tree[2]
1122                fixed.update(tree[1].simplify(fixed))
1123        if replace == 'allzero':
1124            if isinstance(fixed[target], bitexpr.FalseLiteral):
1125                tree[0] = None
1126                tree[1].join(tree[2][1])
1127                del tree[3]
1128                del tree[2]
1129                fixed.update(tree[1].simplify(fixed))
1130            if isinstance(fixed[target], bitexpr.TrueLiteral):
1131                tree[0] = None
1132                tree[1].join(tree[3][1])
1133                del tree[3]
1134                del tree[2]
1135                fixed.update(tree[1].simplify(fixed))
1136
1137    empty_list = []
1138    for index, branch in enumerate(tree[2:]):
1139        if branch[0] is None:
1140            if branch[1].code == []:
1141                empty_list.append(2+index)
1142
1143    empty_list.reverse()
1144    for i in empty_list:
1145        del tree[i]
1146def filter_fixed(fixed, stmts):
1147    for loc in stmts:
1148        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
1149            del fixed[loc.LHS.varname]
1150        if isinstance(loc, bitexpr.If):
1151            fixed = filter_fixed(fixed, loc.true_branch)
1152            fixed = filter_fixed(fixed, loc.false_branch)
1153        if isinstance(loc, bitexpr.WhileLoop):
1154            fixed = filter_fixed(fixed, loc.stmts)
1155    return fixed
1156
1157def simplify_tree(code, fixed = {}, prev = []):
1158    if len(code) == 0:
1159        return []
1160
1161    assumptions = {}
1162
1163    if isinstance(code[0], bitexpr.BitAssign):
1164        this, next = recurse_forward(code)
1165        fixed.update(basic_block.simplify(this, fixed))
1166        return this+simplify_tree(next, fixed, this)
1167
1168    elif isinstance(code[0], bitexpr.If):
1169        fixed1 = copy.deepcopy(fixed)
1170        fixed2 = copy.deepcopy(fixed)
1171        if isinstance(code[0].control_expr, bitexpr.isAllOne):
1172            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
1173        if isinstance(code[0].control_expr, bitexpr.isAllZero):
1174            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
1175        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
1176        fixed1.update(assumptions)
1177        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
1178        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
1179        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1180
1181    elif isinstance(code[0], bitexpr.WhileLoop):
1182        fixed = filter_fixed(fixed, code[0].stmts)
1183        fixed1 = copy.deepcopy(fixed)
1184        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
1185        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1186
1187#################################################################################################################
1188## Dead Code Elimination
1189#################################################################################################################
1190def get_effective_name(varname):
1191    if varname.startswith("array") or varname.startswith("strct"):
1192        try:
1193            ind = varname.index("__")
1194            return varname[6:ind]
1195        except:
1196            return varname
1197    return varname
1198
1199def check_loc(loc, must_liv):
1200
1201    #effective_name = get_effective_name(loc.LHS.varname)
1202    effective_name = loc.LHS.varname
1203    if effective_name in must_liv or loc.LHS.varname in must_liv:
1204        if isinstance(loc.RHS, bitexpr.Not):
1205            return set([loc.RHS.operand1.varname]), False
1206        elif isinstance(loc.RHS, bitexpr.Var):
1207            return set([loc.RHS.varname]), False
1208        elif isinstance(loc.RHS, bitexpr.FalseLiteral):
1209            return set([]), False
1210        elif isinstance(loc.RHS, bitexpr.TrueLiteral):
1211            return set([]), False
1212        elif isinstance(loc.RHS, bitexpr.Const):
1213            return set([]), False
1214        else:
1215            return set([loc.RHS.operand1.varname, loc.RHS.operand2.varname]), False
1216    else:
1217        return set([]), True
1218
1219def remove_copies(bb):
1220    """removes all copy statements e.g. var1 = var2"""
1221    lhs = [x.LHS.varname for x in bb]
1222    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
1223    for i in rhs:
1224        if i[1] in lhs:
1225            line = lhs.index(i[1])
1226            if i[0] > line:
1227                bb[i[0]].RHS = bb[line].RHS
1228    return bb
1229
1230def remove_dead(bb, must_live):
1231    bb = remove_copies(bb)
1232    my_alives = set([])
1233    dead = []
1234
1235    for line, loc in reversed(list(enumerate(bb))):
1236        new_lives, removable = check_loc(loc, my_alives.union(must_live))
1237
1238        if removable:
1239            dead.append(line)
1240        else:
1241            my_alives = my_alives.union(new_lives)
1242
1243    for i in dead:
1244        del bb[i]
1245
1246    return my_alives, bb
1247
1248def get_loop_variables(code):
1249    must_live = set([])
1250    for loc in code:
1251        if isinstance(loc, bitexpr.WhileLoop):
1252            if not loc.carry_expr is None:
1253                must_live.add(loc.carry_expr.var.varname)
1254        if isinstance(loc, bitexpr.If):
1255            true_lives = get_loop_variables(loc.true_branch)
1256            false_lives = get_loop_variables(loc.false_branch)
1257            must_live = must_live.union(true_lives)
1258            must_live = must_live.union(false_lives)
1259    return must_live
1260
1261def eliminate_dead_code(tree, must_live):
1262    if len(tree) == 0:
1263        return must_live, []
1264
1265    last = len(tree) - 1
1266    new_alives = set([])
1267    bb = []
1268
1269    must_live = must_live.union(get_loop_variables(tree))
1270
1271    if isinstance(tree[-1], bitexpr.BitAssign):
1272        first = 0
1273        for i in reversed(range(len(tree))):
1274            if not isinstance(tree[i], bitexpr.BitAssign):
1275                first = i+1
1276                break
1277
1278        new_alives, bb = remove_dead(tree[first:], must_live)
1279        last = first
1280    elif isinstance(tree[-1], bitexpr.If):
1281        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
1282        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
1283        bb = [tree[-1]]
1284        new_alives.add(tree[-1].control_expr.var.varname)
1285
1286    elif isinstance(tree[-1], bitexpr.WhileLoop):
1287        must_live.add(tree[-1].control_expr.var.varname)
1288        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
1289        bb = [tree[-1]]
1290
1291    all_alives = new_alives.union(must_live)
1292    dummy, new_tree = eliminate_dead_code(tree[:last], all_alives)
1293    tree = new_tree+bb
1294    return all_alives.union(dummy), tree
1295
1296#################################################################################################################
1297## This pass processes the code in the while loop and adds extra code required for handling carry variables
1298#################################################################################################################
1299carry_suffix = "_i"
1300
1301def fix_the_loop(loop, carry_type):
1302    carries = []
1303    for loc in loop.stmts:
1304        if isinstance(loc.RHS, bitexpr.Add):
1305            carries.append(loc.RHS.carry)
1306        if isinstance(loc.RHS, bitexpr.Sub):
1307            carries.append(loc.RHS.brw)
1308    for item in carries:
1309        newvar = bitexpr.Var(item+carry_suffix)
1310        if (not loop.fake): loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), carry_type)))
1311    for item in carries:
1312        if (not loop.fake): loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral(carry_type) ))
1313
1314    return loop, carries
1315
1316def process_while_loops(code):
1317    all = []
1318    update = {}
1319   
1320    carry_type = "int"
1321
1322   
1323    for index, loc in enumerate(code):
1324        if isinstance(loc, bitexpr.WhileLoop):
1325            update[index] = fix_the_loop(loc, carry_type)
1326            for i in update[index][1]:
1327                all.append(i)
1328                all.append(i+carry_suffix)
1329
1330        if isinstance(loc, bitexpr.If):
1331            loc.true_branch, live_list = process_while_loops(loc.true_branch)
1332            loc.false_branch, live_list = process_while_loops(loc.false_branch)
1333
1334    keys = [k for k in update]
1335    keys.sort(reverse=True)
1336    for key in keys:
1337        code[key] = update[key][0]
1338        carry_variable = bitexpr.Var("CarryTemp"+str(key))
1339        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral(carry_type)))
1340        for item in update[key][1][2:]:
1341            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), carry_type)))
1342        if len(update[key][1]) == 1:
1343            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
1344        elif len(update[key][1]) > 1:
1345            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), carry_type)))
1346        for item in update[key][1]:
1347            if (not code[key].fake): code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral(carry_type)))
1348        for item in update[key][1]:
1349            if (not code[key].fake): code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
1350
1351        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
1352    return code, all
1353#################################################################################################################
1354## This pass factors out the code that is common between the true branch and false branch of an if
1355## statement.
1356#################################################################################################################
1357
1358def are_the_same(one, two):
1359    if one.__class__ != two.__class__:
1360        return False
1361   
1362    if isinstance(one, bitexpr.BitAssign):
1363        if (one.LHS.varname != two.LHS.varname):
1364            return False
1365        if one.RHS.__class__ != two.RHS.__class__:
1366            return False
1367        if isinstance(one.RHS, bitexpr.FalseLiteral):
1368            return True
1369        if isinstance(one.RHS, bitexpr.TrueLiteral):
1370            return True
1371        if isinstance(one.RHS, bitexpr.Var):
1372            return (one.RHS.varname == two.RHS.varname)
1373        if isinstance(one.RHS, bitexpr.Const):
1374            return one.RHS.n == two.RHS.n
1375        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
1376
1377    if isinstance(one, bitexpr.If):
1378        if (one.control_expr.__class__ != two.control_expr.__class__):
1379            return False
1380        if (one.control_expr.var.varname != two.control_expr.var.varname):
1381            return False
1382        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
1383            return False
1384        for ind in range(len(one.true_branch)):
1385            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
1386                return False
1387        for ind in range(len(one.false_branch)):
1388            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
1389                return False
1390        return True
1391
1392    if isinstance(one, bitexpr.WhileLoop):
1393        if (one.control_expr.__class__ != two.control_expr.__class__):
1394            return False
1395        if (one.control_expr.var.varname != two.control_expr.var.varname):
1396            return False
1397        if (one.carry_expr.__class__ != two.carry_expr.__class__):
1398            return False
1399        #if (one.carry_expr.var.varname != two.carry_expr.var.varname):
1400        #    return False
1401        if (len(one.stmts) != len(two.stmts)):
1402            return False
1403        for ind in range(len(one.stmts)):
1404            if not are_the_same(one.stmts[ind], two.stmts[ind]):
1405                return False
1406        return True
1407
1408def get_factorable(cond):
1409    earliest = 0
1410    l = min(len(cond.true_branch), len(cond.false_branch))
1411    for index in reversed(range(-l,0)):
1412        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
1413            earliest = index
1414        else:
1415            break
1416
1417    common_length = -earliest
1418    return common_length
1419
1420def get_common_code(cond, common):
1421    true_len = len(cond.true_branch)-common
1422    false_len = len(cond.false_branch)-common
1423
1424    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
1425    common = cond.true_branch[true_len:]
1426
1427    return [new_cond], common
1428
1429def do_factorization(code, pos):
1430    indices = [key for key in pos]
1431    indices.sort()
1432
1433    for index in reversed(indices):
1434        new_cond, common = get_common_code(code[index], pos[index])
1435        code = code[:index]+new_cond+common+code[index+1:]
1436    return code
1437
1438def factor_out(code):
1439    for index, loc in enumerate(code):
1440        if isinstance(loc, bitexpr.If):
1441            code[index].true_branch = factor_out(code[index].true_branch)
1442            code[index].false_branch = factor_out(code[index].false_branch)
1443        if isinstance(loc, bitexpr.WhileLoop):
1444            code[index].stmts = factor_out(code[index].stmts)
1445
1446    pos = {}
1447    for index, loc in enumerate(code):
1448        if isinstance(loc, bitexpr.If):
1449            common_length = get_factorable(code[index])
1450            pos[index] = common_length
1451
1452    return do_factorization(code, pos)
1453
1454#################################################################################################################
1455## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 10
1456## ***Changes needed here are the same as Pass 7 and Pass 10***
1457#################################################################################################################
1458def generate_condition(expr):
1459    if isinstance(expr, bitexpr.isAllZero):
1460        return "!bitblock_has_bit(%s)"%(expr.var.varname)
1461    elif isinstance(expr, bitexpr.isAllOne):
1462        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
1463    elif isinstance(expr, bitexpr.isNoneZero):
1464        return "bitblock_has_bit(%s)"%(expr.var.varname)
1465    else:
1466        assert (1==0)
1467
1468def generate_statement(expr, indent, cond_stmt):
1469    if isinstance(expr, bitexpr.isAllZero):
1470        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1471    elif isinstance(expr, bitexpr.isAllOne):
1472        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1473    elif isinstance(expr, bitexpr.isNoneZero):
1474        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1475    else:
1476        assert (1==0)
1477
1478def print_block_stmts(s, indent = 0):
1479    indent_unit = 4
1480    code = ""
1481    if len(s) == 0:
1482        return ""
1483    if isinstance(s[0], bitexpr.If):
1484        code = generate_statement(s[0].control_expr, indent, "if")
1485        code += print_block_stmts(s[0].true_branch, indent+indent_unit)
1486        code += " "*indent+"}\n"
1487        code += " "*indent+"else {\n"
1488        code += print_block_stmts(s[0].false_branch, indent+indent_unit)
1489        code += " "*indent+"}\n"
1490
1491    elif isinstance(s[0], bitexpr.WhileLoop) and (not s[0].fake):
1492        code = "\n%s((%s)||(test_carry(%s))) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname)
1493        code += print_block_stmts(s[0].stmts, indent+indent_unit)
1494        code += " "*indent+"}\n"
1495    elif isinstance(s[0], bitexpr.WhileLoop) and (s[0].fake):
1496        code = "\n%s((%s)||(test_carry(%s))) {\n"%(" "*indent+"if"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname)
1497        code += print_block_stmts(s[0].stmts, indent+indent_unit)
1498        code += " "*indent+"}\n"
1499    elif (isinstance(s[0].RHS, bitexpr.Const)):
1500        pass
1501 
1502    elif isinstance(s[0], bitexpr.BitAssign) and not (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1503
1504        code += " "*indent + s[0].LHS.varname
1505        code += " = "
1506
1507        if s[0].RHS.data_type == "vector":
1508            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1509                code += s[0].RHS.operand1.varname
1510                code += ";\n"
1511            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
1512                code += s[0].RHS.operand1.varname
1513                code += ";\n"
1514            elif isinstance(s[0].RHS, bitexpr.Var):
1515                code += s[0].RHS.operand1.varname
1516                code += ";\n"
1517
1518            else:
1519                code += s[0].RHS.op_C + "("
1520                code += s[0].RHS.operand1.varname
1521                code += ','
1522                code += s[0].RHS.operand2.varname
1523                code += ");\n"
1524        elif s[0].RHS.data_type == "int":
1525            if isinstance(s[0].RHS, bitexpr.Or):
1526                code += "carry_or(%s, %s);\n" % (s[0].RHS.operand1.varname, s[0].RHS.operand2.varname)
1527            elif isinstance(s[0].RHS, bitexpr.FalseLiteral):
1528                code += "Carry0;\n"
1529    elif isinstance(s[0], bitexpr.BitAssign) and (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1530            if isinstance(s[0].RHS, bitexpr.Add) and (s[0].RHS.operand1.varname == s[0].RHS.operand2.varname):
1531                code += " "*indent + s[0].RHS.op_advance + "("
1532                code += s[0].RHS.operand1.varname
1533                code += ', '
1534                code += s[0].RHS.carry
1535                code += ", "
1536                code += s[0].LHS.varname
1537                code += ");\n"
1538            elif isinstance(s[0].RHS, bitexpr.Add):
1539                code += " "*indent + s[0].RHS.op_C + "("
1540                code += s[0].RHS.operand1.varname
1541                code += ', '
1542                code += s[0].RHS.operand2.varname
1543                code += ', '
1544                code += s[0].RHS.carry
1545                code += ", "
1546                code += s[0].LHS.varname
1547                code += ");\n"
1548            elif isinstance(s[0].RHS, bitexpr.Sub):
1549                code += " "*indent + s[0].RHS.op_C + "("
1550                code += s[0].RHS.operand1.varname
1551                code += ', '
1552                code += s[0].RHS.operand2.varname
1553                code += ', '
1554                code += s[0].RHS.brw
1555                code += ", "
1556                code += s[0].LHS.varname
1557                code += ");\n"
1558    s.pop(0)
1559    return code+print_block_stmts(s, indent)
1560
1561
1562def print_stream_stmts(s, indent = 0):
1563    indent_unit = 4
1564    code = ""
1565    if len(s) == 0:
1566        return ""
1567    elif isinstance(s[0], bitexpr.BitAssign):
1568        if (isinstance(s[0].RHS, bitexpr.Const)):
1569            code += " "*indent + s[0].LHS.varname
1570            code += " = "
1571            code += "sisd_from_int(%i);"%s[0].RHS.n
1572            code += "\n"
1573
1574    s.pop(0)
1575    return code+print_stream_stmts(s, indent)
1576
1577#################################################################################################################
1578## Auxiliary Functions
1579#################################################################################################################
1580
1581def get_bb_of(code, index):
1582    if index < 0 or index >= len(code):
1583        return None
1584
1585    if isinstance(code[index], bitexpr.If):
1586        return code[index]
1587    if isinstance(code[index], bitexpr.WhileLoop):
1588        return code[index]
1589    if isinstance(code[index], bitexpr.BitAssign):
1590        last = index
1591        for loc in code[index+1:]:
1592            if isinstance(loc, bitexpr.BitAssign):
1593                last += 1
1594            else:
1595                break
1596 
1597        first = index
1598        for loc in reversed(code[:index]):
1599            if isinstance(loc, bitexpr.BitAssign):
1600                first -= 1
1601            else:
1602                break
1603
1604        return code[first:last+1]
1605
1606def get_previous_bb(code, index):
1607    if index < 0:
1608        return None
1609    if index >= len(code):
1610        return get_bb_of(code, len(code)-1)
1611    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
1612        return get_bb_of(code, index-1)
1613    if isinstance(code[index], bitexpr.BitAssign):
1614        for ind, loc in reversed(enumerate(code[:index])):
1615            if not isinstance(loc, bitexpr.BitAssign):
1616                return get_bb_of(code, ind)
1617        return get_bb_of(code, -1)
1618    assert(1==0)
1619
1620def recurse_forward(code):
1621    if isinstance(code[0], bitexpr.If):
1622        return code[0], code[1:]
1623    if isinstance(code[0], bitexpr.WhileLoop):
1624        return code[0], code[1:]
1625    if isinstance(code[0], bitexpr.BitAssign):
1626        last = 0
1627        for loc in code[1:]:
1628            if isinstance(loc, bitexpr.BitAssign):
1629                last += 1
1630            else:
1631                break
1632        return code[0:last+1], code[last+1:]
Note: See TracBrowser for help on using the repository browser.