source: proto/Compiler/py2bitexpr.py @ 365

Last change on this file since 365 was 365, checked in by eamiri, 9 years ago

Bit stream initialization support added.
Improvements in inlining.

File size: 59.5 KB
Line 
1import ast, bitexpr, copy, basic_block
2
3type_dict = {}
4structs = {}
5arrays = {}
6
7AllOne = 'AllOne'
8AllZero = 'AllZero'
9
10STRUCT_TEMPLATE="strct_%s__%s_"
11ARRAY_TEMPLATE="array_%s__%s_"
12
13
14class PyBitError(Exception):
15        pass
16
17#############################################################################################
18## Extracting list of live variables
19#############################################################################################
20def get_lives(s):
21    for item in s.body:
22        if isinstance(item, ast.FunctionDef):
23            if (item.name =="main"):
24                assert (isinstance(item.body[-1], ast.Return))
25                ret = item.body[-1].value
26                if isinstance(ret, ast.Tuple):
27                    del item.body[-1]
28                    return [translate_var(x) for x in ret.elts],s
29                else:
30                    assert(isinstance(ret, ast.Name))
31                    del item.body[-1]
32                    return [translate_var(ret)],s
33
34#############################################################################################
35## Function Inlining
36#############################################################################################
37TEMP_VAR_TEMPLATE = "InlineTemp%i"
38
39def get_var_name(var):
40    var_name = translate_var(var)
41    if isinstance(var, ast.Name):
42        type_name = "simple"
43        extra = None
44    if isinstance(var, ast.Attribute):
45        type_name = "structure"
46        extra = var.attr
47    if isinstance(var, ast.Index):
48        type_name = "array"
49        extra = translate_index(v.slice.value)
50
51    return type_name, var_name, extra
52
53def register_var(final_name, original_name, type_name, extra):
54    pass
55
56def collect_functions(module):
57    imps = []
58    to_remove = []
59    for index, item in enumerate(module.body):
60        if isinstance(item, ast.Import):
61            imps.append(item)
62            to_remove.append(index)
63
64    for n in reversed(to_remove):
65        del module.body[n]
66
67    for imp in imps:
68        for mod in imp.names:
69            mod_file = open(mod.name + '.py')
70            mod_code = ""
71            for line in mod_file:
72                mod_code += line
73            mod_file.close()
74            mod_parsed = ast.parse(mod_code)
75            modified = []
76            for item in mod_parsed.body:
77                if isinstance(item, ast.FunctionDef):
78                    item.name = translate_var(ast.Attribute(value=ast.Name(id=mod.name), attr=item.name))
79            module.body += mod_parsed.body
80    return module
81
82
83def remove_unsupported_code(callee, fn_names, predefined, main=False):
84    to_remove = []
85   
86    for index, loc in enumerate(callee.body):
87        if not (isinstance(loc, ast.Assign) or isinstance(loc, ast.AugAssign) or isinstance(loc, ast.Return)):
88            if isinstance(loc, ast.While):
89                if not main:
90                    to_remove.append(index)
91
92            else:
93                to_remove.append(index)
94       
95        #we do not support AugAssign with function calls
96        if isinstance(loc, ast.AugAssign):
97            if isinstance(loc.value, ast.Call):
98                assert (1==0)
99                to_remove.append(index)
100
101        #we do not support function calls in functions other than main
102        #these are OK if they are object construction, otherwise error.
103        if isinstance(loc, ast.Assign):
104            if isinstance(loc.value, ast.Call):
105                if main and not (translate_var(loc.value.func) in fn_names+predefined):
106                    to_remove.append(index)
107                if not main and not translate_var(loc.value.func) in predefined:
108                    to_remove.append(index)
109    to_remove.sort(reverse=True)
110    for line_no in to_remove:
111        del callee.body[line_no]
112    return callee
113
114
115def clean_functions(module, fn_names, predefined):
116    new = []
117    for item in module.body:
118        if isinstance(item, ast.FunctionDef):
119            new.append(remove_unsupported_code(item, fn_names, predefined, item.name == "main")) 
120        else:
121             new.append(item)
122    module.body = new
123    return module
124
125def collect_names(body):
126    names = {}
127    for index, item in enumerate(body):
128        if isinstance(item, ast.FunctionDef):
129            names[item.name] = index
130
131    return names
132
133def modify_name(var, args, cargs, unique_prefix):
134    carg_names = [x.id for x in cargs]
135    arg_names = [translate_var(x) for x in args]
136
137    varname = translate_var(var)
138    if isinstance(var, ast.Name):
139        if varname in carg_names:
140            return ast.Name(id=arg_names[carg_names.index(varname)])
141        return ast.Name(id=unique_prefix+varname)
142    if isinstance(var, ast.Attribute) or isinstance(var, ast.Subscript):
143        if var.value.id in carg_names:
144            var.value.id = arg_names[carg_names.index(var.value.id)]
145            return ast.Name(id=translate_var(var))
146        return ast.Name(id=unique_prefix+translate_var(var))
147
148def replace_in_exp(exp, args, cargs, unique_prefix):
149    if isinstance(exp, ast.BinOp):
150        exp.left = replace_in_exp(exp.left, args, cargs, unique_prefix)
151        exp.right = replace_in_exp(exp.right, args, cargs, unique_prefix)
152    elif isinstance(exp, ast.UnaryOp):
153        exp.operand = replace_in_exp(exp.operand, args, cargs, unique_prefix)
154    elif isinstance(exp, ast.Name):
155        exp = modify_name(exp, args, cargs, unique_prefix)
156    elif isinstance(exp, ast.Attribute):
157        exp = modify_name(exp, args, cargs, unique_prefix)
158    elif isinstance(exp, ast.Subscript):
159        exp = modify_name(exp, args, cargs, unique_prefix)
160    elif isinstance(exp, ast.Call):
161        for index, item in enumerate(exp.args):
162            exp.args[index] = replace_in_exp(item, args, cargs, unique_prefix)
163    elif isinstance(exp, ast.Tuple):
164        for index, elt in enumerate(exp.elts):
165            exp.elts[index] = replace_in_exp(elt, args, cargs, unique_prefix)
166    elif isinstance(exp, ast.Num):
167        pass
168    else:
169        print exp
170        assert(1==0)
171    return exp
172
173def get_all_calls(main, all_func):
174    call_list = []
175    for index, loc in enumerate(main.body):
176        #print loc.value
177        if isinstance(loc, ast.Assign) and isinstance(loc.value, ast.Call):
178                func_name = translate_var(loc.value.func)
179                if func_name in all_func:
180                    call_list.append(index)
181        if isinstance(loc, ast.AugAssign) and isinstance(loc.value, ast.Call):
182                func_name = translate_var(loc.value.func)
183                if func_name in all_func:
184                    call_list.append(index)
185    return call_list
186
187def update_var_names(callee, args):
188    unique_prefix = '_'+callee.name.replace('.', '_')+'_'
189
190    assert (len(args)==len(callee.args.args))
191
192    for index, loc in enumerate(callee.body):
193        if isinstance(loc, ast.Assign):
194            for index, var in enumerate(loc.targets):
195                loc.targets[index] = modify_name(var, args, callee.args.args, unique_prefix)
196            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
197        elif isinstance(loc, ast.AugAssign):
198            loc.target = modify_name(loc.target, args, callee.args.args, unique_prefix)
199            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
200        elif isinstance(loc, ast.Return):
201            pass
202            #loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
203        else:
204            assert (1==0)
205
206    return callee, unique_prefix
207
208
209def get_name(loc, index):
210    if isinstance(loc.value, ast.Tuple):
211        return translate_var(loc.value.elts[index])
212    elif index == 0:
213        return translate_var(loc.value)
214    else:
215        assert(1==0)
216
217def modify_assignments_in_caller(main, line_no, return_list, prefix, callee):
218    #print "*********", prefix, "************"
219    if isinstance(main.body[line_no].targets[0], ast.Tuple):
220        all_vars = main.body[line_no].targets[0].elts
221    else:
222        all_vars = main.body[line_no].targets
223    new_code = []
224
225    if isinstance(callee.body[-1].value, ast.Tuple):
226        return_vars = [translate_var(x) for x in callee.body[-1].value.elts]
227    else:
228        return_vars = [translate_var(callee.body[-1].value)]
229
230    for index, item in enumerate(all_vars):
231        actual_name = translate_var(item)
232        if actual_name in [translate_var(x) for x in main.body[line_no].value.args]:
233            continue
234        temp_name = return_vars[index]
235
236        if temp_name in [translate_var(x) for x in callee.args.args]:
237            second_index = [translate_var(x) for x in callee.args.args].index(temp_name)
238            final_name = translate_var(main.body[line_no].value.args[second_index])
239            #print "1111"
240        else:
241            final_name = prefix+temp_name
242            #print "2222", final_name
243
244        if return_list[temp_name][0] == "struct":
245            my_template = STRUCT_TEMPLATE
246        elif return_list[temp_name][0] == "array":
247            my_template = ARRAY_TEMPLATE
248 
249        if return_list[temp_name][0] == "bitblock":
250            new_code.append(ast.Assign( [ast.Name(id=item.id)] , ast.Name(id=final_name) ))
251        else:
252            for elem in return_list[temp_name][1]:
253                #print final_name, elem
254                new_code.append(ast.Assign([ast.Name(id=my_template%(item.id, elem))], ast.Name(id=prefix+my_template%(temp_name, elem))))
255    main.body = main.body[:line_no]+new_code+main.body[line_no+1:]
256    return main
257
258def get_return_list(callee):
259    ret_list = []
260    loc = callee.body[-1]
261    if isinstance(loc, ast.Return):
262        if isinstance(loc.value, ast.Tuple):
263            all_rets = loc.value.elts
264        else:
265            all_rets = [loc.value]
266        for item in all_rets:
267            if isinstance(item, ast.Attribute) or isinstance(item, ast.Subscript):
268                ret_list.append(item.value.id)
269            if isinstance(item, ast.Name):
270                ret_list.append(item.id)
271    return ret_list
272
273def get_updated(loc):
274    if isinstance(loc, ast.AugAssign):
275        updated = [loc.target]
276    elif isinstance(loc, ast.Assign):
277        if isinstance(loc.targets[0], ast.Tuple):
278            updated = loc.targets[0].elts
279        else:
280            updated = loc.targets
281    else:
282        updated = []
283    return updated
284
285def get_update_details(callee, ret_list):
286    field_info = {}
287    for item in ret_list:
288        field_info[item] = ['simple', set([])]
289    for loc in callee.body:
290        updated = get_updated(loc)
291        for item in updated:
292            type, name, extra = parse_var(translate_var(item, True))
293            if name in field_info:
294                field_info[name][0] = type
295                if type=='struct':
296                    field_info[name][1].add(extra)
297                if type == 'array':
298                    field_info[name][1].add(extra)
299    return field_info
300
301def process_return(callee):
302    if not isinstance(callee.body[-1], ast.Return):
303        return {}
304    ret_list = get_return_list(callee)
305    field_info = get_update_details(callee, ret_list)
306    return field_info
307
308
309def generate_expanded_code(loc):
310    new_code = []
311    array_name = translate_var(loc.targets[0])
312    for ind, i in enumerate(range(len(loc.value.elts))):
313        new_loc = ast.Assign([ast.Subscript(value=ast.Name(id=array_name), slice=ast.Index(ast.Num(n=ind)) )], loc.value.elts[i])
314        new_code.append(new_loc)
315
316    return new_code
317
318def expand_list_init(module):
319    for item in module.body:
320        replace_code = {}
321        if isinstance(item, ast.FunctionDef):
322            for index, loc in enumerate(item.body):
323                if isinstance(loc, ast.Assign):
324                    if isinstance(loc.value, ast.List):
325                        equivalent_code = generate_expanded_code(loc)
326                        replace_code[index] = equivalent_code
327            keys = [key for key in replace_code]
328            keys.sort(reverse = True)
329            for line_no in keys:
330                item.body = item.body[:line_no]+replace_code[line_no]+item.body[line_no+1:]
331   
332    return module
333
334def get_predefined_funcs():
335    advance = ast.Attribute(value=ast.Name(id="bitutil"), attr="Advance")
336    scan_thru = ast.Attribute(value=ast.Name(id="bitutil"), attr="ScanThru")
337    return [translate_var(advance), translate_var(scan_thru)]
338
339def do_inlining(module):
340    module = collect_functions(module)
341    func_dict = collect_names(module.body)
342    module = expand_list_init(module)
343   
344   
345    valid_fn = [name for name in func_dict] 
346    predef_fn = get_predefined_funcs()
347    module = clean_functions(module, valid_fn, predef_fn)
348    main = module.body[func_dict["main"]]
349    call_list = get_all_calls(main, [func for func in func_dict])
350    call_list.sort(reverse=True)
351    for line_no in call_list:
352        callee = copy.deepcopy(module.body[func_dict[translate_var(main.body[line_no].value.func)]])
353        args = main.body[line_no].value.args
354        return_list = process_return(callee)
355        callee, prefix = update_var_names(callee, args)
356
357        main = modify_assignments_in_caller(main, line_no, return_list, prefix, callee)
358
359        if isinstance(callee.body[-1], ast.Return):
360            del callee.body[-1]
361        main.body = main.body[:line_no]+callee.body+main.body[line_no:]
362    return module.body[func_dict["main"]].body
363
364#############################################################################################
365## Translation to bitExpr classes --- Pass 2
366#############################################################################################
367
368def translate_index(ast_value):
369        if isinstance(ast_value, ast.Str): return ast_value.s
370        elif isinstance(ast_value, ast.Num): return repr(ast_value.n)
371        else: raise PyBitError("Unknown value %s\n" % ast.dump(ast_value))
372
373def is_Advance_Call(fncall):
374        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
375        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
376                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
377        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
378
379def is_ScanThru_Call(fncall):
380        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
381        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
382                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
383        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
384
385def translate(ast_expr):
386        if isinstance(ast_expr, ast.Name):
387                if ast_expr.id=="allzero":
388                        return bitexpr.FalseLiteral()
389                if ast_expr.id=="allone":
390                        return bitexpr.TrueLiteral()
391                return bitexpr.Var(translate_var(ast_expr))
392        elif isinstance(ast_expr, ast.Num):
393                if ast_expr.n == 0:
394                    return bitexpr.FalseLiteral()
395                else:
396                    return bitexpr.Const(ast_expr.n)
397        elif isinstance(ast_expr, ast.Attribute):
398                if isinstance(ast_expr.value, ast.Name):
399                        return bitexpr.Var(translate_var(ast_expr))
400                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, ast_expr.attr))
401                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(ast_expr))
402        elif isinstance(ast_expr, ast.Subscript):
403                if isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice, ast.Index):
404                        return bitexpr.Var(translate_var(ast_expr))
405                        return bitexpr.Var("%s[%s]" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
406                else: raise PyBitError("Illegal array %s\n" % ast.dump(ast_expr))
407        elif isinstance(ast_expr, ast.UnaryOp):
408                e1 = translate(ast_expr.operand)
409                if isinstance(ast_expr.op, ast.Invert):
410                        return bitexpr.make_not(e1)
411                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
412        elif isinstance(ast_expr, ast.BinOp):
413                e1 = translate(ast_expr.left)
414                e2 = translate(ast_expr.right)
415                if isinstance(ast_expr.op, ast.BitOr):
416                        return bitexpr.make_or(e1, e2)
417                elif isinstance(ast_expr.op, ast.BitAnd):
418                        return bitexpr.make_and(e1, e2)
419                elif isinstance(ast_expr.op, ast.BitXor):
420                        return bitexpr.make_xor(e1, e2)
421                elif isinstance(ast_expr.op, ast.Add):
422                        return bitexpr.make_add(e1, e2)
423                elif isinstance(ast_expr.op, ast.Sub):
424                        return bitexpr.make_sub(e1, e2)
425                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
426        elif isinstance(ast_expr, ast.Call):
427                if is_Advance_Call(ast_expr):
428                        e = translate(ast_expr.args[0])
429                        temp = bitexpr.make_add(e,e)
430                        return temp
431                elif is_ScanThru_Call(ast_expr):
432                        e0 = translate(ast_expr.args[0])
433                        e1 = translate(ast_expr.args[1])
434                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
435                else: 
436                    raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
437        elif isinstance(ast_expr, ast.Compare):
438                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
439                        e0 = translate(ast_expr.left)
440                        return bitexpr.isNoneZero(e0)
441                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
442               
443        else: 
444            raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
445
446def translate_var(v, aggregate_type = False):
447        if isinstance(v, ast.Name):
448            agg_name = v.id
449            var_name = v.id
450        elif isinstance(v, ast.Attribute):
451           if isinstance(v.value, ast.Name):
452               var_name = STRUCT_TEMPLATE%(v.value.id, v.attr)
453               agg_name = "%s.%s" % (v.value.id, v.attr)
454           else: raise PyBitError("Illegal attribute %s\n" % ast.dump(v))
455        elif isinstance(v, ast.Subscript):
456           if isinstance(v.value, ast.Name) and isinstance(v.slice, ast.Index):
457               var_name = ARRAY_TEMPLATE%(v.value.id, translate_index(v.slice.value))
458               agg_name = "%s[%s]" % (v.value.id, translate_index(v.slice.value))
459               
460        else: raise PyBitError("Unknown operator %s\n" % ast.dump(v))
461
462        if aggregate_type:
463            return agg_name
464        else:
465            return var_name
466
467def translate_stmts(ast_stmts):
468        translated = []
469        for s in ast_stmts:
470                if isinstance(s, ast.Expr):
471                        if (s.value.func.id=='optimize'):
472                            target = translate_var(s.value.args[0])
473                            replace = translate_var(s.value.args[1])
474                            translated.append(bitexpr.Reduce(target, replace))
475                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
476                elif isinstance(s, ast.Assign):
477                        e = translate(s.value)
478                        for astv in s.targets: 
479                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
480                elif isinstance(s, ast.AugAssign):
481                        v = bitexpr.Var(translate_var(s.target))
482                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
483                elif isinstance(s, ast.While):
484                        e = translate(s.test)
485                        body = translate_stmts(s.body)
486                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body))
487                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
488        return translated
489#############################################################################################
490## Conversion to SSA form --- Pass 3
491#############################################################################################
492
493def extract_vars(rhs):
494    if isinstance(rhs, bitexpr.Var):
495        return [rhs.varname]
496    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral) or isinstance(rhs, bitexpr.Const):
497        return []
498    if isinstance(rhs, bitexpr.Not):
499        return extract_vars(rhs.operand1)
500    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
501    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
502
503def update_lineno(inner_table, shift):
504    for key in inner_table:
505        for index in range(len(inner_table[key][0])):
506            inner_table[key][0][index] += shift
507        for index in range(len(inner_table[key][1])):
508            inner_table[key][1][index] += shift
509    return inner_table
510
511def merge_tables(table, inner_table):
512    for key in inner_table:
513        if key in table:
514            table[key][0] += inner_table[key][0]
515            table[key][1] += inner_table[key][1]
516        else:
517            table[key] = inner_table[key]
518    return table
519
520def gen_sym_table(code, goInside = False):
521    """Generate a simple symbol table for a three address code
522        each entry is of this form: var_name:[[defs][uses]]
523    """
524    table = {}
525    index = 0
526    for stmt in code:
527        if isinstance(stmt, bitexpr.Reduce):
528            if stmt.target in table:
529                table[stmt.target][1].append(index)
530            else:
531                table[stmt.target] = [[], [index]]
532            index += 1
533        elif isinstance(stmt, bitexpr.BitAssign):
534            current = stmt.LHS.varname
535            if current in table:
536                table[current][0].append(index)
537            else:
538                table[current] = [[index],[]]
539
540            varlist = extract_vars(stmt.RHS)
541            for var in varlist:
542                if var in table:
543                    table[var][1].append(index)
544                else:
545                    table[var] = [[], [index]]
546            index += 1
547        elif isinstance(stmt, bitexpr.WhileLoop):
548            cond_var = stmt.control_expr.var.varname
549            if cond_var in table:
550                table[cond_var][1].append(index)
551            else:
552                #while loop conditioned on an undefined variable
553                assert(1==0)
554            if goInside:
555                inner_table = gen_sym_table(stmt.stmts, True)
556                inner_table = update_lineno(inner_table, index+1)
557                table = merge_tables(table, inner_table)
558                index += (1+len(stmt.stmts))
559            else:
560                index += 1
561        else:
562            assert(1==0)
563    return table
564##################################################################################
565def get_line(code, line):
566    lineno = 0
567    for stmt in code:
568        if isinstance(stmt, bitexpr.BitAssign):
569            if lineno == line:
570                return stmt
571            lineno += 1
572        elif isinstance(stmt, bitexpr.WhileLoop):
573            if lineno == line:
574                return stmt
575            elif lineno+len(stmt.stmts) < line:
576                lineno += 1
577                line -= len(stmt.stmts)
578                continue
579            else:
580                return get_line(stmt.stmts, line-(lineno+1))
581
582        elif isinstance(stmt, bitexpr.Reduce):
583            if lineno == line:
584                return stmt
585            lineno += 1
586        else:
587            assert(1==0)
588
589def update_def(code, var, line, suffix_num):
590    #for i in code:
591    #    print i
592    #print "------------------------", len(code), line, var
593    loc = code[line]
594    if isinstance(loc, bitexpr.BitAssign):
595        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
596        return loc
597    else:
598        #either Reduce or While Loop in both cases it's a use
599        pass
600
601def update_rhs(rhs, varname, newname):
602    if isinstance(rhs, bitexpr.Var):
603        if rhs.varname == varname:
604            rhs.varname = newname
605        return rhs
606
607    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
608        return rhs
609
610    if isinstance(rhs, bitexpr.Not):
611        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
612        return rhs
613    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
614    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
615    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
616    return rhs
617
618def update_use(code, var, line, suffix_num):
619    loc = code[line]
620    if isinstance(loc, bitexpr.BitAssign):
621        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
622    elif isinstance(loc, bitexpr.Reduce):
623        pass
624    elif isinstance(loc, bitexpr.WhileLoop):
625        loc.control_expr.var.varname += "_%i"%suffix_num
626    return loc
627
628def parse_var(var):
629    index=var.find('.')
630    if index >= 0:
631        return ('struct', var[0:index], var[index+1:])
632               
633    index = var.find('[')
634    if index >= 0:
635        right_index = var.find(']')
636        return ('array', var[0:index], var[index+1:right_index])       
637   
638    if var.startswith("carry") or var.startswith("Carry"):
639        return('int', var, None)
640   
641    return ('bitblock', var, None)
642
643def simplify_name(var):
644    (vartype, name, extra) = parse_var(var)
645    if vartype == 'bitblock':
646        return name
647    if vartype == 'array':
648        return "a_%s_%s"%(name, extra)
649    if vartype == 'struct':
650        return "s_%s_%s"%(name, extra)
651    if vartype == 'int':
652        assert(1==0)
653    assert(1==0)
654
655def pairs(lst):
656    if lst == []:
657            return []
658    return zip(lst,lst[1:]+[lst[0]])
659
660
661def make_SSA(code, st):
662    new_vars = []
663    total_lines = len(code)
664    for var in st:
665        st[var][0].append(total_lines)
666        st[var][1].append(total_lines)
667
668    unique_number = 0
669    for var in st:
670        use_index = 0
671        for current, next in pairs(st[var][0])[0:-2]:
672            code[current] = update_def(code, var, current, unique_number)
673            uline = st[var][1][use_index]
674
675            while uline <= next and uline < total_lines:
676                if uline > current:
677                    code[uline] = update_use(code, var, uline, unique_number)
678                use_index += 1
679                uline = st[var][1][use_index]
680            unique_number += 1
681    return code
682
683#################################################################################################################
684## Breaking the code into basic blocks depending on where optimize pragma appears --- Pass 4
685#################################################################################################################
686def get_opt_list(s):
687    opt_list = []
688    remove_index = []
689    for line, stmt in enumerate(s):
690        if isinstance(stmt, bitexpr.Reduce):
691            opt_pair = (stmt.target, stmt.replace)
692            opt_list.append(opt_pair)
693            remove_index.append(line)
694    for ind in reversed(remove_index):
695        del s[ind]
696
697    return opt_list
698
699def break_to_segs(s, breaks):
700    segs = []
701    breaks = [-1]+breaks+[len(s)]
702    for start, stop in pairs(breaks)[:-1]:
703        #print start, stop
704        next = s[start+1:stop]
705        if len(next)>0:
706            segs.append(next)
707    return segs
708
709def extract_pragmas(s):
710    targets = []
711    exprs = []
712    lines = []
713
714    #extracting pragmas and their targets
715    for line, loc in enumerate(s):
716        if isinstance(loc, bitexpr.Reduce):
717            targets.append(loc.target)
718            exprs.append(loc)
719            lines.append(line)
720
721    #removing pragmas from the source code
722    for i in reversed(lines):
723        del s[i]
724
725    return targets, exprs
726
727def get_defs(targets, s):
728    temp = [(x,-1) for x in targets]
729    targ_dic = dict(temp)
730    for line, loc in enumerate(s):
731        if loc.LHS.varname in targets:
732            targ_dic[loc.LHS.varname] = line
733
734    items = targ_dic.items()
735    items = sorted(items, key=(lambda x: x[1]))
736
737    lineno = [key for value, key in items]
738    sorted_targets = [value for value, key in items]
739
740    return sorted_targets, lineno
741
742def get_boundary(def_line, lastuse_line):
743    earliest = min(def_line)
744
745    # The indices of all variables defined in the same line of code and earlier than all other variables
746    # There is more than one such variable only if these variables are input variables
747    all_early = [] 
748    all_early_cnt = 0
749    j = 0
750
751    while earliest in def_line[j:]:
752        m = def_line[j:].index(earliest)
753        all_early_cnt += 1
754        all_early.append(m+j)
755        j += m+1
756
757    #all_early contains the indices of the all earliest defined vars
758    lasts = []
759    map(lambda y: lasts.append(lastuse_line[y]), all_early)
760
761    latest = max(lasts)
762    return earliest, latest
763
764def adjust_latest(s, latest):
765    line = 0
766    for stmt in s:
767        if isinstance(stmt, bitexpr.BitAssign):
768            if latest == line:
769                return latest
770            line += 1
771        if isinstance(stmt, bitexpr.WhileLoop):
772            end_of_loop = line+len(stmt.stmts)
773            if latest <= end_of_loop:
774                return end_of_loop
775            else:
776                line = end_of_loop+1
777    return line
778
779#There is slight difference between the function below and adjust_latest (look at the second return)
780def adjust_earliest(s, earliest):
781    line = 0
782    if earliest == -1:
783        return -1
784
785    for stmt in s:
786        if isinstance(stmt, bitexpr.BitAssign):
787            if earliest == line:
788                return earliest
789            line += 1
790        if isinstance(stmt, bitexpr.WhileLoop):
791            end_of_loop = line+len(stmt.stmts)
792            if earliest <= end_of_loop:
793                return end_of_loop+1
794            else:
795                line = end_of_loop+1
796
797def chop_code(s, start, stop):
798    line = 0
799    if start == -1:
800        cut1 = -1
801   
802    if stop == len(s):
803        cut2 = len(s)
804   
805    for index, stmt in enumerate(s):
806        if isinstance(stmt, bitexpr.BitAssign):
807            if line == start:
808                cut1 = index
809            if line == stop:
810                cut2 = index
811            line += 1
812
813        if isinstance(stmt, bitexpr.WhileLoop):
814            end_of_loop = line+len(stmt.stmts)
815            if line == start:
816                cut1 = index
817            if end_of_loop == stop:
818                cut2 = index
819            line = end_of_loop+1
820
821    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
822
823def count_lines(code):
824    cnt = len(code)
825    for stmt in code:
826        if isinstance(stmt, bitexpr.WhileLoop):
827            cnt += count_lines(stmt.stmts)
828    return cnt
829
830
831def gen_bb(s, opt_list, def_line, lastuse_line):
832    assert (len(opt_list) == len(def_line) == len(lastuse_line))
833    if opt_list==[]:
834        return s
835
836    earliest, latest = get_boundary(def_line, lastuse_line)
837
838    indices = []
839    for index, i in enumerate(def_line): #was enumerate(lastuse_line)
840        if i <= latest:
841            indices.append(index)
842    #if latest or earliest are inside a while loop they are pushed to the end of while loop or beginning of the next block (respectively)
843
844    e = earliest
845    l = latest
846
847    opt1 = {}
848    def1 = {}
849    use1 = {}
850    for index, i in enumerate(opt_list):
851        opt1.setdefault(index in indices, []).append(i)
852    for index, i in enumerate(def_line):
853        def1.setdefault(index in indices, []).append(i)
854    for index, i in enumerate(lastuse_line):
855        use1.setdefault(index in indices, []).append(i)
856
857    latest = max(use1[True])
858    latest = adjust_latest(s, latest)
859    earliest = adjust_earliest(s, earliest)
860
861    #use earliest to construct an initial block with no if-then-else
862    #use use1[True], def1[True], opt1[True] to construct if then else and recurse on inner blocks
863    #use use1[False], def1[False], opt1[False] to construct the block after if-then-else and recurse on that block
864    first, second, third = chop_code(s, earliest, latest)
865    the_opt = None
866    the_index = None
867    for index, value in enumerate(opt1[True]):
868        if e==def1[True][index] and l==use1[True][index]:
869            the_opt = value
870            the_index = index
871    del opt1[True][the_index]
872    del use1[True][the_index]
873    del def1[True][the_index]
874
875    second = gen_bb(second, opt1[True], def1[True],use1[True])
876
877    if the_opt[1] == 'allzero':
878        cond_obj = bitexpr.isAllZero(the_opt[0])
879    if the_opt[1] == 'allone':
880        cond_obj = bitexpr.isAllOne(the_opt[0])
881
882    change = count_lines(first) + count_lines(second)
883    new_def = []
884    new_use = []
885
886    if (False in opt1):
887        for i in def1[False]:
888            new_def.append(max(i-change, -1))
889        for i in use1[False]:
890            new_use.append(i-change)
891
892        third = gen_bb(third, opt1[False], new_def, new_use)
893
894    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))]+third
895
896    return result
897
898def partition2bb(s):
899    basicblocks = []
900    lineno = []
901    def_line = []
902    lastuse_line = []
903    opt_list = get_opt_list(s)
904
905    st = gen_sym_table(s, True)
906    st2 = gen_sym_table(s)
907    total_lines = count_lines(s)
908
909    for item in opt_list:
910        if len(st2[item[0]][0]) > 0:
911            def_line.append(st2[item[0]][0][-1])
912        else:
913            def_line.append(-1)
914
915        lastuse_line.append(total_lines)
916
917    return gen_bb(s, opt_list, def_line, lastuse_line)
918
919#################################################################################################################
920## Generating declarations for variables. This pass is based on the syntax of C programming language ---
921## Normalizing the code
922#################################################################################################################
923
924def normalize(s, predec = {}, ccelim=True):
925    if len(s)==0:
926        return []
927    if isinstance(s[0], bitexpr.If):
928        gc = basic_block.BasicBlock.gensym_counter
929        cc = basic_block.BasicBlock.carry_counter
930        bc = basic_block.BasicBlock.brw_counter
931        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
932        maxgc = basic_block.BasicBlock.gensym_counter
933        maxcc = basic_block.BasicBlock.carry_counter
934        maxbc = basic_block.BasicBlock.brw_counter
935        basic_block.BasicBlock.gensym_counter = gc
936        basic_block.BasicBlock.carry_counter = cc
937        basic_block.BasicBlock.brw_counter = bc
938        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
939        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
940        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
941        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
942       
943        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
944    if isinstance(s[0], bitexpr.WhileLoop):
945        orig = copy.deepcopy(predec)
946        s[0].stmts = normalize(s[0].stmts, predec, False)
947        return [s[0]]+normalize(s[1:], {})
948    if isinstance(s[0], bitexpr.BitAssign):
949        code, next = recurse_forward(s)
950        bb = basic_block.BasicBlock(predec)
951        predec.update(bb.normalize(code, ccelim))
952        return bb.get_code() + normalize(next, copy.deepcopy(predec))
953
954#################################################################################################################
955## Generating declarations for variables. This pass is based on the syntax of C programming language --- Pass 5
956#################################################################################################################
957def get_vars(stmt):
958        ints = set([])
959        bitblocks = set(['AllOne', 'AllZero'])
960        arrays = {}
961        structs = {}
962
963        #for stmt in code:
964
965        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
966       
967
968
969        if isinstance(stmt.RHS, bitexpr.Add): 
970                ints.add(stmt.RHS.carry)
971        if isinstance(stmt.RHS, bitexpr.Sub):
972                ints.add(stmt.RHS.brw)
973
974        for var in all_vars:
975                if isinstance(var, bitexpr.Const):
976                    continue
977               
978                (var_type, name, extra) = parse_var(var.varname)
979
980                if (var_type == "bitblock"):
981                        bitblocks.add(name)
982
983                if var_type == "array":
984                        if not name in arrays:
985                                arrays[name]= extra
986                        else:
987                                arrays[name] = max(arrays[name], extra)
988
989                if var_type == "struct":
990                        if not name in structs:
991                                structs[name] = set([extra])
992                        else:
993                                structs[name].add(extra)
994                if var_type == "int":
995                    ints.add(name)
996        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
997
998def merge_var_dic(first, second):
999    res = {}
1000    #print first
1001    #print second
1002    res['int'] = first['int'].union(second['int'])
1003    res['bitblock'] = first['bitblock'].union(second['bitblock'])
1004
1005    res['array'] = first['array']
1006    temp = dict({'array':copy.copy(second['array'])})
1007    for i in res['array']:
1008        if i in temp['array']:
1009            res['array'][i] = max(res['array'][i], temp['array'][i])
1010            del temp['array'][i]
1011    res['array'].update(temp['array'])
1012
1013
1014    res['struct'] = first['struct']
1015    temp = dict({'struct':copy.copy(second['struct'])})
1016    for i in res['struct']:
1017        if i in temp['struct']:
1018            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
1019            del temp['struct'][i]
1020    res['struct'].update(temp['struct'])
1021    return res
1022
1023def gen_output(var_dic):
1024        s = ''
1025        for i in  var_dic['int']:
1026                s+="int %s=0;\n"%i
1027
1028        for i in var_dic['bitblock']:
1029                if i == AllOne:
1030                        s += "BitBlock %s = simd_const_1(1);\n"%i
1031                elif i== AllZero:
1032                        s+="BitBlock %s = simd_const_1(0);\n"%i
1033                else:
1034                        s+="BitBlock %s;\n"%i
1035
1036        for i in var_dic['array']:
1037                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
1038
1039        for i in var_dic['struct']:
1040                s+="struct __%s__{\n"%i
1041                for j in var_dic['struct'][i]:
1042                        s+= "\tBitBlock %s;\n"%j
1043                s+="};\n"
1044                s+="struct __%s__ %s;\n"%(i,i)
1045
1046        return s
1047
1048def gen_var_dic(s):
1049    if len(s)==0:
1050        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
1051    if isinstance(s[0], bitexpr.BitAssign):
1052        vd = get_vars(s[0])
1053        more = gen_var_dic(s[1:])
1054        vd = merge_var_dic(vd, more)
1055        return vd
1056
1057    if isinstance(s[0], bitexpr.If):
1058        vd1 = gen_var_dic(s[0].true_branch)
1059        vd2 = gen_var_dic(s[0].false_branch)
1060        vd = merge_var_dic(vd1, vd2)
1061        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1062        vd =  merge_var_dic(vd, vd3)
1063        more = gen_var_dic(s[1:])
1064        vd  = merge_var_dic(vd, more)
1065        return vd
1066
1067    if isinstance(s[0], bitexpr.WhileLoop):
1068        vd = gen_var_dic(s[0].stmts)
1069        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1070        vd =  merge_var_dic(vd, vd1)
1071        more = gen_var_dic(s[1:])
1072        vd  = merge_var_dic(vd, more)
1073        return vd
1074
1075    """for loc in s[1:]:
1076        more = get_vars(loc)
1077
1078    vd = get_vars(bb[0].code)
1079    for block in bb[1:]:
1080        more = get_vars(block.code)
1081        vd = merge_var_dic(vd, more)
1082    declarations = gen_output(vd)
1083    return declarations
1084    """
1085
1086def gen_declarations(s):
1087    vd = gen_var_dic(s)
1088    return gen_output(vd)
1089#################################################################################################################
1090## This class replaces all occurences of a reduced variable to its value --- Pass 7
1091## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
1092## *** the new AST notation                                                                                   ***
1093#################################################################################################################
1094def replace_in_rhs(rhs, target, replace):
1095    if isinstance(rhs, bitexpr.Var):
1096        if rhs.varname == target:
1097            rhs = replace
1098            return replace
1099
1100    if isinstance(rhs.operand1, bitexpr.Var):
1101        if rhs.operand1.varname == target:
1102            rhs.operand1 = replace
1103    else:
1104        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
1105            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
1106
1107    if isinstance(rhs.operand2, bitexpr.Var):
1108        if rhs.operand2.varname == target:
1109            rhs.operand2 = replace
1110    else:
1111        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
1112            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
1113    return rhs
1114
1115def apply_single_opt(code, target, replace):
1116    if len(code) == 0:
1117        return []
1118
1119    if replace=='AllZero':
1120        replace = bitexpr.FalseLiteral()
1121    if replace == 'AllOne':
1122        replace = bitexpr.TrueLiteral()
1123
1124    if isinstance(code[0], bitexpr.BitAssign):
1125        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
1126
1127    if isinstance(code[0], bitexpr.If):
1128        if code[0].control_expr.var.varname == target:
1129            code[0].control_expr.var = replace
1130        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
1131        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
1132
1133    if isinstance(code[0], bitexpr.WhileLoop):
1134        if code[0].control_expr.var.varname == target:
1135            code[0].control_expr.var = replace
1136        code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
1137
1138    return [code[0]]+apply_single_opt(code[1:], target, replace)
1139
1140def apply_all_opt(s):
1141    if len(s) == 0:
1142        return []
1143    elif isinstance(s[0], bitexpr.If):
1144        target = s[0].control_expr.var.varname
1145        replace = s[0].control_expr.val
1146        apply_single_opt(s[0].true_branch, target, replace)
1147        apply_all_opt(s[0].true_branch)
1148        apply_all_opt(s[0].false_branch)
1149
1150    elif isinstance(s[0], bitexpr.WhileLoop):
1151        apply_all_opt(s[0].stmts)
1152
1153    return [s[0]]+apply_all_opt(s[1:])
1154 
1155#################################################################################################################
1156## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation. --- Pass 8
1157## method prune is supposed to remove unreachable branches but it is incomplete
1158#################################################################################################################
1159
1160def prune(fixed, tree):
1161    """removes unreachable branches of the tree"""
1162    target = tree[0].target
1163    replace = tree[0].replace
1164    if target in fixed:
1165        if replace == 'allone':
1166            if isinstance(fixed[target], bitexpr.TrueLiteral):
1167                tree[0] = None
1168                tree[1].join(tree[2][1])
1169                del tree[3]
1170                del tree[2]
1171                fixed.update(tree[1].simplify(fixed))
1172            if isinstance(fixed[target], bitexpr.FalseLiteral):
1173                tree[0] = None
1174                tree[1].join(tree[3][1])
1175                del tree[3]
1176                del tree[2]
1177                fixed.update(tree[1].simplify(fixed))
1178        if replace == 'allzero':
1179            if isinstance(fixed[target], bitexpr.FalseLiteral):
1180                tree[0] = None
1181                tree[1].join(tree[2][1])
1182                del tree[3]
1183                del tree[2]
1184                fixed.update(tree[1].simplify(fixed))
1185            if isinstance(fixed[target], bitexpr.TrueLiteral):
1186                tree[0] = None
1187                tree[1].join(tree[3][1])
1188                del tree[3]
1189                del tree[2]
1190                fixed.update(tree[1].simplify(fixed))
1191
1192    empty_list = []
1193    for index, branch in enumerate(tree[2:]):
1194        if branch[0] is None:
1195            if branch[1].code == []:
1196                empty_list.append(2+index)
1197
1198    empty_list.reverse()
1199    for i in empty_list:
1200        del tree[i]
1201def filter_fixed(fixed, stmts):
1202    for loc in stmts:
1203        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
1204            del fixed[loc.LHS.varname]
1205        if isinstance(loc, bitexpr.If):
1206            fixed = filter_fixed(fixed, loc.true_branch)
1207            fixed = filter_fixed(fixed, loc.false_branch)
1208        if isinstance(loc, bitexpr.WhileLoop):
1209            fixed = filter_fixed(fixed, loc.stmts)
1210    return fixed
1211           
1212def simplify_tree(code, fixed = {}, prev = []):
1213    if len(code) == 0:
1214        return []
1215
1216    assumptions = {}
1217
1218    if isinstance(code[0], bitexpr.BitAssign):
1219        this, next = recurse_forward(code)
1220        fixed.update(basic_block.simplify(this, fixed))
1221        return this+simplify_tree(next, fixed, this)
1222
1223    elif isinstance(code[0], bitexpr.If):
1224        fixed1 = copy.deepcopy(fixed)
1225        fixed2 = copy.deepcopy(fixed)
1226        if isinstance(code[0].control_expr, bitexpr.isAllOne):
1227            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
1228        if isinstance(code[0].control_expr, bitexpr.isAllZero):
1229            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
1230        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
1231        fixed1.update(assumptions)
1232        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
1233        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
1234        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1235
1236    elif isinstance(code[0], bitexpr.WhileLoop):
1237        fixed = filter_fixed(fixed, code[0].stmts)
1238        fixed1 = copy.deepcopy(fixed)
1239        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
1240        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1241
1242#################################################################################################################
1243## Dead Code Elimination ---- Pass 9
1244## *** Changes required here is the same as Pass 7
1245#################################################################################################################
1246def get_effective_name(varname):
1247    if varname.startswith("array") or varname.startswith("strct"):
1248        try:
1249            ind = varname.index("__")
1250           
1251            return varname[6:ind]
1252           
1253        except:
1254            return varname
1255   
1256    return varname
1257   
1258   
1259def check_loc(loc, must_liv):
1260   
1261    effective_name = get_effective_name(loc.LHS.varname)
1262    #if effective_name == "lex":
1263           
1264    if effective_name in must_liv or loc.LHS.varname in must_liv:
1265        if isinstance(loc.RHS, bitexpr.Not):
1266            return set([loc.RHS.operand1.varname]), False
1267        elif isinstance(loc.RHS, bitexpr.Var):
1268            return set([loc.RHS.varname]), False
1269        elif isinstance(loc.RHS, bitexpr.FalseLiteral):
1270            return set([]), False
1271        elif isinstance(loc.RHS, bitexpr.TrueLiteral):
1272            return set([]), False
1273        elif isinstance(loc.RHS, bitexpr.Const):
1274            return set([]), False
1275        else:
1276            return set([loc.RHS.operand1.varname, loc.RHS.operand2.varname]), False
1277    else:
1278        return set([]), True
1279
1280def remove_copies(bb):
1281    """removes all copy statements e.g. var1 = var2"""
1282    lhs = [x.LHS.varname for x in bb]
1283    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
1284    for i in rhs:
1285        if i[1] in lhs:
1286            line = lhs.index(i[1])
1287            if i[0] > line:
1288                bb[i[0]].RHS = bb[line].RHS
1289    return bb
1290
1291def remove_dead(bb, must_live):
1292    #eliminates dead code from a basic block
1293    bb = remove_copies(bb)
1294    my_alives = set([])
1295    dead = []
1296
1297    for line, loc in reversed(list(enumerate(bb))):
1298        #print line, my_alives.union(must_live)
1299        new_lives, removable = check_loc(loc, my_alives.union(must_live))
1300
1301        if removable:
1302            dead.append(line)
1303        else:
1304            my_alives = my_alives.union(new_lives)
1305
1306    for i in dead:
1307        del bb[i]
1308
1309    return my_alives, bb
1310
1311def eliminate_dead_code(tree, must_live):
1312
1313    if len(tree) == 0:
1314        return [], []
1315
1316    last = len(tree) - 1
1317    new_alives = set([])
1318    bb = []
1319
1320    for loc in tree:
1321        if isinstance(loc, bitexpr.WhileLoop):
1322            if not loc.carry_expr is None:
1323                must_live.add(loc.carry_expr.var.varname)
1324
1325    if isinstance(tree[-1], bitexpr.BitAssign):
1326        first = 0
1327        for i in reversed(range(len(tree))):
1328            if not isinstance(tree[i], bitexpr.BitAssign):
1329                first = i+1
1330                break
1331        new_alives, bb = remove_dead(tree[first:], must_live)
1332        last = first
1333
1334    elif isinstance(tree[-1], bitexpr.If):
1335        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
1336
1337        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
1338        bb = [tree[-1]]
1339        new_alives.add(tree[-1].control_expr.var.varname)
1340
1341    elif isinstance(tree[-1], bitexpr.WhileLoop):
1342        must_live.add(tree[-1].control_expr.var.varname)
1343        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
1344        bb = [tree[-1]]
1345
1346    all_alives = new_alives.union(must_live)
1347    all_lives, new_tree = eliminate_dead_code(tree[:last], all_alives)
1348    tree = new_tree+bb
1349    return all_alives, tree
1350
1351#################################################################################################################
1352## This pass processes the code in the while loop and adds extra code required for handling carry variables
1353#################################################################################################################
1354carry_suffix = "_i"
1355
1356def fix_the_loop(loop):
1357    carries = []
1358    for loc in loop.stmts:
1359        if isinstance(loc.RHS, bitexpr.Add):
1360            carries.append(loc.RHS.carry)
1361    for item in carries:
1362        newvar = bitexpr.Var(item+carry_suffix)
1363        loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), "int")))
1364    for item in carries:
1365        loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral("int") ))
1366
1367    return loop, carries
1368
1369def process_while_loops(code):
1370    all = []
1371    update = {}
1372    for index, loc in enumerate(code):
1373        if isinstance(loc, bitexpr.WhileLoop):
1374            update[index] = fix_the_loop(loc)
1375            for i in update[index][1]:
1376                all.append(i)
1377                all.append(i+carry_suffix)
1378
1379    keys = [k for k in update]
1380    keys.sort(reverse=True)
1381    for key in keys:
1382        code[key] = update[key][0]
1383        carry_variable = bitexpr.Var("CarryTemp"+str(key))
1384        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral("int")))
1385        for item in update[key][1][2:]:
1386            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), "int")))
1387        if len(update[key][1]) == 1:
1388            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
1389        elif len(update[key][1]) > 1:
1390            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), "int")))
1391        for item in update[key][1]:
1392            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral("int")))
1393        for item in update[key][1]:
1394            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
1395
1396        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
1397    return code, all
1398#################################################################################################################
1399## This pass factors out the code that is common between the true branch and false branch of an if
1400## statement.
1401#################################################################################################################
1402
1403def are_the_same(one, two):
1404    if one.__class__ != two.__class__:
1405        return False
1406
1407    if isinstance(one, bitexpr.BitAssign):
1408        if (one.LHS.varname != two.LHS.varname):
1409            return False
1410        if one.RHS.__class__ != two.RHS.__class__:
1411            return False
1412        if isinstance(one.RHS, bitexpr.FalseLiteral):
1413            return True
1414        if isinstance(one.RHS, bitexpr.TrueLiteral):
1415            return True
1416        if isinstance(one.RHS, bitexpr.Var):
1417            return (one.RHS.varname == two.RHS.varname)
1418
1419        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
1420
1421    if isinstance(one, bitexpr.If):
1422        if (one.control_expr.__class__ != two.control_expr.__class__):
1423            return False
1424        if (one.control_expr.var.varname != two.control_expr.var.varname):
1425            return False
1426        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
1427            return False
1428        for ind in range(len(one.true_branch)):
1429            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
1430                return False
1431        for ind in range(len(one.false_branch)):
1432            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
1433                return False
1434        return True
1435
1436    if isinstance(one, bitexpr.WhileLoop):
1437        if (one.control_expr.__class__ != two.control_expr.__class__):
1438            return False
1439        if (one.control_expr.var.varname != two.control_expr.var.varname):
1440            return False
1441        if (one.carry_expr.__class__ != two.carry_expr.__class__):
1442            return False
1443        if (one.carry_expr.var.varname != two.carry_expr.var.varname):
1444            return False
1445        if (len(one.stmts) != len(two.stmts)):
1446            return False
1447        for ind in range(len(one.stmts)):
1448            if not are_the_same(one.stmts[ind], two.stmts[ind]):
1449                return False
1450        return True
1451
1452def get_factorable(cond):
1453    earliest = 0
1454    #print cond.control_expr, cond.control_expr.var.varname
1455    l = min(len(cond.true_branch), len(cond.false_branch))
1456    for index in reversed(range(-l,0)):
1457        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
1458            earliest = index
1459        else:
1460            break
1461
1462    common_length = -earliest
1463    return common_length
1464
1465def get_common_code(cond, common):
1466    true_len = len(cond.true_branch)-common
1467    false_len = len(cond.false_branch)-common
1468
1469    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
1470    common = cond.true_branch[true_len:]
1471
1472    return [new_cond], common
1473
1474def do_factorization(code, pos):
1475    indices = [key for key in pos]
1476    indices.sort()
1477
1478    for index in reversed(indices):
1479        new_cond, common = get_common_code(code[index], pos[index])
1480        code = code[:index]+new_cond+common+code[index+1:]
1481    return code
1482
1483def factor_out(code):
1484    for index, loc in enumerate(code):
1485        if isinstance(loc, bitexpr.If):
1486            code[index].true_branch = factor_out(code[index].true_branch)
1487            code[index].false_branch = factor_out(code[index].false_branch)
1488        if isinstance(loc, bitexpr.WhileLoop):
1489            code[index].stmts = factor_out(code[index].stmts)
1490
1491    pos = {}
1492    for index, loc in enumerate(code):
1493        if isinstance(loc, bitexpr.If):
1494            common_length = get_factorable(code[index])
1495            pos[index] = common_length
1496
1497    return do_factorization(code, pos)
1498
1499#################################################################################################################
1500## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 10
1501## ***Changes needed here are the same as Pass 7 and Pass 10***
1502#################################################################################################################
1503
1504def generate_condition(expr):
1505    if isinstance(expr, bitexpr.isAllZero):
1506        return "!bitblock_has_bit(%s)"%(expr.var.varname)
1507    elif isinstance(expr, bitexpr.isAllOne):
1508        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
1509    elif isinstance(expr, bitexpr.isNoneZero):
1510        return "bitblock_has_bit(%s)"%(expr.var.varname)
1511    else:
1512        #print expr
1513        assert (1==0)
1514
1515def generate_statement(expr, indent, cond_stmt):
1516    if isinstance(expr, bitexpr.isAllZero):
1517        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1518    elif isinstance(expr, bitexpr.isAllOne):
1519        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1520    elif isinstance(expr, bitexpr.isNoneZero):
1521        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1522    else:
1523        #print expr
1524        assert (1==0)
1525
1526def print_prog(s, indent = 0):
1527    indent_unit = 4
1528    code = ""
1529    if len(s) == 0:
1530        return ""
1531    if isinstance(s[0], bitexpr.If):
1532        code = generate_statement(s[0].control_expr, indent, "if")
1533        code += print_prog(s[0].true_branch, indent+indent_unit)
1534        code += " "*indent+"}\n"
1535        code += " "*indent+"else {\n"
1536        code += print_prog(s[0].false_branch, indent+indent_unit)
1537        code += " "*indent+"}\n"
1538
1539    if isinstance(s[0], bitexpr.WhileLoop):
1540        #code = generate_statement(s[0].control_expr, indent, "while")
1541        code = "\n%s(%s|%s) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname+">0")
1542        code += print_prog(s[0].stmts, indent+indent_unit)
1543        code += " "*indent+"}\n"
1544
1545    if isinstance(s[0], bitexpr.BitAssign):
1546        code += " "*indent + s[0].LHS.varname
1547        code += " = "
1548
1549        if s[0].RHS.data_type == "vector":
1550            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1551                code += s[0].RHS.operand1.varname
1552                code += ";\n"
1553            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
1554                code += s[0].RHS.operand1.varname
1555                code += ";\n"
1556            elif isinstance(s[0].RHS, bitexpr.Var):
1557                code += s[0].RHS.operand1.varname
1558                code += ";\n"
1559            elif isinstance(s[0].RHS, bitexpr.Add):
1560                code += s[0].RHS.op_C + "("
1561                code += s[0].RHS.operand1.varname
1562                code += ','
1563                code += s[0].RHS.operand2.varname
1564                code += ','
1565                code += s[0].RHS.carry
1566                code += ");\n"
1567            elif isinstance(s[0].RHS, bitexpr.Sub):
1568                code += s[0].RHS.op_C + "("
1569                code += s[0].RHS.operand1.varname
1570                code += ','
1571                code += s[0].RHS.operand2.varname
1572                code += ','
1573                code += s[0].RHS.brw
1574                code += ");\n"
1575
1576            else:
1577                code += s[0].RHS.op_C + "("
1578                code += s[0].RHS.operand1.varname
1579                code += ','
1580                code += s[0].RHS.operand2.varname
1581                code += ");\n"
1582        elif s[0].RHS.data_type == "int":
1583            if isinstance(s[0].RHS, bitexpr.Or):
1584                code += s[0].RHS.operand1.varname
1585                code += '|'
1586                code += s[0].RHS.operand2.varname
1587                code += ";\n"
1588            elif isinstance(s[0].RHS, bitexpr.Const):
1589                code += "sisd_from_int(%i);"%s[0].RHS.n
1590                code += "\n"
1591            elif isinstance(s[0].RHS, bitexpr.FalseLiteral):
1592                code += "0;\n"
1593
1594    s.pop(0)
1595    return code+print_prog(s, indent)
1596
1597#################################################################################################################
1598## Auxiliary Functions
1599#################################################################################################################
1600
1601def get_bb_of(code, index):
1602    if index < 0 or index >= len(code):
1603        return None
1604
1605    if isinstance(code[index], bitexpr.If):
1606        return code[index]
1607    if isinstance(code[index], bitexpr.WhileLoop):
1608        return code[index]
1609    if isinstance(code[index], bitexpr.BitAssign):
1610        last = index
1611        for loc in code[index+1:]:
1612            if isinstance(loc, bitexpr.BitAssign):
1613                last += 1
1614            else:
1615                break
1616 
1617        first = index
1618        for loc in reversed(code[:index]):
1619            if isinstance(loc, bitexpr.BitAssign):
1620                first -= 1
1621            else:
1622                break
1623
1624        return code[first:last+1]
1625
1626def get_previous_bb(code, index):
1627    if index < 0:
1628        return None
1629    if index >= len(code):
1630        return get_bb_of(code, len(code)-1)
1631    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
1632        return get_bb_of(code, index-1)
1633    if isinstance(code[index], bitexpr.BitAssign):
1634        for ind, loc in reversed(enumerate(code[:index])):
1635            if not isinstance(loc, bitexpr.BitAssign):
1636                return get_bb_of(code, ind)
1637        return get_bb_of(code, -1)
1638    assert(1==0)
1639
1640def recurse_forward(code):
1641    if isinstance(code[0], bitexpr.If):
1642        return code[0], code[1:]
1643    if isinstance(code[0], bitexpr.WhileLoop):
1644        return code[0], code[1:]
1645    if isinstance(code[0], bitexpr.BitAssign):
1646        last = 0
1647        for loc in code[1:]:
1648            if isinstance(loc, bitexpr.BitAssign):
1649                last += 1
1650            else:
1651                break
1652        return code[0:last+1], code[last+1:]
Note: See TracBrowser for help on using the repository browser.