source: proto/Compiler/py2bitexpr.py @ 817

Last change on this file since 817 was 576, checked in by eamiri, 9 years ago

some comments added

File size: 62.6 KB
Line 
1# -*- coding: utf-8 -*-
2import ast, bitexpr, copy, basic_block
3
4type_dict = {}
5structs = {}
6arrays = {}
7
8AllOne = 'AllOne'
9AllZero = 'AllZero'
10
11STRUCT_TEMPLATE="strct_%s__%s_"
12ARRAY_TEMPLATE="array_%s__%s_"
13
14class PyBitError(Exception):
15        pass
16
17#############################################################################################
18## Extracting list of live variables
19#############################################################################################
20
21# Extracs the list of live variables from the very last line of function "main"
22# This must be a "return" statement
23def get_lives(s):
24    for item in s.body:
25        if isinstance(item, ast.FunctionDef):
26            if (item.name =="main"):
27                assert (isinstance(item.body[-1], ast.Return))
28                ret = item.body[-1].value
29                if isinstance(ret, ast.Tuple):
30                    del item.body[-1]
31                    return [translate_var(x) for x in ret.elts],s
32                else:
33                    assert(isinstance(ret, ast.Name))
34                    del item.body[-1]
35                    return [translate_var(ret)],s
36
37#############################################################################################
38## Function Inlining
39#############################################################################################
40
41TEMP_VAR_TEMPLATE = "InlineTemp%i"
42
43def get_var_name(var):
44    var_name = translate_var(var)
45    if isinstance(var, ast.Name):
46        type_name = "simple"
47        extra = None
48    if isinstance(var, ast.Attribute):
49        type_name = "structure"
50        extra = var.attr
51    if isinstance(var, ast.Index):
52        type_name = "array"
53        extra = translate_index(v.slice.value)
54
55    return type_name, var_name, extra
56
57def register_var(final_name, original_name, type_name, extra):
58    pass
59
60# collects all functions in this module and in imported modules.
61# It starts by collecting a list of imported modules
62def collect_functions(module, file_path):
63    imps = []
64    to_remove = []
65    for index, item in enumerate(module.body):
66        if isinstance(item, ast.Import):
67            imps.append(item)
68            to_remove.append(index)
69
70    for n in reversed(to_remove):
71        del module.body[n]
72
73    for imp in imps:
74        for mod in imp.names:
75            mod_file = open(file_path + mod.name + '.py')
76            mod_code = ""
77            for line in mod_file:
78                mod_code += line
79            mod_file.close()
80            mod_parsed = ast.parse(mod_code)
81            modified = []
82            for item in mod_parsed.body:
83                if isinstance(item, ast.FunctionDef):
84                    item.name = translate_var(ast.Attribute(value=ast.Name(id=mod.name), attr=item.name))
85            module.body += mod_parsed.body
86    return module
87
88
89#Given a function it removes any line of code in that function that is not supported by the compiler
90def remove_unsupported_code(callee, fn_names, predefined, main=False):
91    to_remove = []
92    for index, loc in enumerate(callee.body):
93        if isinstance(loc, ast.While) or isinstance(loc, ast.If):
94            if not main:
95                to_remove.append(index)
96            continue
97
98        if isinstance(loc, ast.Expr):
99            if isinstance(loc.value, ast.Call):
100                if main and not (translate_var(loc.value.func) in fn_names+predefined):
101                    to_remove.append(index)
102                if not main and not translate_var(loc.value.func) in predefined:
103                    to_remove.append(index)
104
105            else:
106                to_remove.append(index)
107
108            continue
109
110        if not (isinstance(loc, ast.Assign) or isinstance(loc, ast.AugAssign) or isinstance(loc, ast.Return)):
111            to_remove.append(index)
112
113        #we do not support AugAssign with function calls
114        if isinstance(loc, ast.AugAssign):
115            if isinstance(loc.value, ast.Call):
116                assert (1==0)
117                to_remove.append(index)
118
119        #we do not support function calls in functions other than main
120        #these are OK if they are object construction, otherwise error.
121        if isinstance(loc, ast.Assign):
122            if isinstance(loc.value, ast.Call):
123                if main and not (translate_var(loc.value.func) in fn_names+predefined):
124                    to_remove.append(index)
125                if not main and not translate_var(loc.value.func) in predefined:
126                    to_remove.append(index)
127    to_remove.sort(reverse=True)
128    for line_no in to_remove:
129        del callee.body[line_no]
130
131    return callee
132
133
134# Iterates over all functions and cleans them up from code that is not supported by the compiler
135def clean_functions(module, fn_names, predefined):
136    new = []
137    for item in module.body:
138        if isinstance(item, ast.FunctionDef):
139            new.append(remove_unsupported_code(item, fn_names, predefined, item.name == "main")) 
140        else:
141            new.append(item)
142    module.body = new
143    return module
144
145#creates a dictionary of functions and their index in the list of statements.
146def collect_names(body):
147    names = {}
148    for index, item in enumerate(body):
149        if isinstance(item, ast.FunctionDef):
150            names[item.name] = index
151
152    return names
153
154
155def modify_name(var, args, cargs, unique_prefix):
156    carg_names = [x.id for x in cargs]
157    arg_names = [translate_var(x) for x in args]
158
159    varname = translate_var(var)
160    if isinstance(var, ast.Name):
161        if varname in carg_names:
162            return ast.Name(id=arg_names[carg_names.index(varname)])
163        return ast.Name(id=unique_prefix+varname)
164    if isinstance(var, ast.Attribute) or isinstance(var, ast.Subscript):
165        if var.value.id in carg_names:
166            var.value.id = arg_names[carg_names.index(var.value.id)]
167            return ast.Name(id=translate_var(var))
168        return ast.Name(id=unique_prefix+translate_var(var))
169
170def replace_in_exp(exp, args, cargs, unique_prefix):
171    if isinstance(exp, ast.BinOp):
172        exp.left = replace_in_exp(exp.left, args, cargs, unique_prefix)
173        exp.right = replace_in_exp(exp.right, args, cargs, unique_prefix)
174    elif isinstance(exp, ast.UnaryOp):
175        exp.operand = replace_in_exp(exp.operand, args, cargs, unique_prefix)
176    elif isinstance(exp, ast.Name):
177        exp = modify_name(exp, args, cargs, unique_prefix)
178    elif isinstance(exp, ast.Attribute):
179        exp = modify_name(exp, args, cargs, unique_prefix)
180    elif isinstance(exp, ast.Subscript):
181        exp = modify_name(exp, args, cargs, unique_prefix)
182    elif isinstance(exp, ast.Call):
183        for index, item in enumerate(exp.args):
184            exp.args[index] = replace_in_exp(item, args, cargs, unique_prefix)
185    elif isinstance(exp, ast.Tuple):
186        for index, elt in enumerate(exp.elts):
187            exp.elts[index] = replace_in_exp(elt, args, cargs, unique_prefix)
188    elif isinstance(exp, ast.Num):
189        pass
190    else:
191        assert(1==0)
192    return exp
193
194#generates a list of all function calls within function "main".
195def get_all_calls(main, all_func):
196    call_list = []
197    for index, loc in enumerate(main.body):
198        if isinstance(loc, ast.Assign) and isinstance(loc.value, ast.Call):
199                func_name = translate_var(loc.value.func)
200                if func_name in all_func:
201                    call_list.append(index)
202        if isinstance(loc, ast.AugAssign) and isinstance(loc.value, ast.Call):
203                func_name = translate_var(loc.value.func)
204                if func_name in all_func:
205                    call_list.append(index)
206    return call_list
207
208#modifies names of variables in "callee" by adding a prefix. This is a necessary part of inlining.
209def update_var_names(callee, args):
210    unique_prefix = '_'+callee.name.replace('.', '_')+'_'
211
212    assert (len(args)==len(callee.args.args))
213
214    for index, loc in enumerate(callee.body):
215        if isinstance(loc, ast.Assign):
216            for index, var in enumerate(loc.targets):
217                loc.targets[index] = modify_name(var, args, callee.args.args, unique_prefix)
218            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
219        elif isinstance(loc, ast.AugAssign):
220            loc.target = modify_name(loc.target, args, callee.args.args, unique_prefix)
221            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
222        elif isinstance(loc, ast.Return):
223            pass
224            #loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
225        else:
226            assert (1==0)
227
228    return callee, unique_prefix
229
230
231def get_name(loc, index):
232    if isinstance(loc.value, ast.Tuple):
233        return translate_var(loc.value.elts[index])
234    elif index == 0:
235        return translate_var(loc.value)
236    else:
237        assert(1==0)
238
239
240def modify_assignments_in_caller(main, line_no, return_list, prefix, callee):
241    if isinstance(main.body[line_no].targets[0], ast.Tuple):
242        all_vars = main.body[line_no].targets[0].elts
243    else:
244        all_vars = main.body[line_no].targets
245    new_code = []
246
247    if isinstance(callee.body[-1].value, ast.Tuple):
248        return_vars = [translate_var(x) for x in callee.body[-1].value.elts]
249    else:
250        return_vars = [translate_var(callee.body[-1].value)]
251
252    for index, item in enumerate(all_vars):
253        actual_name = translate_var(item)
254        if actual_name in [translate_var(x) for x in main.body[line_no].value.args]:
255            continue
256        temp_name = return_vars[index]
257
258        if temp_name in [translate_var(x) for x in callee.args.args]:
259            second_index = [translate_var(x) for x in callee.args.args].index(temp_name)
260            final_name = translate_var(main.body[line_no].value.args[second_index])
261        else:
262            final_name = prefix+temp_name
263
264        if return_list[temp_name][0] == "struct":
265            my_template = STRUCT_TEMPLATE
266        elif return_list[temp_name][0] == "array":
267            my_template = ARRAY_TEMPLATE
268 
269        if return_list[temp_name][0] == "bitblock":
270            new_code.append(ast.Assign( [ast.Name(id=item.id)] , ast.Name(id=final_name) ))
271        else:
272            for elem in return_list[temp_name][1]:
273                new_code.append(ast.Assign([ast.Name(id=my_template%(item.id, elem))], ast.Name(id=prefix+my_template%(temp_name, elem))))
274    main.body = main.body[:line_no]+new_code+main.body[line_no+1:]
275    return main
276
277#returns a list of all variables returned by the given function
278def get_return_list(callee):
279    ret_list = []
280    loc = callee.body[-1]
281    if isinstance(loc, ast.Return):
282        if isinstance(loc.value, ast.Tuple):
283            all_rets = loc.value.elts
284        else:
285            all_rets = [loc.value]
286        for item in all_rets:
287            if isinstance(item, ast.Attribute) or isinstance(item, ast.Subscript):
288                ret_list.append(item.value.id)
289            if isinstance(item, ast.Name):
290                ret_list.append(item.id)
291    return ret_list
292
293# returns a list of variables that are updated when the given line of code(loc) is run.
294def get_updated(loc):
295    if isinstance(loc, ast.AugAssign):
296        updated = [loc.target]
297    elif isinstance(loc, ast.Assign):
298        if isinstance(loc.targets[0], ast.Tuple):
299            updated = loc.targets[0].elts
300        else:
301            updated = loc.targets
302    else:
303        updated = []
304    return updated
305
306#Extracts information about variables returned by the given function. The second argument ret_list
307# is a list of variables that function "callee" reutrns.
308def get_update_details(callee, ret_list):
309    field_info = {}
310    for item in ret_list:
311        field_info[item] = ['simple', set([])]
312    for loc in callee.body:
313        updated = get_updated(loc)
314        for item in updated:
315            type, name, extra = parse_var(translate_var(item, True))
316            if name in field_info:
317                field_info[name][0] = type
318                if type=='struct':
319                    field_info[name][1].add(extra)
320                if type == 'array':
321                    field_info[name][1].add(extra)
322    return field_info
323
324#Extracts some information about variables that "callee" returns. This is necessary for inlining.
325def process_return(callee):
326    if not isinstance(callee.body[-1], ast.Return):
327        return {}
328    ret_list = get_return_list(callee)
329    field_info = get_update_details(callee, ret_list)
330    return field_info
331
332#Generates the code equivalent to initialization of a list.
333def generate_expanded_code(loc):
334    new_code = []
335    array_name = translate_var(loc.targets[0])
336    for ind, i in enumerate(range(len(loc.value.elts))):
337        new_loc = ast.Assign([ast.Subscript(value=ast.Name(id=array_name), slice=ast.Index(ast.Num(n=ind)) )], loc.value.elts[i])
338        new_code.append(new_loc)
339
340    return new_code
341
342#List initializations are expanded to a number of single variable initializations
343def expand_list_init(module):
344    for item in module.body:
345        replace_code = {}
346        if isinstance(item, ast.FunctionDef):
347            for index,loc in enumerate(item.body):
348                if isinstance(loc, ast.Assign):
349                    if isinstance(loc.value, ast.List):
350                        equivalent_code = generate_expanded_code(loc)
351                        replace_code[index] = equivalent_code
352            keys = [key for key in replace_code]
353            keys.sort(reverse = True)
354            for line_no in keys:
355                item.body = item.body[:line_no]+replace_code[line_no]+item.body[line_no+1:]
356
357    return module
358
359#Returns a list of predefined functions. These are usually primitives on bit streams.
360def get_predefined_funcs():
361    advance = ast.Attribute(value=ast.Name(id="bitutil"), attr="Advance")
362    scan_thru = ast.Attribute(value=ast.Name(id="bitutil"), attr="ScanThru")
363    optimize = ast.Attribute(value=ast.Name(id="bitutil"), attr="Optimize")
364    return [translate_var(advance), translate_var(scan_thru), translate_var(optimize)]
365
366
367#This function is called to do inlining.
368def do_inlining(module, file_path):
369    module = collect_functions(module, file_path)
370    func_dict = collect_names(module.body)
371    module = expand_list_init(module)
372
373    valid_fn = [name for name in func_dict] 
374    predef_fn = get_predefined_funcs()
375    module = clean_functions(module, valid_fn, predef_fn)
376
377    main = module.body[func_dict["main"]]
378    call_list = get_all_calls(main, [func for func in func_dict])
379    call_list.sort(reverse=True)
380    for line_no in call_list:
381        callee = copy.deepcopy(module.body[func_dict[translate_var(main.body[line_no].value.func)]])
382        args = main.body[line_no].value.args
383        return_list = process_return(callee)
384        callee, prefix = update_var_names(callee, args)
385
386        main = modify_assignments_in_caller(main, line_no, return_list, prefix, callee)
387
388        if isinstance(callee.body[-1], ast.Return):
389            del callee.body[-1]
390        main.body = main.body[:line_no]+callee.body+main.body[line_no:]
391       
392        #if callee.name == 'strct_byteclass__classify_bytes_':
393        #  for item in module.body[func_dict["main"]].body:
394        #      if isinstance(item, ast.Assign):
395        #       if isinstance(item.targets[0], ast.Name):
396        #           print item.targets[0].id
397         
398    #for item in module.body[func_dict["main"]].body:
399    #   if isinstance(item, ast.Assign):
400    #      if isinstance(item.targets[0], ast.Name):
401    #         print item.targets[0].id
402   
403    return module.body[func_dict["main"]].body
404
405#############################################################################################
406## Translation to bitExpr classes --- Pass 2
407#############################################################################################
408
409def translate_index(ast_value):
410        if isinstance(ast_value, ast.Str): return ast_value.s
411        elif isinstance(ast_value, ast.Num): return repr(ast_value.n)
412        else: raise PyBitError("Unknown value %s\n" % ast.dump(ast_value))
413
414def is_Advance_Call(fncall):
415        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
416        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
417                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
418        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
419
420def is_ScanThru_Call(fncall):
421        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
422        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
423                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
424        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
425
426#Convertnig Python AST to our own AST
427def translate(ast_expr):
428        if isinstance(ast_expr, ast.Name):
429                if ast_expr.id=="allzero":
430                        return bitexpr.FalseLiteral()
431                if ast_expr.id=="allone":
432                        return bitexpr.TrueLiteral()
433                return bitexpr.Var(translate_var(ast_expr))
434        elif isinstance(ast_expr, ast.Num):
435                if ast_expr.n == 0:
436                    return bitexpr.FalseLiteral()
437                else:
438                    return bitexpr.Const(ast_expr.n)
439        elif isinstance(ast_expr, ast.Attribute):
440                if isinstance(ast_expr.value, ast.Name):
441                        return bitexpr.Var(translate_var(ast_expr))
442                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, ast_expr.attr))
443                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(ast_expr))
444        elif isinstance(ast_expr, ast.Subscript):
445                if isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice, ast.Index):
446                        return bitexpr.Var(translate_var(ast_expr))
447                        return bitexpr.Var("%s[%s]" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
448                else: raise PyBitError("Illegal array %s\n" % ast.dump(ast_expr))
449        elif isinstance(ast_expr, ast.UnaryOp):
450                e1 = translate(ast_expr.operand)
451                if isinstance(ast_expr.op, ast.Invert):
452                        return bitexpr.make_not(e1)
453                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
454        elif isinstance(ast_expr, ast.BinOp):
455                e1 = translate(ast_expr.left)
456                e2 = translate(ast_expr.right)
457                if isinstance(ast_expr.op, ast.BitOr):
458                        return bitexpr.make_or(e1, e2)
459                elif isinstance(ast_expr.op, ast.BitAnd):
460                        return bitexpr.make_and(e1, e2)
461                elif isinstance(ast_expr.op, ast.BitXor):
462                        return bitexpr.make_xor(e1, e2)
463                elif isinstance(ast_expr.op, ast.Add):
464                        return bitexpr.make_add(e1, e2)
465                elif isinstance(ast_expr.op, ast.Sub):
466                        return bitexpr.make_sub(e1, e2)
467                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
468        elif isinstance(ast_expr, ast.Call):
469                if is_Advance_Call(ast_expr):
470                        e = translate(ast_expr.args[0])
471                        temp = bitexpr.make_add(e,e)
472                        return temp
473                elif is_ScanThru_Call(ast_expr):
474                        e0 = translate(ast_expr.args[0])
475                        e1 = translate(ast_expr.args[1])
476                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
477                else:
478                    raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
479        elif isinstance(ast_expr, ast.Compare):
480                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
481                        e0 = translate(ast_expr.left)
482                        return bitexpr.isNoneZero(e0)
483                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
484               
485        else:
486            raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
487
488def translate_var(v, aggregate_type = False):
489        if isinstance(v, ast.Name):
490            agg_name = v.id
491            var_name = v.id
492        elif isinstance(v, ast.Attribute):
493           if isinstance(v.value, ast.Name):
494               var_name = STRUCT_TEMPLATE%(v.value.id, v.attr)
495               agg_name = "%s.%s" % (v.value.id, v.attr)
496           else: raise PyBitError("Illegal attribute %s\n" % ast.dump(v))
497        elif isinstance(v, ast.Subscript):
498           if isinstance(v.value, ast.Name) and isinstance(v.slice, ast.Index):
499               var_name = ARRAY_TEMPLATE%(v.value.id, translate_index(v.slice.value))
500               agg_name = "%s[%s]" % (v.value.id, translate_index(v.slice.value))
501               
502        else: raise PyBitError("Unknown operator %s\n" % ast.dump(v))
503
504        if aggregate_type:
505            return agg_name
506        else:
507            return var_name
508
509def translate_stmts(ast_stmts):
510        translated = []
511        for s in ast_stmts:
512                if isinstance(s, ast.Expr):
513                        if (translate_var(s.value.func)=='strct_bitutil__Optimize_'):
514                            target = translate_var(s.value.args[0])
515                            replace = translate_var(s.value.args[1])
516                            translated.append(bitexpr.Reduce(target, replace))
517                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
518                elif isinstance(s, ast.Assign):
519                        e = translate(s.value)
520                        for astv in s.targets: 
521                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
522                elif isinstance(s, ast.AugAssign):
523                        v = bitexpr.Var(translate_var(s.target))
524                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
525                elif isinstance(s, ast.While):
526                        e = translate(s.test)
527                        body = translate_stmts(s.body)
528                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body))
529                elif isinstance(s, ast.If):
530                        e = translate(s.test)
531                        body = translate_stmts(s.body)
532                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body, fake = True ))
533                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
534        return translated
535#############################################################################################
536## Conversion to SSA form
537#############################################################################################
538
539def extract_vars(rhs):
540    if isinstance(rhs, bitexpr.Var):
541        return [rhs.varname]
542    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral) or isinstance(rhs, bitexpr.Const):
543        return []
544    if isinstance(rhs, bitexpr.Not):
545        return extract_vars(rhs.operand1)
546    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
547    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
548
549def update_lineno(inner_table, shift):
550    for key in inner_table:
551        for index in range(len(inner_table[key][0])):
552            inner_table[key][0][index] += shift
553        for index in range(len(inner_table[key][1])):
554            inner_table[key][1][index] += shift
555    return inner_table
556
557def merge_tables(table, inner_table):
558    for key in inner_table:
559        if key in table:
560            table[key][0] += inner_table[key][0]
561            table[key][1] += inner_table[key][1]
562        else:
563            table[key] = inner_table[key]
564    return table
565
566#generating symbol table. goInside is a flag that determines whether or not
567#the symbol table cover body of a while loop
568def gen_sym_table(code, goInside = False):
569    """Generate a simple symbol table for a three address code
570        each entry is of this form: var_name:[[defs][uses]]
571    """
572    table = {}
573    index = 0
574    for stmt in code:
575        if isinstance(stmt, bitexpr.Reduce):
576            if stmt.target in table:
577                table[stmt.target][1].append(index)
578            else:
579                table[stmt.target] = [[], [index]]
580            index += 1
581        elif isinstance(stmt, bitexpr.BitAssign):
582            current = stmt.LHS.varname
583            if current in table:
584                table[current][0].append(index)
585            else:
586                table[current] = [[index],[]]
587
588            varlist = extract_vars(stmt.RHS)
589            for var in varlist:
590                if var in table:
591                    table[var][1].append(index)
592                else:
593                    table[var] = [[], [index]]
594            index += 1
595        elif isinstance(stmt, bitexpr.WhileLoop):
596            cond_var = stmt.control_expr.var.varname
597            if cond_var in table:
598                table[cond_var][1].append(index)
599            else:
600                #while loop conditioned on an undefined variable
601                assert(1==0)
602            if goInside:
603                inner_table = gen_sym_table(stmt.stmts, True)
604                inner_table = update_lineno(inner_table, index+1)
605                table = merge_tables(table, inner_table)
606                index += (1+len(stmt.stmts))
607            else:
608                index += 1
609        else:
610            assert(1==0)
611    return table
612
613##################################################################################
614def get_line(code, line):
615    lineno = 0
616    for stmt in code:
617        if isinstance(stmt, bitexpr.BitAssign):
618            if lineno == line:
619                return stmt
620            lineno += 1
621        elif isinstance(stmt, bitexpr.WhileLoop):
622            if lineno == line:
623                return stmt
624            elif lineno+len(stmt.stmts) < line:
625                lineno += 1
626                line -= len(stmt.stmts)
627                continue
628            else:
629                return get_line(stmt.stmts, line-(lineno+1))
630
631        elif isinstance(stmt, bitexpr.Reduce):
632            if lineno == line:
633                return stmt
634            lineno += 1
635        else:
636            assert(1==0)
637
638#Variable name changes in order to convert the code to SSA
639#This function applies the change to the line that defines a variable/
640def update_def(code, var, line, suffix_num):
641    loc = code[line]
642    if isinstance(loc, bitexpr.BitAssign):
643        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
644        return loc
645    else:
646        #either Reduce or While Loop in both cases it's a use
647        pass
648
649# variable name  (for a variable that appears on the right side of an assignment)
650# is replaced by its new name
651def update_rhs(rhs, varname, newname):
652    if isinstance(rhs, bitexpr.Var):
653        if rhs.varname == varname:
654            rhs.varname = newname
655        return rhs
656
657    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
658        return rhs
659
660    if isinstance(rhs, bitexpr.Not):
661        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
662        return rhs
663       
664    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
665    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
666    return rhs
667
668#Changing the name of a variable
669def update_use(code, var, line, suffix_num):
670    loc = code[line]
671    if isinstance(loc, bitexpr.BitAssign):
672        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
673    elif isinstance(loc, bitexpr.Reduce):
674        pass
675    elif isinstance(loc, bitexpr.WhileLoop):
676        loc.control_expr.var.varname += "_%i"%suffix_num
677    return loc
678
679def parse_var(var):
680    index=var.find('.')
681    if index >= 0:
682        return ('struct', var[0:index], var[index+1:])
683               
684    index = var.find('[')
685    if index >= 0:
686        right_index = var.find(']')
687        return ('array', var[0:index], var[index+1:right_index])       
688   
689    if var.startswith("carry") or var.startswith("Carry"):
690        return('int', var, None)
691   
692    return ('bitblock', var, None)
693
694def simplify_name(var):
695    (vartype, name, extra) = parse_var(var)
696    if vartype == 'bitblock':
697        return name
698    if vartype == 'array':
699        return "a_%s_%s"%(name, extra)
700    if vartype == 'struct':
701        return "s_%s_%s"%(name, extra)
702    if vartype == 'int':
703        assert(1==0)
704    assert(1==0)
705
706#pairs([1,2,3,4,5]) ==> [(1,2), (2,3), (3,4), (4,5), (5,1)]
707def pairs(lst):
708    if lst == []:
709            return []
710    return zip(lst,lst[1:]+[lst[0]])
711
712
713def make_SSA(code, st):
714    new_vars = []
715    total_lines = len(code)
716    for var in st:
717        st[var][0].append(total_lines)
718        st[var][1].append(total_lines)
719
720    unique_number = 0
721    for var in st:
722        use_index = 0
723        for current, next in pairs(st[var][0])[0:-2]:
724            code[current] = update_def(code, var, current, unique_number)
725            uline = st[var][1][use_index]
726
727            while uline <= next and uline < total_lines:
728                if uline > current:
729                    code[uline] = update_use(code, var, uline, unique_number)
730                use_index += 1
731                uline = st[var][1][use_index]
732            unique_number += 1
733    return code
734
735#################################################################################################################
736## Breaking the code into basic blocks depending on where optimize pragma appears
737#################################################################################################################
738def get_opt_list(s):
739    opt_list = []
740    remove_index = []
741    for line, stmt in enumerate(s):
742        if isinstance(stmt, bitexpr.Reduce):
743            opt_pair = (stmt.target, stmt.replace)
744            opt_list.append(opt_pair)
745            remove_index.append(line)
746    for ind in reversed(remove_index):
747        del s[ind]
748
749    return opt_list
750
751def break_to_segs(s, breaks):
752    segs = []
753    breaks = [-1]+breaks+[len(s)]
754    for start, stop in pairs(breaks)[:-1]:
755        next = s[start+1:stop]
756        if len(next)>0:
757            segs.append(next)
758    return segs
759
760def extract_pragmas(s):
761    targets = []
762    exprs = []
763    lines = []
764
765    #extracting pragmas and their targets
766    for line, loc in enumerate(s):
767        if isinstance(loc, bitexpr.Reduce):
768            targets.append(loc.target)
769            exprs.append(loc)
770            lines.append(line)
771
772    #removing pragmas from the source code
773    for i in reversed(lines):
774        del s[i]
775
776    return targets, exprs
777
778def get_defs(targets, s):
779    temp = [(x,-1) for x in targets]
780    targ_dic = dict(temp)
781    for line, loc in enumerate(s):
782        if loc.LHS.varname in targets:
783            targ_dic[loc.LHS.varname] = line
784
785    items = targ_dic.items()
786    items = sorted(items, key=(lambda x: x[1]))
787
788    lineno = [key for value, key in items]
789    sorted_targets = [value for value, key in items]
790
791    return sorted_targets, lineno
792
793def defines(varname, stmts):
794    is_defined = False
795    for loc in stmts:
796        if isinstance(loc, bitexpr.WhileLoop):
797            is_defined |= defines(varname, loc.stmts)
798        elif isinstance(loc, bitexpr.BitAssign): 
799            if loc.LHS.varname == varname:
800                is_defined = True
801
802    return is_defined
803
804#
805def adjust_earliest(s, earliest, varname):
806    if earliest == -1:
807        return -1
808
809    line = earliest+1
810    for stmt in s[earliest+1:]:
811        if isinstance(stmt, bitexpr.WhileLoop):
812            if defines(varname, stmt.stmts):
813               earliest = line+1 
814        line+=1
815    return earliest
816
817#breaking the code into three segments
818#currently the third segment should be empty because stop is
819#chosen to be the last line of the code
820def chop_code(s, start, stop):
821    line = 0
822
823    if start == -1:
824        cut1 = -1
825
826    if stop == len(s):
827        cut2 = len(s)
828
829    for index, stmt in enumerate(s):
830        if isinstance(stmt, bitexpr.BitAssign):
831            if line == start:
832                cut1 = index
833            if line == stop:
834                cut2 = index
835            line += 1
836
837        if isinstance(stmt, bitexpr.WhileLoop):
838            if line == start:
839                cut1 = index
840            if line  == stop:
841                cut2 = index
842            line+=1
843
844    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
845
846#counting lines of code
847def count_lines(code):
848    cnt = len(code)
849    for stmt in code:
850        if isinstance(stmt, bitexpr.WhileLoop):
851            cnt += count_lines(stmt.stmts)
852    return cnt
853
854#generates the condition of the if statement
855def gen_cond_obj(opt_pragma):
856    if opt_pragma[1] == 'AllZero' or opt_pragma[1] == 'allzero':
857        cond_obj = bitexpr.isAllZero(opt_pragma[0])
858    if opt_pragma[1] == 'AllOne' or opt_pragma[1] == 'allone':
859        cond_obj = bitexpr.isAllOne(opt_pragma[0])
860    return cond_obj
861
862
863
864def gen_bb(s, opt_list, def_line):
865    assert (len(opt_list) == len(def_line))
866    if opt_list==[]:
867        return s
868
869    earliest = min(def_line)
870   
871    #latest is set to length of the code
872    #This means up to the end of the code is considered for optimization
873    latest = len(s)
874    earliest_index = def_line.index(earliest)
875    early_varname = opt_list[earliest_index][0]
876   
877    earliest = adjust_earliest(s, earliest, early_varname)
878
879    first, second, third = chop_code(s, earliest, latest)
880
881    cond_obj = gen_cond_obj(opt_list[earliest_index])
882   
883    del opt_list[earliest_index]
884    del def_line[earliest_index]
885    def_line = [x-earliest-1 for x in def_line]
886   
887    second = gen_bb(second, opt_list, def_line)
888
889    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))] + third
890    return result
891
892#Creating basic blocks required for implementation of optimize pragma (if..then..else statements)
893def partition2bb(s):
894    basicblocks = []
895    lineno = []
896    def_line = []
897    opt_list = get_opt_list(s)
898
899    st2 = gen_sym_table(s)
900
901    # This loop iterates over variables that has appeared in an optimize pragma.
902    # The line in which each of these variables is defined is found and appended to "def_line"
903    for item in opt_list:
904        if len(st2[item[0]][0]) > 0:
905            def_line.append(st2[item[0]][0][-1])
906        else:
907            def_line.append(-1)
908
909
910    return gen_bb(s, opt_list, def_line)
911
912#################################################################################################################
913## Generating declarations for variables. This pass is based on the syntax of C programming language ---
914## Normalizing the code
915#################################################################################################################
916
917def normalize(s, predec = {}, ccelim=True):
918    if len(s)==0:
919        return []
920    if isinstance(s[0], bitexpr.If):
921        gc = basic_block.BasicBlock.gensym_counter
922        cc = basic_block.BasicBlock.carry_counter
923        bc = basic_block.BasicBlock.brw_counter
924        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
925        maxgc = basic_block.BasicBlock.gensym_counter
926        maxcc = basic_block.BasicBlock.carry_counter
927        maxbc = basic_block.BasicBlock.brw_counter
928        basic_block.BasicBlock.gensym_counter = gc
929        basic_block.BasicBlock.carry_counter = cc
930        basic_block.BasicBlock.brw_counter = bc
931        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
932        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
933        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
934        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
935       
936        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
937    if isinstance(s[0], bitexpr.WhileLoop):
938        orig = copy.deepcopy(predec)
939        s[0].stmts = normalize(s[0].stmts, predec, False)
940        return [s[0]]+normalize(s[1:], {})
941    if isinstance(s[0], bitexpr.BitAssign):
942        code, next = recurse_forward(s)
943        bb = basic_block.BasicBlock(predec)
944        predec.update(bb.normalize(code, ccelim))
945        return bb.get_code() + normalize(next, copy.deepcopy(predec))
946
947#################################################################################################################
948## Generating declarations for variables. This pass is based on the syntax of C programming language
949#################################################################################################################
950def get_vars(stmt):
951        ints = set([])
952        bitblocks = set(['AllOne', 'AllZero'])
953        arrays = {}
954        structs = {}
955
956        #for stmt in code:
957
958        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
959       
960
961
962        if isinstance(stmt.RHS, bitexpr.Add): 
963                ints.add(stmt.RHS.carry)
964        if isinstance(stmt.RHS, bitexpr.Sub):
965                ints.add(stmt.RHS.brw)
966
967        for var in all_vars:
968                if isinstance(var, bitexpr.Const):
969                    continue
970               
971                (var_type, name, extra) = parse_var(var.varname)
972
973                if (var_type == "bitblock"):
974                        bitblocks.add(name)
975
976                if var_type == "array":
977                        if not name in arrays:
978                                arrays[name]= extra
979                        else:
980                                arrays[name] = max(arrays[name], extra)
981
982                if var_type == "struct":
983                        if not name in structs:
984                                structs[name] = set([extra])
985                        else:
986                                structs[name].add(extra)
987                if var_type == "int":
988                    ints.add(name)
989                       
990        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
991
992def merge_var_dic(first, second):
993    res = {}
994    res['int'] = first['int'].union(second['int'])
995    res['bitblock'] = first['bitblock'].union(second['bitblock'])
996
997    res['array'] = first['array']
998    temp = dict({'array':copy.copy(second['array'])})
999    for i in res['array']:
1000        if i in temp['array']:
1001            res['array'][i] = max(res['array'][i], temp['array'][i])
1002            del temp['array'][i]
1003    res['array'].update(temp['array'])
1004
1005
1006    res['struct'] = first['struct']
1007    temp = dict({'struct':copy.copy(second['struct'])})
1008    for i in res['struct']:
1009        if i in temp['struct']:
1010            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
1011            del temp['struct'][i]
1012    res['struct'].update(temp['struct'])
1013    return res
1014
1015def gen_output(var_dic):
1016        s = ''
1017        for i in var_dic['bitblock']:
1018                if i == AllOne:
1019                        s += "BitBlock %s = simd_const_1(1);\n"%i
1020                elif i == AllZero:
1021                        s+="BitBlock %s = simd_const_1(0);\n"%i
1022                elif i == "EOF_mask":
1023                        s+="BitBlock EOF_mask = simd_const_1(1);\n"
1024                else:
1025                        s+="BitBlock %s;\n"%i
1026
1027        for i in var_dic['array']:
1028                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
1029
1030        for i in var_dic['struct']:
1031                s+="struct __%s__{\n"%i
1032                for j in var_dic['struct'][i]:
1033                        s+= "\tBitBlock %s;\n"%j
1034                s+="};\n"
1035                s+="struct __%s__ %s;\n"%(i,i)
1036
1037        for i in  var_dic['int']:
1038                s+="CarryType %s = Carry0;\n"%i
1039
1040
1041
1042
1043        return s
1044
1045def gen_var_dic(s):
1046    if len(s)==0:
1047        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
1048    if isinstance(s[0], bitexpr.BitAssign):
1049        vd = get_vars(s[0])
1050        more = gen_var_dic(s[1:])
1051        vd = merge_var_dic(vd, more)
1052        return vd
1053
1054    if isinstance(s[0], bitexpr.If):
1055        vd1 = gen_var_dic(s[0].true_branch)
1056        vd2 = gen_var_dic(s[0].false_branch)
1057        vd = merge_var_dic(vd1, vd2)
1058        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1059        vd =  merge_var_dic(vd, vd3)
1060        more = gen_var_dic(s[1:])
1061        vd  = merge_var_dic(vd, more)
1062        return vd
1063
1064    if isinstance(s[0], bitexpr.WhileLoop):
1065        vd = gen_var_dic(s[0].stmts)
1066        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1067        vd =  merge_var_dic(vd, vd1)
1068        more = gen_var_dic(s[1:])
1069        vd  = merge_var_dic(vd, more)
1070        return vd
1071
1072def gen_declarations(s):
1073    vd = gen_var_dic(s)
1074    return gen_output(vd)
1075
1076#################################################################################################################
1077## This class replaces all occurrences of a reduced variable to its value --- Pass 7
1078## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
1079## *** the new AST notation                                                                                   ***
1080#################################################################################################################
1081def replace_in_rhs(rhs, target, replace):
1082    if isinstance(rhs, bitexpr.Const):
1083        return rhs
1084    if isinstance(rhs, bitexpr.Var):
1085        if rhs.varname == target:
1086            rhs = replace
1087            return replace
1088
1089    if isinstance(rhs.operand1, bitexpr.Var):
1090        if rhs.operand1.varname == target:
1091            rhs.operand1 = replace
1092    else:
1093        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
1094            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
1095
1096    if isinstance(rhs.operand2, bitexpr.Var):
1097        if rhs.operand2.varname == target:
1098            rhs.operand2 = replace
1099    else:
1100        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
1101            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
1102    return rhs
1103
1104def apply_single_opt(code, target, replace):
1105    if len(code) == 0:
1106        return []
1107
1108    if replace=='AllZero':
1109        replace = bitexpr.FalseLiteral()
1110    if replace == 'AllOne':
1111        replace = bitexpr.TrueLiteral()
1112
1113    if isinstance(code[0], bitexpr.BitAssign):
1114        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
1115
1116    if isinstance(code[0], bitexpr.If):
1117        if code[0].control_expr.var.varname == target:
1118            code[0].control_expr.var = replace
1119        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
1120        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
1121
1122    if isinstance(code[0], bitexpr.WhileLoop):
1123        if code[0].control_expr.var.varname == target:
1124            code[0].control_expr.var = replace
1125        #code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
1126
1127    return [code[0]]+apply_single_opt(code[1:], target, replace)
1128
1129def apply_all_opt(s):
1130    if len(s) == 0:
1131        return []
1132    elif isinstance(s[0], bitexpr.If):
1133        target = s[0].control_expr.var.varname
1134        replace = s[0].control_expr.val
1135        apply_single_opt(s[0].true_branch, target, replace)
1136        apply_all_opt(s[0].true_branch)
1137        apply_all_opt(s[0].false_branch)
1138
1139    elif isinstance(s[0], bitexpr.WhileLoop):
1140        apply_all_opt(s[0].stmts)
1141
1142    return [s[0]]+apply_all_opt(s[1:])
1143 
1144#################################################################################################################
1145## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation.
1146## method prune is supposed to remove unreachable branches but it is incomplete
1147#################################################################################################################
1148
1149#NOT USED
1150def prune(fixed, tree):
1151    """removes unreachable branches of the tree"""
1152    target = tree[0].target
1153    replace = tree[0].replace
1154    if target in fixed:
1155        if replace == 'allone':
1156            if isinstance(fixed[target], bitexpr.TrueLiteral):
1157                tree[0] = None
1158                tree[1].join(tree[2][1])
1159                del tree[3]
1160                del tree[2]
1161                fixed.update(tree[1].simplify(fixed))
1162            if isinstance(fixed[target], bitexpr.FalseLiteral):
1163                tree[0] = None
1164                tree[1].join(tree[3][1])
1165                del tree[3]
1166                del tree[2]
1167                fixed.update(tree[1].simplify(fixed))
1168        if replace == 'allzero':
1169            if isinstance(fixed[target], bitexpr.FalseLiteral):
1170                tree[0] = None
1171                tree[1].join(tree[2][1])
1172                del tree[3]
1173                del tree[2]
1174                fixed.update(tree[1].simplify(fixed))
1175            if isinstance(fixed[target], bitexpr.TrueLiteral):
1176                tree[0] = None
1177                tree[1].join(tree[3][1])
1178                del tree[3]
1179                del tree[2]
1180                fixed.update(tree[1].simplify(fixed))
1181
1182    empty_list = []
1183    for index, branch in enumerate(tree[2:]):
1184        if branch[0] is None:
1185            if branch[1].code == []:
1186                empty_list.append(2+index)
1187
1188    empty_list.reverse()
1189    for i in empty_list:
1190        del tree[i]
1191       
1192
1193def filter_fixed(fixed, stmts):
1194    for loc in stmts:
1195        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
1196            del fixed[loc.LHS.varname]
1197        if isinstance(loc, bitexpr.If):
1198            fixed = filter_fixed(fixed, loc.true_branch)
1199            fixed = filter_fixed(fixed, loc.false_branch)
1200        if isinstance(loc, bitexpr.WhileLoop):
1201            fixed = filter_fixed(fixed, loc.stmts)
1202    return fixed
1203
1204#Optimizing the code using constant and copy propagation and rules of boolean algebra.
1205def simplify_tree(code, fixed = {}, prev = []):
1206    if len(code) == 0:
1207        return []
1208
1209    assumptions = {}
1210
1211    if isinstance(code[0], bitexpr.BitAssign):
1212        this, next = recurse_forward(code)
1213        fixed.update(basic_block.simplify(this, fixed))
1214        return this+simplify_tree(next, fixed, this)
1215
1216    elif isinstance(code[0], bitexpr.If):
1217        fixed1 = copy.deepcopy(fixed)
1218        fixed2 = copy.deepcopy(fixed)
1219        if isinstance(code[0].control_expr, bitexpr.isAllOne):
1220            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
1221        if isinstance(code[0].control_expr, bitexpr.isAllZero):
1222            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
1223        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
1224        fixed1.update(assumptions)
1225        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
1226        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
1227        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1228
1229    elif isinstance(code[0], bitexpr.WhileLoop):
1230        fixed = filter_fixed(fixed, code[0].stmts)
1231        fixed1 = copy.deepcopy(fixed)
1232        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
1233        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1234
1235#################################################################################################################
1236## Dead Code Elimination
1237#################################################################################################################
1238#NOT USED
1239def get_effective_name(varname):
1240    if varname.startswith("array") or varname.startswith("strct"):
1241        try:
1242            ind = varname.index("__")
1243            return varname[6:ind]
1244        except:
1245            return varname
1246    return varname
1247
1248#Checks whether or not this line of code can be removed.
1249#It may also return new variables that must be kept and not removed by
1250#dead code elimination
1251def check_loc(loc, must_liv):
1252    effective_name = loc.LHS.varname
1253    if effective_name in must_liv:
1254        if isinstance(loc.RHS, bitexpr.Not):
1255            return set([loc.RHS.operand1.varname]), False
1256        elif isinstance(loc.RHS, bitexpr.Var):
1257            return set([loc.RHS.varname]), False
1258        elif isinstance(loc.RHS, bitexpr.FalseLiteral):
1259            return set([]), False
1260        elif isinstance(loc.RHS, bitexpr.TrueLiteral):
1261            return set([]), False
1262        elif isinstance(loc.RHS, bitexpr.Const):
1263            return set([]), False
1264        else:
1265            return set([loc.RHS.operand1.varname, loc.RHS.operand2.varname]), False
1266    else:
1267        return set([]), True
1268
1269def remove_copies(bb):
1270    """removes all copy statements e.g. var1 = var2"""
1271    lhs = [x.LHS.varname for x in bb]
1272    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
1273    for i in rhs:
1274        if i[1] in lhs:
1275            line = lhs.index(i[1])
1276            if i[0] > line:
1277                bb[i[0]].RHS = bb[line].RHS
1278    return bb
1279
1280#Actually removes the dead  code  from a basic block
1281def remove_dead(bb, must_live):
1282    bb = remove_copies(bb)
1283    my_alives = set([])
1284    dead = []
1285
1286    for line, loc in reversed(list(enumerate(bb))):
1287        new_lives, removable = check_loc(loc, my_alives.union(must_live))
1288
1289        if removable:
1290            dead.append(line)
1291        else:
1292            my_alives = my_alives.union(new_lives)
1293
1294    for i in dead:
1295        del bb[i]
1296
1297    return my_alives, bb
1298
1299#Get variables that are used in the loop and must not be removed. For example
1300#variables that are used in the condition of a loop.
1301def get_loop_variables(code):
1302    must_live = set([])
1303    for loc in code:
1304        if isinstance(loc, bitexpr.WhileLoop):
1305            if not loc.carry_expr is None:
1306                must_live.add(loc.carry_expr.var.varname)
1307        if isinstance(loc, bitexpr.If):
1308            true_lives = get_loop_variables(loc.true_branch)
1309            false_lives = get_loop_variables(loc.false_branch)
1310            must_live = must_live.union(true_lives)
1311            must_live = must_live.union(false_lives)
1312    return must_live
1313
1314
1315#Traverses code from the end to the begniing, one basic block at a time, to detect dead code and remove it.
1316def eliminate_dead_code(tree, must_live):
1317    if len(tree) == 0:
1318        return must_live, []
1319
1320    last = len(tree) - 1
1321    new_alives = set([])
1322    bb = []
1323
1324    must_live = must_live.union(get_loop_variables(tree))
1325
1326    if isinstance(tree[-1], bitexpr.BitAssign):
1327        first = 0
1328        for i in reversed(range(len(tree))):
1329            if not isinstance(tree[i], bitexpr.BitAssign):
1330                first = i+1
1331                break
1332
1333        new_alives, bb = remove_dead(tree[first:], must_live)
1334        last = first
1335    elif isinstance(tree[-1], bitexpr.If):
1336        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
1337        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
1338        bb = [tree[-1]]
1339        new_alives.add(tree[-1].control_expr.var.varname)
1340
1341    elif isinstance(tree[-1], bitexpr.WhileLoop):
1342        must_live.add(tree[-1].control_expr.var.varname)
1343        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
1344        bb = [tree[-1]]
1345
1346    all_alives = new_alives.union(must_live)
1347    dummy, new_tree = eliminate_dead_code(tree[:last], all_alives)
1348    tree = new_tree+bb
1349    return all_alives.union(dummy), tree
1350
1351#################################################################################################################
1352## This pass processes the code in the while loop and adds extra code required for handling carry variables
1353## It also updates condition of the loop.
1354#################################################################################################################
1355carry_suffix = "_i"
1356
1357def fix_the_loop(loop, carry_type):
1358    carries = []
1359    for loc in loop.stmts:
1360        if isinstance(loc.RHS, bitexpr.Add):
1361            carries.append(loc.RHS.carry)
1362        if isinstance(loc.RHS, bitexpr.Sub):
1363            carries.append(loc.RHS.brw)
1364    for item in carries:
1365        newvar = bitexpr.Var(item+carry_suffix)
1366        if (not loop.fake): loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), carry_type)))
1367    for item in carries:
1368        if (not loop.fake): loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral(carry_type) ))
1369
1370    return loop, carries
1371
1372def process_while_loops(code):
1373    all = []
1374    update = {}
1375   
1376    carry_type = "int"
1377
1378   
1379    for index, loc in enumerate(code):
1380        if isinstance(loc, bitexpr.WhileLoop):
1381            update[index] = fix_the_loop(loc, carry_type)
1382            for i in update[index][1]:
1383                all.append(i)
1384                all.append(i+carry_suffix)
1385
1386        if isinstance(loc, bitexpr.If):
1387            loc.true_branch, live_list = process_while_loops(loc.true_branch)
1388            loc.false_branch, live_list = process_while_loops(loc.false_branch)
1389
1390    keys = [k for k in update]
1391    keys.sort(reverse=True)
1392    for key in keys:
1393        code[key] = update[key][0]
1394        carry_variable = bitexpr.Var("CarryTemp"+str(key))
1395        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral(carry_type)))
1396        for item in update[key][1][2:]:
1397            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), carry_type)))
1398        if len(update[key][1]) == 1:
1399            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
1400        elif len(update[key][1]) > 1:
1401            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), carry_type)))
1402        for item in update[key][1]:
1403            if (not code[key].fake): code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral(carry_type)))
1404        for item in update[key][1]:
1405            if (not code[key].fake): code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
1406
1407        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
1408    return code, all
1409#################################################################################################################
1410## This pass factors out the code that is common between the true branch and false branch of an if
1411## statement.
1412#################################################################################################################
1413
1414def are_the_same(one, two):
1415    if one.__class__ != two.__class__:
1416        return False
1417   
1418    if isinstance(one, bitexpr.BitAssign):
1419        if (one.LHS.varname != two.LHS.varname):
1420            return False
1421        if one.RHS.__class__ != two.RHS.__class__:
1422            return False
1423        if isinstance(one.RHS, bitexpr.FalseLiteral):
1424            return True
1425        if isinstance(one.RHS, bitexpr.TrueLiteral):
1426            return True
1427        if isinstance(one.RHS, bitexpr.Var):
1428            return (one.RHS.varname == two.RHS.varname)
1429        if isinstance(one.RHS, bitexpr.Const):
1430            return one.RHS.n == two.RHS.n
1431        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
1432
1433    if isinstance(one, bitexpr.If):
1434        if (one.control_expr.__class__ != two.control_expr.__class__):
1435            return False
1436        if (one.control_expr.var.varname != two.control_expr.var.varname):
1437            return False
1438        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
1439            return False
1440        for ind in range(len(one.true_branch)):
1441            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
1442                return False
1443        for ind in range(len(one.false_branch)):
1444            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
1445                return False
1446        return True
1447
1448    if isinstance(one, bitexpr.WhileLoop):
1449        if (one.control_expr.__class__ != two.control_expr.__class__):
1450            return False
1451        if (one.control_expr.var.varname != two.control_expr.var.varname):
1452            return False
1453        if (one.carry_expr.__class__ != two.carry_expr.__class__):
1454            return False
1455        #if (one.carry_expr.var.varname != two.carry_expr.var.varname):
1456        #    return False
1457        if (len(one.stmts) != len(two.stmts)):
1458            return False
1459        for ind in range(len(one.stmts)):
1460            if not are_the_same(one.stmts[ind], two.stmts[ind]):
1461                return False
1462        return True
1463
1464def get_factorable(cond):
1465    earliest = 0
1466    l = min(len(cond.true_branch), len(cond.false_branch))
1467    for index in reversed(range(-l,0)):
1468        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
1469            earliest = index
1470        else:
1471            break
1472
1473    common_length = -earliest
1474    return common_length
1475
1476def get_common_code(cond, common):
1477    true_len = len(cond.true_branch)-common
1478    false_len = len(cond.false_branch)-common
1479
1480    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
1481    common = cond.true_branch[true_len:]
1482
1483    return [new_cond], common
1484
1485def do_factorization(code, pos):
1486    indices = [key for key in pos]
1487    indices.sort()
1488
1489    for index in reversed(indices):
1490        new_cond, common = get_common_code(code[index], pos[index])
1491        code = code[:index]+new_cond+common+code[index+1:]
1492    return code
1493
1494#Factors out the common code at the end of two basic blocksof two branches of an if..then..else statement
1495def factor_out(code):
1496    for index, loc in enumerate(code):
1497        if isinstance(loc, bitexpr.If):
1498            code[index].true_branch = factor_out(code[index].true_branch)
1499            code[index].false_branch = factor_out(code[index].false_branch)
1500        if isinstance(loc, bitexpr.WhileLoop):
1501            code[index].stmts = factor_out(code[index].stmts)
1502
1503    pos = {}
1504    for index, loc in enumerate(code):
1505        if isinstance(loc, bitexpr.If):
1506            common_length = get_factorable(code[index])
1507            pos[index] = common_length
1508
1509    return do_factorization(code, pos)
1510
1511#################################################################################################################
1512## Generates C code given conditions-tree, by a recursive traversal of the tree
1513#################################################################################################################
1514def generate_condition(expr):
1515    if isinstance(expr, bitexpr.isAllZero):
1516        return "!bitblock_has_bit(%s)"%(expr.var.varname)
1517    elif isinstance(expr, bitexpr.isAllOne):
1518        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
1519    elif isinstance(expr, bitexpr.isNoneZero):
1520        return "bitblock_has_bit(%s)"%(expr.var.varname)
1521    else:
1522        assert (1==0)
1523
1524def generate_statement(expr, indent, cond_stmt):
1525    if isinstance(expr, bitexpr.isAllZero):
1526        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1527    elif isinstance(expr, bitexpr.isAllOne):
1528        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1529    elif isinstance(expr, bitexpr.isNoneZero):
1530        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1531    else:
1532        assert (1==0)
1533
1534def print_block_stmts(s, indent = 0):
1535    indent_unit = 4
1536    code = ""
1537    if len(s) == 0:
1538        return ""
1539    if isinstance(s[0], bitexpr.If):
1540        code = generate_statement(s[0].control_expr, indent, "if")
1541        code += print_block_stmts(s[0].true_branch, indent+indent_unit)
1542        code += " "*indent+"}\n"
1543        code += " "*indent+"else {\n"
1544        code += print_block_stmts(s[0].false_branch, indent+indent_unit)
1545        code += " "*indent+"}\n"
1546
1547    elif isinstance(s[0], bitexpr.WhileLoop) and (not s[0].fake):
1548        code = "\n%s((%s)||(test_carry(%s))) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname)
1549        code += print_block_stmts(s[0].stmts, indent+indent_unit)
1550        code += " "*indent+"}\n"
1551    elif isinstance(s[0], bitexpr.WhileLoop) and (s[0].fake):
1552        code = "\n%s((%s)||(test_carry(%s))) {\n"%(" "*indent+"if"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname)
1553        code += print_block_stmts(s[0].stmts, indent+indent_unit)
1554        code += " "*indent+"}\n"
1555    elif (isinstance(s[0].RHS, bitexpr.Const)):
1556        pass
1557 
1558    elif isinstance(s[0], bitexpr.BitAssign) and not (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1559
1560        code += " "*indent + s[0].LHS.varname
1561        code += " = "
1562
1563        if s[0].RHS.data_type == "vector":
1564            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1565                code += s[0].RHS.operand1.varname
1566                code += ";\n"
1567            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
1568                code += s[0].RHS.operand1.varname
1569                code += ";\n"
1570            elif isinstance(s[0].RHS, bitexpr.Var):
1571                code += s[0].RHS.operand1.varname
1572                code += ";\n"
1573
1574            else:
1575                code += s[0].RHS.op_C + "("
1576                code += s[0].RHS.operand1.varname
1577                code += ','
1578                code += s[0].RHS.operand2.varname
1579                code += ");\n"
1580        elif s[0].RHS.data_type == "int":
1581            if isinstance(s[0].RHS, bitexpr.Or):
1582                code += "carry_or(%s, %s);\n" % (s[0].RHS.operand1.varname, s[0].RHS.operand2.varname)
1583            elif isinstance(s[0].RHS, bitexpr.FalseLiteral):
1584                code += "Carry0;\n"
1585    elif isinstance(s[0], bitexpr.BitAssign) and (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1586            if isinstance(s[0].RHS, bitexpr.Add) and (s[0].RHS.operand1.varname == s[0].RHS.operand2.varname):
1587                code += " "*indent + s[0].RHS.op_advance + "("
1588                code += s[0].RHS.operand1.varname
1589                code += ', '
1590                code += s[0].RHS.carry
1591                code += ", "
1592                code += s[0].LHS.varname
1593                code += ");\n"
1594            elif isinstance(s[0].RHS, bitexpr.Add):
1595                code += " "*indent + s[0].RHS.op_C + "("
1596                code += s[0].RHS.operand1.varname
1597                code += ', '
1598                code += s[0].RHS.operand2.varname
1599                code += ', '
1600                code += s[0].RHS.carry
1601                code += ", "
1602                code += s[0].LHS.varname
1603                code += ");\n"
1604            elif isinstance(s[0].RHS, bitexpr.Sub):
1605                code += " "*indent + s[0].RHS.op_C + "("
1606                code += s[0].RHS.operand1.varname
1607                code += ', '
1608                code += s[0].RHS.operand2.varname
1609                code += ', '
1610                code += s[0].RHS.brw
1611                code += ", "
1612                code += s[0].LHS.varname
1613                code += ");\n"
1614    s.pop(0)
1615    return code+print_block_stmts(s, indent)
1616
1617
1618def print_stream_stmts(s, indent = 0):
1619    indent_unit = 4
1620    code = ""
1621    if len(s) == 0:
1622        return ""
1623    elif isinstance(s[0], bitexpr.BitAssign):
1624        if (isinstance(s[0].RHS, bitexpr.Const)):
1625            code += " "*indent + s[0].LHS.varname
1626            code += " = "
1627            code += "sisd_from_int(%i);"%s[0].RHS.n
1628            code += "\n"
1629
1630    s.pop(0)
1631    return code+print_stream_stmts(s, indent)
1632
1633#################################################################################################################
1634## Auxiliary Functions
1635#################################################################################################################
1636
1637def get_bb_of(code, index):
1638    if index < 0 or index >= len(code):
1639        return None
1640
1641    if isinstance(code[index], bitexpr.If):
1642        return code[index]
1643    if isinstance(code[index], bitexpr.WhileLoop):
1644        return code[index]
1645    if isinstance(code[index], bitexpr.BitAssign):
1646        last = index
1647        for loc in code[index+1:]:
1648            if isinstance(loc, bitexpr.BitAssign):
1649                last += 1
1650            else:
1651                break
1652 
1653        first = index
1654        for loc in reversed(code[:index]):
1655            if isinstance(loc, bitexpr.BitAssign):
1656                first -= 1
1657            else:
1658                break
1659
1660        return code[first:last+1]
1661
1662def get_previous_bb(code, index):
1663    if index < 0:
1664        return None
1665    if index >= len(code):
1666        return get_bb_of(code, len(code)-1)
1667    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
1668        return get_bb_of(code, index-1)
1669    if isinstance(code[index], bitexpr.BitAssign):
1670        for ind, loc in reversed(enumerate(code[:index])):
1671            if not isinstance(loc, bitexpr.BitAssign):
1672                return get_bb_of(code, ind)
1673        return get_bb_of(code, -1)
1674    assert(1==0)
1675
1676def recurse_forward(code):
1677    if isinstance(code[0], bitexpr.If):
1678        return code[0], code[1:]
1679    if isinstance(code[0], bitexpr.WhileLoop):
1680        return code[0], code[1:]
1681    if isinstance(code[0], bitexpr.BitAssign):
1682        last = 0
1683        for loc in code[1:]:
1684            if isinstance(loc, bitexpr.BitAssign):
1685                last += 1
1686            else:
1687                break
1688        return code[0:last+1], code[last+1:]
Note: See TracBrowser for help on using the repository browser.