source: proto/Compiler/py2bitexpr.py @ 512

Last change on this file since 512 was 512, checked in by cameron, 9 years ago

Eliminate int_carry parameter

File size: 58.6 KB
RevLine 
[344]1import ast, bitexpr, copy, basic_block
[297]2
[363]3type_dict = {}
4structs = {}
5arrays = {}
6
[344]7AllOne = 'AllOne'
8AllZero = 'AllZero'
9
[365]10STRUCT_TEMPLATE="strct_%s__%s_"
11ARRAY_TEMPLATE="array_%s__%s_"
[363]12
13
[297]14class PyBitError(Exception):
15        pass
16
[320]17#############################################################################################
[363]18## Extracting list of live variables
19#############################################################################################
20def get_lives(s):
21    for item in s.body:
22        if isinstance(item, ast.FunctionDef):
23            if (item.name =="main"):
24                assert (isinstance(item.body[-1], ast.Return))
25                ret = item.body[-1].value
26                if isinstance(ret, ast.Tuple):
27                    del item.body[-1]
28                    return [translate_var(x) for x in ret.elts],s
29                else:
30                    assert(isinstance(ret, ast.Name))
31                    del item.body[-1]
32                    return [translate_var(ret)],s
33
34#############################################################################################
[349]35## Function Inlining
36#############################################################################################
37TEMP_VAR_TEMPLATE = "InlineTemp%i"
38
[363]39def get_var_name(var):
40    var_name = translate_var(var)
41    if isinstance(var, ast.Name):
42        type_name = "simple"
43        extra = None
44    if isinstance(var, ast.Attribute):
45        type_name = "structure"
46        extra = var.attr
47    if isinstance(var, ast.Index):
48        type_name = "array"
49        extra = translate_index(v.slice.value)
[349]50
[363]51    return type_name, var_name, extra
52
53def register_var(final_name, original_name, type_name, extra):
54    pass
55
[421]56def collect_functions(module, file_path):
[363]57    imps = []
58    to_remove = []
59    for index, item in enumerate(module.body):
60        if isinstance(item, ast.Import):
61            imps.append(item)
62            to_remove.append(index)
63
64    for n in reversed(to_remove):
65        del module.body[n]
66
67    for imp in imps:
68        for mod in imp.names:
[421]69            mod_file = open(file_path + mod.name + '.py')
[363]70            mod_code = ""
71            for line in mod_file:
72                mod_code += line
73            mod_file.close()
74            mod_parsed = ast.parse(mod_code)
[365]75            modified = []
[363]76            for item in mod_parsed.body:
77                if isinstance(item, ast.FunctionDef):
78                    item.name = translate_var(ast.Attribute(value=ast.Name(id=mod.name), attr=item.name))
79            module.body += mod_parsed.body
80    return module
81
[365]82
83def remove_unsupported_code(callee, fn_names, predefined, main=False):
84    to_remove = []
85    for index, loc in enumerate(callee.body):
[368]86        if isinstance(loc, ast.While):
87            if not main:
88                to_remove.append(index)
89            continue
90
91        if isinstance(loc, ast.Expr):
92            if isinstance(loc.value, ast.Call):
93                if main and not (translate_var(loc.value.func) in fn_names+predefined):
[365]94                    to_remove.append(index)
[368]95                if not main and not translate_var(loc.value.func) in predefined:
96                    to_remove.append(index)
[365]97
98            else:
99                to_remove.append(index)
[368]100
101            continue
102
103        if not (isinstance(loc, ast.Assign) or isinstance(loc, ast.AugAssign) or isinstance(loc, ast.Return)):
104            to_remove.append(index)
105
[365]106        #we do not support AugAssign with function calls
107        if isinstance(loc, ast.AugAssign):
108            if isinstance(loc.value, ast.Call):
109                assert (1==0)
110                to_remove.append(index)
111
112        #we do not support function calls in functions other than main
113        #these are OK if they are object construction, otherwise error.
114        if isinstance(loc, ast.Assign):
115            if isinstance(loc.value, ast.Call):
116                if main and not (translate_var(loc.value.func) in fn_names+predefined):
117                    to_remove.append(index)
118                if not main and not translate_var(loc.value.func) in predefined:
119                    to_remove.append(index)
120    to_remove.sort(reverse=True)
121    for line_no in to_remove:
122        del callee.body[line_no]
[368]123
[365]124    return callee
125
126
127def clean_functions(module, fn_names, predefined):
128    new = []
129    for item in module.body:
130        if isinstance(item, ast.FunctionDef):
131            new.append(remove_unsupported_code(item, fn_names, predefined, item.name == "main")) 
132        else:
[368]133            new.append(item)
[365]134    module.body = new
135    return module
136
[363]137def collect_names(body):
138    names = {}
139    for index, item in enumerate(body):
140        if isinstance(item, ast.FunctionDef):
141            names[item.name] = index
142
143    return names
144
145def modify_name(var, args, cargs, unique_prefix):
146    carg_names = [x.id for x in cargs]
147    arg_names = [translate_var(x) for x in args]
148
149    varname = translate_var(var)
150    if isinstance(var, ast.Name):
151        if varname in carg_names:
152            return ast.Name(id=arg_names[carg_names.index(varname)])
153        return ast.Name(id=unique_prefix+varname)
154    if isinstance(var, ast.Attribute) or isinstance(var, ast.Subscript):
155        if var.value.id in carg_names:
156            var.value.id = arg_names[carg_names.index(var.value.id)]
157            return ast.Name(id=translate_var(var))
158        return ast.Name(id=unique_prefix+translate_var(var))
159
160def replace_in_exp(exp, args, cargs, unique_prefix):
[349]161    if isinstance(exp, ast.BinOp):
[363]162        exp.left = replace_in_exp(exp.left, args, cargs, unique_prefix)
163        exp.right = replace_in_exp(exp.right, args, cargs, unique_prefix)
[349]164    elif isinstance(exp, ast.UnaryOp):
[363]165        exp.operand = replace_in_exp(exp.operand, args, cargs, unique_prefix)
[349]166    elif isinstance(exp, ast.Name):
[363]167        exp = modify_name(exp, args, cargs, unique_prefix)
168    elif isinstance(exp, ast.Attribute):
169        exp = modify_name(exp, args, cargs, unique_prefix)
[349]170    elif isinstance(exp, ast.Subscript):
[363]171        exp = modify_name(exp, args, cargs, unique_prefix)
172    elif isinstance(exp, ast.Call):
173        for index, item in enumerate(exp.args):
174            exp.args[index] = replace_in_exp(item, args, cargs, unique_prefix)
175    elif isinstance(exp, ast.Tuple):
176        for index, elt in enumerate(exp.elts):
177            exp.elts[index] = replace_in_exp(elt, args, cargs, unique_prefix)
[365]178    elif isinstance(exp, ast.Num):
179        pass
[349]180    else:
181        assert(1==0)
182    return exp
183
[363]184def get_all_calls(main, all_func):
[349]185    call_list = []
186    for index, loc in enumerate(main.body):
187        if isinstance(loc, ast.Assign) and isinstance(loc.value, ast.Call):
[363]188                func_name = translate_var(loc.value.func)
189                if func_name in all_func:
190                    call_list.append(index)
[349]191        if isinstance(loc, ast.AugAssign) and isinstance(loc.value, ast.Call):
[363]192                func_name = translate_var(loc.value.func)
193                if func_name in all_func:
194                    call_list.append(index)
[349]195    return call_list
196
[363]197def update_var_names(callee, args):
198    unique_prefix = '_'+callee.name.replace('.', '_')+'_'
[349]199
200    assert (len(args)==len(callee.args.args))
201
202    for index, loc in enumerate(callee.body):
203        if isinstance(loc, ast.Assign):
204            for index, var in enumerate(loc.targets):
[363]205                loc.targets[index] = modify_name(var, args, callee.args.args, unique_prefix)
206            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
[349]207        elif isinstance(loc, ast.AugAssign):
[363]208            loc.target = modify_name(loc.target, args, callee.args.args, unique_prefix)
209            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
[349]210        elif isinstance(loc, ast.Return):
[363]211            pass
212            #loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
[349]213        else:
[363]214            assert (1==0)
[349]215
[363]216    return callee, unique_prefix
[349]217
[363]218
219def get_name(loc, index):
220    if isinstance(loc.value, ast.Tuple):
221        return translate_var(loc.value.elts[index])
222    elif index == 0:
223        return translate_var(loc.value)
224    else:
225        assert(1==0)
226
227def modify_assignments_in_caller(main, line_no, return_list, prefix, callee):
228    if isinstance(main.body[line_no].targets[0], ast.Tuple):
229        all_vars = main.body[line_no].targets[0].elts
230    else:
231        all_vars = main.body[line_no].targets
232    new_code = []
233
234    if isinstance(callee.body[-1].value, ast.Tuple):
235        return_vars = [translate_var(x) for x in callee.body[-1].value.elts]
236    else:
237        return_vars = [translate_var(callee.body[-1].value)]
238
239    for index, item in enumerate(all_vars):
240        actual_name = translate_var(item)
241        if actual_name in [translate_var(x) for x in main.body[line_no].value.args]:
242            continue
243        temp_name = return_vars[index]
244
245        if temp_name in [translate_var(x) for x in callee.args.args]:
246            second_index = [translate_var(x) for x in callee.args.args].index(temp_name)
247            final_name = translate_var(main.body[line_no].value.args[second_index])
248        else:
249            final_name = prefix+temp_name
250
251        if return_list[temp_name][0] == "struct":
252            my_template = STRUCT_TEMPLATE
253        elif return_list[temp_name][0] == "array":
254            my_template = ARRAY_TEMPLATE
255 
256        if return_list[temp_name][0] == "bitblock":
257            new_code.append(ast.Assign( [ast.Name(id=item.id)] , ast.Name(id=final_name) ))
258        else:
259            for elem in return_list[temp_name][1]:
260                new_code.append(ast.Assign([ast.Name(id=my_template%(item.id, elem))], ast.Name(id=prefix+my_template%(temp_name, elem))))
261    main.body = main.body[:line_no]+new_code+main.body[line_no+1:]
[349]262    return main
[363]263
264def get_return_list(callee):
265    ret_list = []
266    loc = callee.body[-1]
267    if isinstance(loc, ast.Return):
268        if isinstance(loc.value, ast.Tuple):
269            all_rets = loc.value.elts
270        else:
271            all_rets = [loc.value]
272        for item in all_rets:
273            if isinstance(item, ast.Attribute) or isinstance(item, ast.Subscript):
274                ret_list.append(item.value.id)
275            if isinstance(item, ast.Name):
276                ret_list.append(item.id)
277    return ret_list
278
279def get_updated(loc):
280    if isinstance(loc, ast.AugAssign):
281        updated = [loc.target]
282    elif isinstance(loc, ast.Assign):
283        if isinstance(loc.targets[0], ast.Tuple):
284            updated = loc.targets[0].elts
285        else:
286            updated = loc.targets
287    else:
288        updated = []
289    return updated
290
291def get_update_details(callee, ret_list):
292    field_info = {}
293    for item in ret_list:
294        field_info[item] = ['simple', set([])]
295    for loc in callee.body:
296        updated = get_updated(loc)
297        for item in updated:
298            type, name, extra = parse_var(translate_var(item, True))
299            if name in field_info:
300                field_info[name][0] = type
301                if type=='struct':
302                    field_info[name][1].add(extra)
303                if type == 'array':
304                    field_info[name][1].add(extra)
305    return field_info
306
307def process_return(callee):
308    if not isinstance(callee.body[-1], ast.Return):
309        return {}
310    ret_list = get_return_list(callee)
311    field_info = get_update_details(callee, ret_list)
312    return field_info
313
[365]314
315def generate_expanded_code(loc):
316    new_code = []
317    array_name = translate_var(loc.targets[0])
318    for ind, i in enumerate(range(len(loc.value.elts))):
319        new_loc = ast.Assign([ast.Subscript(value=ast.Name(id=array_name), slice=ast.Index(ast.Num(n=ind)) )], loc.value.elts[i])
320        new_code.append(new_loc)
321
322    return new_code
323
324def expand_list_init(module):
325    for item in module.body:
326        replace_code = {}
327        if isinstance(item, ast.FunctionDef):
[368]328            for index,loc in enumerate(item.body):
[365]329                if isinstance(loc, ast.Assign):
330                    if isinstance(loc.value, ast.List):
331                        equivalent_code = generate_expanded_code(loc)
332                        replace_code[index] = equivalent_code
333            keys = [key for key in replace_code]
334            keys.sort(reverse = True)
335            for line_no in keys:
336                item.body = item.body[:line_no]+replace_code[line_no]+item.body[line_no+1:]
[368]337
[365]338    return module
[363]339
[365]340def get_predefined_funcs():
341    advance = ast.Attribute(value=ast.Name(id="bitutil"), attr="Advance")
342    scan_thru = ast.Attribute(value=ast.Name(id="bitutil"), attr="ScanThru")
[368]343    optimize = ast.Attribute(value=ast.Name(id="bitutil"), attr="Optimize")
344    return [translate_var(advance), translate_var(scan_thru), translate_var(optimize)]
[365]345
[421]346def do_inlining(module, file_path):
347    module = collect_functions(module, file_path)
[363]348    func_dict = collect_names(module.body)
[365]349    module = expand_list_init(module)
[368]350
[365]351    valid_fn = [name for name in func_dict] 
352    predef_fn = get_predefined_funcs()
353    module = clean_functions(module, valid_fn, predef_fn)
[368]354
[349]355    main = module.body[func_dict["main"]]
[363]356    call_list = get_all_calls(main, [func for func in func_dict])
357    call_list.sort(reverse=True)
[349]358    for line_no in call_list:
[363]359        callee = copy.deepcopy(module.body[func_dict[translate_var(main.body[line_no].value.func)]])
[349]360        args = main.body[line_no].value.args
[363]361        return_list = process_return(callee)
362        callee, prefix = update_var_names(callee, args)
[349]363
[363]364        main = modify_assignments_in_caller(main, line_no, return_list, prefix, callee)
[349]365
[363]366        if isinstance(callee.body[-1], ast.Return):
367            del callee.body[-1]
368        main.body = main.body[:line_no]+callee.body+main.body[line_no:]
369    return module.body[func_dict["main"]].body
370
[349]371#############################################################################################
[323]372## Translation to bitExpr classes --- Pass 2
[320]373#############################################################################################
374
[297]375def translate_index(ast_value):
376        if isinstance(ast_value, ast.Str): return ast_value.s
377        elif isinstance(ast_value, ast.Num): return repr(ast_value.n)
378        else: raise PyBitError("Unknown value %s\n" % ast.dump(ast_value))
379
380def is_Advance_Call(fncall):
381        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
382        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
383                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
384        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
385
386def is_ScanThru_Call(fncall):
387        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
388        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
389                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
390        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
391
392def translate(ast_expr):
393        if isinstance(ast_expr, ast.Name):
[319]394                if ast_expr.id=="allzero":
395                        return bitexpr.FalseLiteral()
396                if ast_expr.id=="allone":
397                        return bitexpr.TrueLiteral()
[363]398                return bitexpr.Var(translate_var(ast_expr))
[353]399        elif isinstance(ast_expr, ast.Num):
[365]400                if ast_expr.n == 0:
401                    return bitexpr.FalseLiteral()
402                else:
403                    return bitexpr.Const(ast_expr.n)
[297]404        elif isinstance(ast_expr, ast.Attribute):
405                if isinstance(ast_expr.value, ast.Name):
[363]406                        return bitexpr.Var(translate_var(ast_expr))
[297]407                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, ast_expr.attr))
408                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(ast_expr))
409        elif isinstance(ast_expr, ast.Subscript):
[363]410                if isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice, ast.Index):
411                        return bitexpr.Var(translate_var(ast_expr))
[297]412                        return bitexpr.Var("%s[%s]" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
413                else: raise PyBitError("Illegal array %s\n" % ast.dump(ast_expr))
414        elif isinstance(ast_expr, ast.UnaryOp):
415                e1 = translate(ast_expr.operand)
416                if isinstance(ast_expr.op, ast.Invert):
417                        return bitexpr.make_not(e1)
418                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
419        elif isinstance(ast_expr, ast.BinOp):
420                e1 = translate(ast_expr.left)
421                e2 = translate(ast_expr.right)
422                if isinstance(ast_expr.op, ast.BitOr):
423                        return bitexpr.make_or(e1, e2)
424                elif isinstance(ast_expr.op, ast.BitAnd):
425                        return bitexpr.make_and(e1, e2)
426                elif isinstance(ast_expr.op, ast.BitXor):
427                        return bitexpr.make_xor(e1, e2)
428                elif isinstance(ast_expr.op, ast.Add):
429                        return bitexpr.make_add(e1, e2)
430                elif isinstance(ast_expr.op, ast.Sub):
431                        return bitexpr.make_sub(e1, e2)
432                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
433        elif isinstance(ast_expr, ast.Call):
434                if is_Advance_Call(ast_expr):
435                        e = translate(ast_expr.args[0])
[319]436                        temp = bitexpr.make_add(e,e)
437                        return temp
[297]438                elif is_ScanThru_Call(ast_expr):
439                        e0 = translate(ast_expr.args[0])
440                        e1 = translate(ast_expr.args[1])
441                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
[363]442                else: 
443                    raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
[344]444        elif isinstance(ast_expr, ast.Compare):
445                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
446                        e0 = translate(ast_expr.left)
447                        return bitexpr.isNoneZero(e0)
448                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
[365]449               
450        else: 
451            raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
[297]452
[363]453def translate_var(v, aggregate_type = False):
[297]454        if isinstance(v, ast.Name):
[363]455            agg_name = v.id
456            var_name = v.id
[297]457        elif isinstance(v, ast.Attribute):
[363]458           if isinstance(v.value, ast.Name):
459               var_name = STRUCT_TEMPLATE%(v.value.id, v.attr)
460               agg_name = "%s.%s" % (v.value.id, v.attr)
461           else: raise PyBitError("Illegal attribute %s\n" % ast.dump(v))
[297]462        elif isinstance(v, ast.Subscript):
[363]463           if isinstance(v.value, ast.Name) and isinstance(v.slice, ast.Index):
464               var_name = ARRAY_TEMPLATE%(v.value.id, translate_index(v.slice.value))
465               agg_name = "%s[%s]" % (v.value.id, translate_index(v.slice.value))
[365]466               
[363]467        else: raise PyBitError("Unknown operator %s\n" % ast.dump(v))
[297]468
[363]469        if aggregate_type:
470            return agg_name
471        else:
472            return var_name
[297]473
474def translate_stmts(ast_stmts):
475        translated = []
476        for s in ast_stmts:
[319]477                if isinstance(s, ast.Expr):
[368]478                        if (translate_var(s.value.func)=='strct_bitutil__Optimize_'):
[344]479                            target = translate_var(s.value.args[0])
480                            replace = translate_var(s.value.args[1])
481                            translated.append(bitexpr.Reduce(target, replace))
482                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
[319]483                elif isinstance(s, ast.Assign):
[297]484                        e = translate(s.value)
485                        for astv in s.targets: 
[319]486                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
[297]487                elif isinstance(s, ast.AugAssign):
[319]488                        v = bitexpr.Var(translate_var(s.target))
[297]489                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
[322]490                elif isinstance(s, ast.While):
491                        e = translate(s.test)
492                        body = translate_stmts(s.body)
[363]493                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body))
[297]494                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
495        return translated
[320]496#############################################################################################
[323]497## Conversion to SSA form --- Pass 3
[320]498#############################################################################################
[297]499
[319]500def extract_vars(rhs):
501    if isinstance(rhs, bitexpr.Var):
502        return [rhs.varname]
[365]503    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral) or isinstance(rhs, bitexpr.Const):
[319]504        return []
505    if isinstance(rhs, bitexpr.Not):
[344]506        return extract_vars(rhs.operand1)
[319]507    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
508    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
509
[344]510def update_lineno(inner_table, shift):
511    for key in inner_table:
512        for index in range(len(inner_table[key][0])):
513            inner_table[key][0][index] += shift
514        for index in range(len(inner_table[key][1])):
515            inner_table[key][1][index] += shift
516    return inner_table
517
518def merge_tables(table, inner_table):
519    for key in inner_table:
520        if key in table:
521            table[key][0] += inner_table[key][0]
522            table[key][1] += inner_table[key][1]
523        else:
524            table[key] = inner_table[key]
525    return table
526
527def gen_sym_table(code, goInside = False):
528    """Generate a simple symbol table for a three address code
529        each entry is of this form: var_name:[[defs][uses]]
530    """
531    table = {}
532    index = 0
533    for stmt in code:
534        if isinstance(stmt, bitexpr.Reduce):
535            if stmt.target in table:
536                table[stmt.target][1].append(index)
537            else:
538                table[stmt.target] = [[], [index]]
539            index += 1
540        elif isinstance(stmt, bitexpr.BitAssign):
541            current = stmt.LHS.varname
542            if current in table:
543                table[current][0].append(index)
544            else:
545                table[current] = [[index],[]]
546
547            varlist = extract_vars(stmt.RHS)
548            for var in varlist:
549                if var in table:
550                    table[var][1].append(index)
[319]551                else:
[344]552                    table[var] = [[], [index]]
553            index += 1
554        elif isinstance(stmt, bitexpr.WhileLoop):
555            cond_var = stmt.control_expr.var.varname
556            if cond_var in table:
557                table[cond_var][1].append(index)
558            else:
559                #while loop conditioned on an undefined variable
560                assert(1==0)
561            if goInside:
562                inner_table = gen_sym_table(stmt.stmts, True)
563                inner_table = update_lineno(inner_table, index+1)
564                table = merge_tables(table, inner_table)
565                index += (1+len(stmt.stmts))
566            else:
567                index += 1
568        else:
569            assert(1==0)
570    return table
571##################################################################################
572def get_line(code, line):
573    lineno = 0
574    for stmt in code:
575        if isinstance(stmt, bitexpr.BitAssign):
576            if lineno == line:
577                return stmt
578            lineno += 1
579        elif isinstance(stmt, bitexpr.WhileLoop):
580            if lineno == line:
581                return stmt
582            elif lineno+len(stmt.stmts) < line:
583                lineno += 1
584                line -= len(stmt.stmts)
585                continue
586            else:
587                return get_line(stmt.stmts, line-(lineno+1))
[319]588
[344]589        elif isinstance(stmt, bitexpr.Reduce):
590            if lineno == line:
591                return stmt
592            lineno += 1
593        else:
594            assert(1==0)
595
596def update_def(code, var, line, suffix_num):
597    loc = code[line]
598    if isinstance(loc, bitexpr.BitAssign):
599        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
600        return loc
601    else:
602        #either Reduce or While Loop in both cases it's a use
603        pass
604
605def update_rhs(rhs, varname, newname):
606    if isinstance(rhs, bitexpr.Var):
607        if rhs.varname == varname:
608            rhs.varname = newname
609        return rhs
610
611    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
612        return rhs
613
614    if isinstance(rhs, bitexpr.Not):
615        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
616        return rhs
617    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
618    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
619    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
620    return rhs
621
622def update_use(code, var, line, suffix_num):
623    loc = code[line]
624    if isinstance(loc, bitexpr.BitAssign):
625        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
626    elif isinstance(loc, bitexpr.Reduce):
627        pass
628    elif isinstance(loc, bitexpr.WhileLoop):
629        loc.control_expr.var.varname += "_%i"%suffix_num
630    return loc
631
[319]632def parse_var(var):
633    index=var.find('.')
634    if index >= 0:
635        return ('struct', var[0:index], var[index+1:])
636               
637    index = var.find('[')
638    if index >= 0:
639        right_index = var.find(']')
[363]640        return ('array', var[0:index], var[index+1:right_index])       
[344]641   
642    if var.startswith("carry") or var.startswith("Carry"):
643        return('int', var, None)
644   
[319]645    return ('bitblock', var, None)
646
647def simplify_name(var):
648    (vartype, name, extra) = parse_var(var)
649    if vartype == 'bitblock':
650        return name
651    if vartype == 'array':
652        return "a_%s_%s"%(name, extra)
653    if vartype == 'struct':
654        return "s_%s_%s"%(name, extra)
[344]655    if vartype == 'int':
656        assert(1==0)
[319]657    assert(1==0)
658
659def pairs(lst):
[344]660    if lst == []:
661            return []
662    return zip(lst,lst[1:]+[lst[0]])
[319]663
[344]664
[319]665def make_SSA(code, st):
[344]666    new_vars = []
667    total_lines = len(code)
668    for var in st:
669        st[var][0].append(total_lines)
670        st[var][1].append(total_lines)
[319]671
[344]672    unique_number = 0
673    for var in st:
674        use_index = 0
675        for current, next in pairs(st[var][0])[0:-2]:
676            code[current] = update_def(code, var, current, unique_number)
677            uline = st[var][1][use_index]
[320]678
[344]679            while uline <= next and uline < total_lines:
680                if uline > current:
681                    code[uline] = update_use(code, var, uline, unique_number)
682                use_index += 1
683                uline = st[var][1][use_index]
684            unique_number += 1
685    return code
686
[320]687#################################################################################################################
[323]688## Breaking the code into basic blocks depending on where optimize pragma appears --- Pass 4
[320]689#################################################################################################################
[344]690def get_opt_list(s):
691    opt_list = []
692    remove_index = []
693    for line, stmt in enumerate(s):
694        if isinstance(stmt, bitexpr.Reduce):
695            opt_pair = (stmt.target, stmt.replace)
696            opt_list.append(opt_pair)
697            remove_index.append(line)
698    for ind in reversed(remove_index):
699        del s[ind]
700
701    return opt_list
702
[320]703def break_to_segs(s, breaks):
704    segs = []
705    breaks = [-1]+breaks+[len(s)]
706    for start, stop in pairs(breaks)[:-1]:
707        next = s[start+1:stop]
708        if len(next)>0:
709            segs.append(next)
710    return segs
[323]711
[344]712def extract_pragmas(s):
713    targets = []
[320]714    exprs = []
[344]715    lines = []
[323]716
[344]717    #extracting pragmas and their targets
[320]718    for line, loc in enumerate(s):
719        if isinstance(loc, bitexpr.Reduce):
[344]720            targets.append(loc.target)
[320]721            exprs.append(loc)
[344]722            lines.append(line)
[320]723
[344]724    #removing pragmas from the source code
725    for i in reversed(lines):
726        del s[i]
[320]727
[344]728    return targets, exprs
[320]729
[344]730def get_defs(targets, s):
731    temp = [(x,-1) for x in targets]
732    targ_dic = dict(temp)
733    for line, loc in enumerate(s):
734        if loc.LHS.varname in targets:
735            targ_dic[loc.LHS.varname] = line
736
737    items = targ_dic.items()
738    items = sorted(items, key=(lambda x: x[1]))
739
740    lineno = [key for value, key in items]
741    sorted_targets = [value for value, key in items]
742
743    return sorted_targets, lineno
744
[369]745def defines(varname, stmts):
746    is_defined = False
747    for loc in stmts:
748        if isinstance(loc, bitexpr.WhileLoop):
749            is_defined |= defines(varname, loc.stmts)
750        elif isinstance(loc, bitexpr.BitAssign): 
751            if loc.LHS.varname == varname:
752                is_defined = True
[344]753
[369]754    return is_defined
[344]755
[369]756def adjust_earliest(s, earliest, varname):
[344]757    if earliest == -1:
758        return -1
759
[369]760    line = earliest+1
761    for stmt in s[earliest+1:]:
[344]762        if isinstance(stmt, bitexpr.WhileLoop):
[369]763            if defines(varname, stmt.stmts):
764               earliest = line+1 
765        line+=1
766    return earliest
[344]767
768def chop_code(s, start, stop):
769    line = 0
[368]770
[344]771    if start == -1:
772        cut1 = -1
[368]773
[344]774    if stop == len(s):
775        cut2 = len(s)
[368]776
[344]777    for index, stmt in enumerate(s):
778        if isinstance(stmt, bitexpr.BitAssign):
779            if line == start:
780                cut1 = index
781            if line == stop:
782                cut2 = index
783            line += 1
784
785        if isinstance(stmt, bitexpr.WhileLoop):
786            if line == start:
787                cut1 = index
[369]788            if line  == stop:
[344]789                cut2 = index
[369]790            line+=1
[344]791
792    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
793
794def count_lines(code):
795    cnt = len(code)
796    for stmt in code:
797        if isinstance(stmt, bitexpr.WhileLoop):
798            cnt += count_lines(stmt.stmts)
799    return cnt
800
801
[369]802def gen_cond_obj(opt_pragma):
803    if opt_pragma[1] == 'AllZero' or opt_pragma[1] == 'allzero':
804        cond_obj = bitexpr.isAllZero(opt_pragma[0])
805    if opt_pragma[1] == 'AllOne' or opt_pragma[1] == 'allone':
806        cond_obj = bitexpr.isAllOne(opt_pragma[0])
807    return cond_obj
808
809def gen_bb(s, opt_list, def_line):
810    assert (len(opt_list) == len(def_line))
[344]811    if opt_list==[]:
812        return s
813
[369]814    earliest = min(def_line)
815    latest = len(s)
816    earliest_index = def_line.index(earliest)
817    early_varname = opt_list[earliest_index][0]
818   
819    earliest = adjust_earliest(s, earliest, early_varname)
[344]820
821    first, second, third = chop_code(s, earliest, latest)
[368]822
[369]823    cond_obj = gen_cond_obj(opt_list[earliest_index])
824   
825    del opt_list[earliest_index]
826    del def_line[earliest_index]
827    def_line = [x-earliest-1 for x in def_line]
828   
829    second = gen_bb(second, opt_list, def_line)
[344]830
[368]831    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))] + third
[344]832    return result
833
834def partition2bb(s):
835    basicblocks = []
836    lineno = []
837    def_line = []
838    opt_list = get_opt_list(s)
839
840    st2 = gen_sym_table(s)
841
842    for item in opt_list:
843        if len(st2[item[0]][0]) > 0:
844            def_line.append(st2[item[0]][0][-1])
845        else:
846            def_line.append(-1)
847
848
[369]849    return gen_bb(s, opt_list, def_line)
[344]850
[320]851#################################################################################################################
[344]852## Generating declarations for variables. This pass is based on the syntax of C programming language ---
853## Normalizing the code
854#################################################################################################################
855
856def normalize(s, predec = {}, ccelim=True):
857    if len(s)==0:
858        return []
859    if isinstance(s[0], bitexpr.If):
860        gc = basic_block.BasicBlock.gensym_counter
861        cc = basic_block.BasicBlock.carry_counter
862        bc = basic_block.BasicBlock.brw_counter
863        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
864        maxgc = basic_block.BasicBlock.gensym_counter
865        maxcc = basic_block.BasicBlock.carry_counter
866        maxbc = basic_block.BasicBlock.brw_counter
867        basic_block.BasicBlock.gensym_counter = gc
868        basic_block.BasicBlock.carry_counter = cc
869        basic_block.BasicBlock.brw_counter = bc
870        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
871        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
872        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
873        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
874       
875        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
876    if isinstance(s[0], bitexpr.WhileLoop):
877        orig = copy.deepcopy(predec)
878        s[0].stmts = normalize(s[0].stmts, predec, False)
[364]879        return [s[0]]+normalize(s[1:], {})
[344]880    if isinstance(s[0], bitexpr.BitAssign):
881        code, next = recurse_forward(s)
882        bb = basic_block.BasicBlock(predec)
883        predec.update(bb.normalize(code, ccelim))
884        return bb.get_code() + normalize(next, copy.deepcopy(predec))
885
886#################################################################################################################
[323]887## Generating declarations for variables. This pass is based on the syntax of C programming language --- Pass 5
[320]888#################################################################################################################
[344]889def get_vars(stmt):
[320]890        ints = set([])
891        bitblocks = set(['AllOne', 'AllZero'])
892        arrays = {}
893        structs = {}
894
[344]895        #for stmt in code:
[320]896
[344]897        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
[365]898       
[320]899
[365]900
[344]901        if isinstance(stmt.RHS, bitexpr.Add): 
902                ints.add(stmt.RHS.carry)
903        if isinstance(stmt.RHS, bitexpr.Sub):
904                ints.add(stmt.RHS.brw)
[320]905
[344]906        for var in all_vars:
[365]907                if isinstance(var, bitexpr.Const):
908                    continue
909               
[344]910                (var_type, name, extra) = parse_var(var.varname)
[320]911
[344]912                if (var_type == "bitblock"):
913                        bitblocks.add(name)
[320]914
[344]915                if var_type == "array":
916                        if not name in arrays:
917                                arrays[name]= extra
918                        else:
919                                arrays[name] = max(arrays[name], extra)
[320]920
[344]921                if var_type == "struct":
922                        if not name in structs:
923                                structs[name] = set([extra])
924                        else:
925                                structs[name].add(extra)
926                if var_type == "int":
927                    ints.add(name)
[429]928                       
[320]929        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
930
931def merge_var_dic(first, second):
932    res = {}
933    res['int'] = first['int'].union(second['int'])
934    res['bitblock'] = first['bitblock'].union(second['bitblock'])
935
936    res['array'] = first['array']
937    temp = dict({'array':copy.copy(second['array'])})
938    for i in res['array']:
939        if i in temp['array']:
940            res['array'][i] = max(res['array'][i], temp['array'][i])
941            del temp['array'][i]
942    res['array'].update(temp['array'])
943
[323]944
[320]945    res['struct'] = first['struct']
946    temp = dict({'struct':copy.copy(second['struct'])})
947    for i in res['struct']:
948        if i in temp['struct']:
949            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
950            del temp['struct'][i]
951    res['struct'].update(temp['struct'])
952    return res
953
[512]954def gen_output(var_dic):
[320]955        s = ''
956        for i in var_dic['bitblock']:
[344]957                if i == AllOne:
[320]958                        s += "BitBlock %s = simd_const_1(1);\n"%i
[429]959                elif i == AllZero:
[320]960                        s+="BitBlock %s = simd_const_1(0);\n"%i
[429]961                elif i == "EOF_mask":
962                        s+="BitBlock EOF_mask = simd_const_1(1);\n"
[320]963                else:
964                        s+="BitBlock %s;\n"%i
965
966        for i in var_dic['array']:
967                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
968
969        for i in var_dic['struct']:
970                s+="struct __%s__{\n"%i
971                for j in var_dic['struct'][i]:
972                        s+= "\tBitBlock %s;\n"%j
973                s+="};\n"
974                s+="struct __%s__ %s;\n"%(i,i)
975
[429]976        for i in  var_dic['int']:
[512]977                s+="CarryType %s = Carry0;\n"%i
[429]978
979
980
[512]981
[320]982        return s
983
[344]984def gen_var_dic(s):
985    if len(s)==0:
986        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
987    if isinstance(s[0], bitexpr.BitAssign):
988        vd = get_vars(s[0])
989        more = gen_var_dic(s[1:])
990        vd = merge_var_dic(vd, more)
991        return vd
992
993    if isinstance(s[0], bitexpr.If):
994        vd1 = gen_var_dic(s[0].true_branch)
995        vd2 = gen_var_dic(s[0].false_branch)
996        vd = merge_var_dic(vd1, vd2)
997        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
998        vd =  merge_var_dic(vd, vd3)
999        more = gen_var_dic(s[1:])
1000        vd  = merge_var_dic(vd, more)
1001        return vd
1002
1003    if isinstance(s[0], bitexpr.WhileLoop):
1004        vd = gen_var_dic(s[0].stmts)
1005        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1006        vd =  merge_var_dic(vd, vd1)
1007        more = gen_var_dic(s[1:])
1008        vd  = merge_var_dic(vd, more)
1009        return vd
1010
[512]1011def gen_declarations(s):
[344]1012    vd = gen_var_dic(s)
[512]1013    return gen_output(vd)
[320]1014#################################################################################################################
[429]1015## This class replaces all occurrences of a reduced variable to its value --- Pass 7
[344]1016## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
1017## *** the new AST notation                                                                                   ***
[320]1018#################################################################################################################
[344]1019def replace_in_rhs(rhs, target, replace):
[369]1020    if isinstance(rhs, bitexpr.Const):
1021        return rhs
[344]1022    if isinstance(rhs, bitexpr.Var):
1023        if rhs.varname == target:
1024            rhs = replace
1025            return replace
[320]1026
[344]1027    if isinstance(rhs.operand1, bitexpr.Var):
1028        if rhs.operand1.varname == target:
1029            rhs.operand1 = replace
[320]1030    else:
[344]1031        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
1032            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
[320]1033
[344]1034    if isinstance(rhs.operand2, bitexpr.Var):
1035        if rhs.operand2.varname == target:
1036            rhs.operand2 = replace
[320]1037    else:
[344]1038        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
1039            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
1040    return rhs
[320]1041
[344]1042def apply_single_opt(code, target, replace):
1043    if len(code) == 0:
1044        return []
[320]1045
[344]1046    if replace=='AllZero':
1047        replace = bitexpr.FalseLiteral()
1048    if replace == 'AllOne':
1049        replace = bitexpr.TrueLiteral()
[320]1050
[344]1051    if isinstance(code[0], bitexpr.BitAssign):
1052        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
[320]1053
[344]1054    if isinstance(code[0], bitexpr.If):
1055        if code[0].control_expr.var.varname == target:
1056            code[0].control_expr.var = replace
1057        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
1058        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
[320]1059
[344]1060    if isinstance(code[0], bitexpr.WhileLoop):
1061        if code[0].control_expr.var.varname == target:
1062            code[0].control_expr.var = replace
[369]1063        #code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
[320]1064
[344]1065    return [code[0]]+apply_single_opt(code[1:], target, replace)
[320]1066
[344]1067def apply_all_opt(s):
1068    if len(s) == 0:
1069        return []
1070    elif isinstance(s[0], bitexpr.If):
1071        target = s[0].control_expr.var.varname
1072        replace = s[0].control_expr.val
1073        apply_single_opt(s[0].true_branch, target, replace)
1074        apply_all_opt(s[0].true_branch)
1075        apply_all_opt(s[0].false_branch)
[320]1076
[344]1077    elif isinstance(s[0], bitexpr.WhileLoop):
1078        apply_all_opt(s[0].stmts)
[320]1079
[344]1080    return [s[0]]+apply_all_opt(s[1:])
1081 
[320]1082#################################################################################################################
[323]1083## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation. --- Pass 8
1084## method prune is supposed to remove unreachable branches but it is incomplete
[320]1085#################################################################################################################
[344]1086
[320]1087def prune(fixed, tree):
1088    """removes unreachable branches of the tree"""
1089    target = tree[0].target
1090    replace = tree[0].replace
1091    if target in fixed:
1092        if replace == 'allone':
1093            if isinstance(fixed[target], bitexpr.TrueLiteral):
1094                tree[0] = None
1095                tree[1].join(tree[2][1])
1096                del tree[3]
1097                del tree[2]
1098                fixed.update(tree[1].simplify(fixed))
1099            if isinstance(fixed[target], bitexpr.FalseLiteral):
1100                tree[0] = None
1101                tree[1].join(tree[3][1])
1102                del tree[3]
1103                del tree[2]
1104                fixed.update(tree[1].simplify(fixed))
1105        if replace == 'allzero':
1106            if isinstance(fixed[target], bitexpr.FalseLiteral):
1107                tree[0] = None
1108                tree[1].join(tree[2][1])
1109                del tree[3]
1110                del tree[2]
1111                fixed.update(tree[1].simplify(fixed))
1112            if isinstance(fixed[target], bitexpr.TrueLiteral):
1113                tree[0] = None
1114                tree[1].join(tree[3][1])
1115                del tree[3]
1116                del tree[2]
1117                fixed.update(tree[1].simplify(fixed))
1118
1119    empty_list = []
1120    for index, branch in enumerate(tree[2:]):
1121        if branch[0] is None:
1122            if branch[1].code == []:
1123                empty_list.append(2+index)
1124
1125    empty_list.reverse()
1126    for i in empty_list:
1127        del tree[i]
[344]1128def filter_fixed(fixed, stmts):
1129    for loc in stmts:
1130        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
1131            del fixed[loc.LHS.varname]
1132        if isinstance(loc, bitexpr.If):
1133            fixed = filter_fixed(fixed, loc.true_branch)
1134            fixed = filter_fixed(fixed, loc.false_branch)
1135        if isinstance(loc, bitexpr.WhileLoop):
1136            fixed = filter_fixed(fixed, loc.stmts)
1137    return fixed
[370]1138
[344]1139def simplify_tree(code, fixed = {}, prev = []):
1140    if len(code) == 0:
1141        return []
[320]1142
[323]1143    assumptions = {}
1144
[344]1145    if isinstance(code[0], bitexpr.BitAssign):
1146        this, next = recurse_forward(code)
1147        fixed.update(basic_block.simplify(this, fixed))
1148        return this+simplify_tree(next, fixed, this)
[323]1149
[344]1150    elif isinstance(code[0], bitexpr.If):
1151        fixed1 = copy.deepcopy(fixed)
1152        fixed2 = copy.deepcopy(fixed)
1153        if isinstance(code[0].control_expr, bitexpr.isAllOne):
1154            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
1155        if isinstance(code[0].control_expr, bitexpr.isAllZero):
1156            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
1157        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
1158        fixed1.update(assumptions)
1159        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
1160        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
1161        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
[320]1162
[344]1163    elif isinstance(code[0], bitexpr.WhileLoop):
1164        fixed = filter_fixed(fixed, code[0].stmts)
1165        fixed1 = copy.deepcopy(fixed)
1166        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
1167        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1168
[323]1169#################################################################################################################
[373]1170## Dead Code Elimination
[323]1171#################################################################################################################
[352]1172def get_effective_name(varname):
[365]1173    if varname.startswith("array") or varname.startswith("strct"):
1174        try:
1175            ind = varname.index("__")
1176            return varname[6:ind]
1177        except:
1178            return varname
[352]1179    return varname
[369]1180
[323]1181def check_loc(loc, must_liv):
[368]1182
[352]1183    effective_name = get_effective_name(loc.LHS.varname)
[365]1184    if effective_name in must_liv or loc.LHS.varname in must_liv:
[323]1185        if isinstance(loc.RHS, bitexpr.Not):
[344]1186            return set([loc.RHS.operand1.varname]), False
[323]1187        elif isinstance(loc.RHS, bitexpr.Var):
1188            return set([loc.RHS.varname]), False
1189        elif isinstance(loc.RHS, bitexpr.FalseLiteral):
1190            return set([]), False
1191        elif isinstance(loc.RHS, bitexpr.TrueLiteral):
1192            return set([]), False
[365]1193        elif isinstance(loc.RHS, bitexpr.Const):
1194            return set([]), False
[323]1195        else:
1196            return set([loc.RHS.operand1.varname, loc.RHS.operand2.varname]), False
1197    else:
1198        return set([]), True
1199
1200def remove_copies(bb):
1201    """removes all copy statements e.g. var1 = var2"""
[344]1202    lhs = [x.LHS.varname for x in bb]
1203    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
[323]1204    for i in rhs:
1205        if i[1] in lhs:
1206            line = lhs.index(i[1])
[344]1207            if i[0] > line:
1208                bb[i[0]].RHS = bb[line].RHS
1209    return bb
[323]1210
1211def remove_dead(bb, must_live):
[373]1212    bb = remove_copies(bb)
[323]1213    my_alives = set([])
1214    dead = []
[344]1215
1216    for line, loc in reversed(list(enumerate(bb))):
[323]1217        new_lives, removable = check_loc(loc, my_alives.union(must_live))
[344]1218
[323]1219        if removable:
1220            dead.append(line)
1221        else:
1222            my_alives = my_alives.union(new_lives)
1223
1224    for i in dead:
[344]1225        del bb[i]
[323]1226
[344]1227    return my_alives, bb
[323]1228
[368]1229def get_loop_variables(code):
1230    must_live = set([])
1231    for loc in code:
1232        if isinstance(loc, bitexpr.WhileLoop):
1233            if not loc.carry_expr is None:
1234                must_live.add(loc.carry_expr.var.varname)
1235        if isinstance(loc, bitexpr.If):
1236            true_lives = get_loop_variables(loc.true_branch)
1237            false_lives = get_loop_variables(loc.false_branch)
1238            must_live = must_live.union(true_lives)
1239            must_live = must_live.union(false_lives)
1240    return must_live
1241
[344]1242def eliminate_dead_code(tree, must_live):
1243    if len(tree) == 0:
[368]1244        return must_live, []
[344]1245
1246    last = len(tree) - 1
1247    new_alives = set([])
1248    bb = []
1249
[368]1250    must_live = must_live.union(get_loop_variables(tree))
[344]1251
1252    if isinstance(tree[-1], bitexpr.BitAssign):
1253        first = 0
1254        for i in reversed(range(len(tree))):
1255            if not isinstance(tree[i], bitexpr.BitAssign):
1256                first = i+1
1257                break
[368]1258
[344]1259        new_alives, bb = remove_dead(tree[first:], must_live)
1260        last = first
1261
1262    elif isinstance(tree[-1], bitexpr.If):
1263        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
1264        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
1265        bb = [tree[-1]]
1266        new_alives.add(tree[-1].control_expr.var.varname)
[349]1267
[344]1268    elif isinstance(tree[-1], bitexpr.WhileLoop):
1269        must_live.add(tree[-1].control_expr.var.varname)
1270        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
1271        bb = [tree[-1]]
[349]1272
[344]1273    all_alives = new_alives.union(must_live)
[368]1274    dummy, new_tree = eliminate_dead_code(tree[:last], all_alives)
[344]1275    tree = new_tree+bb
[368]1276    return all_alives.union(dummy), tree
[344]1277
1278#################################################################################################################
1279## This pass processes the code in the while loop and adds extra code required for handling carry variables
1280#################################################################################################################
1281carry_suffix = "_i"
1282
[429]1283def fix_the_loop(loop, carry_type):
[344]1284    carries = []
1285    for loc in loop.stmts:
1286        if isinstance(loc.RHS, bitexpr.Add):
1287            carries.append(loc.RHS.carry)
1288    for item in carries:
1289        newvar = bitexpr.Var(item+carry_suffix)
[429]1290        loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), carry_type)))
[344]1291    for item in carries:
[429]1292        loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral(carry_type) ))
[349]1293
[344]1294    return loop, carries
[323]1295
[512]1296def process_while_loops(code):
[344]1297    all = []
1298    update = {}
[429]1299   
[512]1300    carry_type = "int"
1301
[429]1302   
[344]1303    for index, loc in enumerate(code):
1304        if isinstance(loc, bitexpr.WhileLoop):
[429]1305            update[index] = fix_the_loop(loc, carry_type)
[344]1306            for i in update[index][1]:
1307                all.append(i)
1308                all.append(i+carry_suffix)
[349]1309
[368]1310        if isinstance(loc, bitexpr.If):
1311            loc.true_branch, live_list = process_while_loops(loc.true_branch)
1312            loc.false_branch, live_list = process_while_loops(loc.false_branch)
1313
[344]1314    keys = [k for k in update]
1315    keys.sort(reverse=True)
1316    for key in keys:
1317        code[key] = update[key][0]
1318        carry_variable = bitexpr.Var("CarryTemp"+str(key))
[429]1319        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral(carry_type)))
[344]1320        for item in update[key][1][2:]:
[429]1321            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), carry_type)))
[344]1322        if len(update[key][1]) == 1:
1323            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
1324        elif len(update[key][1]) > 1:
[429]1325            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), carry_type)))
[344]1326        for item in update[key][1]:
[429]1327            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral(carry_type)))
[344]1328        for item in update[key][1]:
1329            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
[323]1330
[344]1331        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
1332    return code, all
[320]1333#################################################################################################################
[344]1334## This pass factors out the code that is common between the true branch and false branch of an if
1335## statement.
[320]1336#################################################################################################################
[344]1337
1338def are_the_same(one, two):
1339    if one.__class__ != two.__class__:
1340        return False
[369]1341   
[344]1342    if isinstance(one, bitexpr.BitAssign):
1343        if (one.LHS.varname != two.LHS.varname):
1344            return False
1345        if one.RHS.__class__ != two.RHS.__class__:
1346            return False
1347        if isinstance(one.RHS, bitexpr.FalseLiteral):
1348            return True
1349        if isinstance(one.RHS, bitexpr.TrueLiteral):
1350            return True
1351        if isinstance(one.RHS, bitexpr.Var):
1352            return (one.RHS.varname == two.RHS.varname)
[369]1353        if isinstance(one.RHS, bitexpr.Const):
1354            return one.RHS.n == two.RHS.n
[344]1355        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
1356
1357    if isinstance(one, bitexpr.If):
1358        if (one.control_expr.__class__ != two.control_expr.__class__):
1359            return False
1360        if (one.control_expr.var.varname != two.control_expr.var.varname):
1361            return False
1362        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
1363            return False
1364        for ind in range(len(one.true_branch)):
1365            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
1366                return False
1367        for ind in range(len(one.false_branch)):
1368            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
1369                return False
1370        return True
1371
1372    if isinstance(one, bitexpr.WhileLoop):
1373        if (one.control_expr.__class__ != two.control_expr.__class__):
1374            return False
1375        if (one.control_expr.var.varname != two.control_expr.var.varname):
1376            return False
1377        if (one.carry_expr.__class__ != two.carry_expr.__class__):
1378            return False
[369]1379        #if (one.carry_expr.var.varname != two.carry_expr.var.varname):
1380        #    return False
[344]1381        if (len(one.stmts) != len(two.stmts)):
1382            return False
1383        for ind in range(len(one.stmts)):
1384            if not are_the_same(one.stmts[ind], two.stmts[ind]):
1385                return False
1386        return True
1387
1388def get_factorable(cond):
1389    earliest = 0
1390    l = min(len(cond.true_branch), len(cond.false_branch))
1391    for index in reversed(range(-l,0)):
1392        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
1393            earliest = index
1394        else:
1395            break
1396
1397    common_length = -earliest
1398    return common_length
1399
1400def get_common_code(cond, common):
1401    true_len = len(cond.true_branch)-common
1402    false_len = len(cond.false_branch)-common
1403
1404    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
1405    common = cond.true_branch[true_len:]
1406
1407    return [new_cond], common
1408
1409def do_factorization(code, pos):
1410    indices = [key for key in pos]
1411    indices.sort()
1412
1413    for index in reversed(indices):
1414        new_cond, common = get_common_code(code[index], pos[index])
1415        code = code[:index]+new_cond+common+code[index+1:]
1416    return code
1417
1418def factor_out(code):
1419    for index, loc in enumerate(code):
1420        if isinstance(loc, bitexpr.If):
1421            code[index].true_branch = factor_out(code[index].true_branch)
1422            code[index].false_branch = factor_out(code[index].false_branch)
1423        if isinstance(loc, bitexpr.WhileLoop):
1424            code[index].stmts = factor_out(code[index].stmts)
1425
1426    pos = {}
1427    for index, loc in enumerate(code):
1428        if isinstance(loc, bitexpr.If):
1429            common_length = get_factorable(code[index])
1430            pos[index] = common_length
1431
1432    return do_factorization(code, pos)
1433
1434#################################################################################################################
1435## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 10
1436## ***Changes needed here are the same as Pass 7 and Pass 10***
1437#################################################################################################################
1438
1439def generate_condition(expr):
1440    if isinstance(expr, bitexpr.isAllZero):
1441        return "!bitblock_has_bit(%s)"%(expr.var.varname)
1442    elif isinstance(expr, bitexpr.isAllOne):
1443        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
1444    elif isinstance(expr, bitexpr.isNoneZero):
1445        return "bitblock_has_bit(%s)"%(expr.var.varname)
[320]1446    else:
1447        assert (1==0)
1448
[344]1449def generate_statement(expr, indent, cond_stmt):
1450    if isinstance(expr, bitexpr.isAllZero):
1451        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1452    elif isinstance(expr, bitexpr.isAllOne):
1453        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1454    elif isinstance(expr, bitexpr.isNoneZero):
1455        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1456    else:
1457        assert (1==0)
[320]1458
[512]1459def print_block_stmts(s, indent = 0):
[320]1460    indent_unit = 4
[344]1461    code = ""
1462    if len(s) == 0:
1463        return ""
1464    if isinstance(s[0], bitexpr.If):
1465        code = generate_statement(s[0].control_expr, indent, "if")
[512]1466        code += print_block_stmts(s[0].true_branch, indent+indent_unit)
[344]1467        code += " "*indent+"}\n"
1468        code += " "*indent+"else {\n"
[512]1469        code += print_block_stmts(s[0].false_branch, indent+indent_unit)
[344]1470        code += " "*indent+"}\n"
[323]1471
[429]1472    elif isinstance(s[0], bitexpr.WhileLoop):
[455]1473        code = "\n%s((%s)||(test_carry(%s))) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname)
[512]1474        code += print_block_stmts(s[0].stmts, indent+indent_unit)
[344]1475        code += " "*indent+"}\n"
[429]1476    elif (isinstance(s[0].RHS, bitexpr.Const)):
1477        pass
1478 
1479    elif isinstance(s[0], bitexpr.BitAssign) and not (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
[344]1480
1481        code += " "*indent + s[0].LHS.varname
1482        code += " = "
1483
1484        if s[0].RHS.data_type == "vector":
1485            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1486                code += s[0].RHS.operand1.varname
1487                code += ";\n"
1488            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
1489                code += s[0].RHS.operand1.varname
1490                code += ";\n"
1491            elif isinstance(s[0].RHS, bitexpr.Var):
1492                code += s[0].RHS.operand1.varname
1493                code += ";\n"
1494
1495            else:
1496                code += s[0].RHS.op_C + "("
1497                code += s[0].RHS.operand1.varname
1498                code += ','
1499                code += s[0].RHS.operand2.varname
1500                code += ");\n"
1501        elif s[0].RHS.data_type == "int":
1502            if isinstance(s[0].RHS, bitexpr.Or):
[460]1503                code += "carry_or(%s, %s);\n" % (s[0].RHS.operand1.varname, s[0].RHS.operand2.varname)
[365]1504            elif isinstance(s[0].RHS, bitexpr.FalseLiteral):
[456]1505                code += "Carry0;\n"
[454]1506    elif isinstance(s[0], bitexpr.BitAssign) and (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1507            if isinstance(s[0].RHS, bitexpr.Add) and (s[0].RHS.operand1.varname == s[0].RHS.operand2.varname):
1508                code += " "*indent + s[0].RHS.op_advance + "("
1509                code += s[0].RHS.operand1.varname
1510                code += ', '
1511                code += s[0].RHS.carry
1512                code += ", "
1513                code += s[0].LHS.varname
1514                code += ");\n"
1515            elif isinstance(s[0].RHS, bitexpr.Add):
[373]1516                code += " "*indent + s[0].RHS.op_C + "("
1517                code += s[0].RHS.operand1.varname
1518                code += ', '
1519                code += s[0].RHS.operand2.varname
1520                code += ', '
1521                code += s[0].RHS.carry
1522                code += ", "
1523                code += s[0].LHS.varname
1524                code += ");\n"
1525            elif isinstance(s[0].RHS, bitexpr.Sub):
1526                code += " "*indent + s[0].RHS.op_C + "("
1527                code += s[0].RHS.operand1.varname
1528                code += ', '
1529                code += s[0].RHS.operand2.varname
1530                code += ', '
1531                code += s[0].RHS.brw
1532                code += ", "
1533                code += s[0].LHS.varname
1534                code += ");\n"
[344]1535    s.pop(0)
[512]1536    return code+print_block_stmts(s, indent)
[344]1537
[429]1538
1539def print_stream_stmts(s, indent = 0):
1540    indent_unit = 4
1541    code = ""
1542    if len(s) == 0:
1543        return ""
1544    elif isinstance(s[0], bitexpr.BitAssign):
1545        if (isinstance(s[0].RHS, bitexpr.Const)):
1546            code += " "*indent + s[0].LHS.varname
1547            code += " = "
1548            code += "sisd_from_int(%i);"%s[0].RHS.n
1549            code += "\n"
1550
1551    s.pop(0)
1552    return code+print_stream_stmts(s, indent)
1553
[344]1554#################################################################################################################
1555## Auxiliary Functions
1556#################################################################################################################
1557
1558def get_bb_of(code, index):
1559    if index < 0 or index >= len(code):
1560        return None
[363]1561
[344]1562    if isinstance(code[index], bitexpr.If):
1563        return code[index]
1564    if isinstance(code[index], bitexpr.WhileLoop):
1565        return code[index]
1566    if isinstance(code[index], bitexpr.BitAssign):
1567        last = index
1568        for loc in code[index+1:]:
1569            if isinstance(loc, bitexpr.BitAssign):
1570                last += 1
1571            else:
1572                break
1573 
1574        first = index
1575        for loc in reversed(code[:index]):
1576            if isinstance(loc, bitexpr.BitAssign):
1577                first -= 1
1578            else:
1579                break
1580
1581        return code[first:last+1]
1582
1583def get_previous_bb(code, index):
1584    if index < 0:
1585        return None
1586    if index >= len(code):
1587        return get_bb_of(code, len(code)-1)
1588    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
1589        return get_bb_of(code, index-1)
1590    if isinstance(code[index], bitexpr.BitAssign):
1591        for ind, loc in reversed(enumerate(code[:index])):
1592            if not isinstance(loc, bitexpr.BitAssign):
1593                return get_bb_of(code, ind)
1594        return get_bb_of(code, -1)
1595    assert(1==0)
1596
1597def recurse_forward(code):
1598    if isinstance(code[0], bitexpr.If):
1599        return code[0], code[1:]
1600    if isinstance(code[0], bitexpr.WhileLoop):
1601        return code[0], code[1:]
1602    if isinstance(code[0], bitexpr.BitAssign):
1603        last = 0
1604        for loc in code[1:]:
1605            if isinstance(loc, bitexpr.BitAssign):
1606                last += 1
1607            else:
1608                break
1609        return code[0:last+1], code[last+1:]
Note: See TracBrowser for help on using the repository browser.