source: proto/Compiler/py2bitexpr.py @ 460

Last change on this file since 460 was 460, checked in by cameron, 9 years ago

Generate carry_or

File size: 58.8 KB
Line 
1import ast, bitexpr, copy, basic_block
2
3type_dict = {}
4structs = {}
5arrays = {}
6
7AllOne = 'AllOne'
8AllZero = 'AllZero'
9
10STRUCT_TEMPLATE="strct_%s__%s_"
11ARRAY_TEMPLATE="array_%s__%s_"
12
13
14class PyBitError(Exception):
15        pass
16
17#############################################################################################
18## Extracting list of live variables
19#############################################################################################
20def get_lives(s):
21    for item in s.body:
22        if isinstance(item, ast.FunctionDef):
23            if (item.name =="main"):
24                assert (isinstance(item.body[-1], ast.Return))
25                ret = item.body[-1].value
26                if isinstance(ret, ast.Tuple):
27                    del item.body[-1]
28                    return [translate_var(x) for x in ret.elts],s
29                else:
30                    assert(isinstance(ret, ast.Name))
31                    del item.body[-1]
32                    return [translate_var(ret)],s
33
34#############################################################################################
35## Function Inlining
36#############################################################################################
37TEMP_VAR_TEMPLATE = "InlineTemp%i"
38
39def get_var_name(var):
40    var_name = translate_var(var)
41    if isinstance(var, ast.Name):
42        type_name = "simple"
43        extra = None
44    if isinstance(var, ast.Attribute):
45        type_name = "structure"
46        extra = var.attr
47    if isinstance(var, ast.Index):
48        type_name = "array"
49        extra = translate_index(v.slice.value)
50
51    return type_name, var_name, extra
52
53def register_var(final_name, original_name, type_name, extra):
54    pass
55
56def collect_functions(module, file_path):
57    imps = []
58    to_remove = []
59    for index, item in enumerate(module.body):
60        if isinstance(item, ast.Import):
61            imps.append(item)
62            to_remove.append(index)
63
64    for n in reversed(to_remove):
65        del module.body[n]
66
67    for imp in imps:
68        for mod in imp.names:
69            mod_file = open(file_path + mod.name + '.py')
70            mod_code = ""
71            for line in mod_file:
72                mod_code += line
73            mod_file.close()
74            mod_parsed = ast.parse(mod_code)
75            modified = []
76            for item in mod_parsed.body:
77                if isinstance(item, ast.FunctionDef):
78                    item.name = translate_var(ast.Attribute(value=ast.Name(id=mod.name), attr=item.name))
79            module.body += mod_parsed.body
80    return module
81
82
83def remove_unsupported_code(callee, fn_names, predefined, main=False):
84    to_remove = []
85    for index, loc in enumerate(callee.body):
86        if isinstance(loc, ast.While):
87            if not main:
88                to_remove.append(index)
89            continue
90
91        if isinstance(loc, ast.Expr):
92            if isinstance(loc.value, ast.Call):
93                if main and not (translate_var(loc.value.func) in fn_names+predefined):
94                    to_remove.append(index)
95                if not main and not translate_var(loc.value.func) in predefined:
96                    to_remove.append(index)
97
98            else:
99                to_remove.append(index)
100
101            continue
102
103        if not (isinstance(loc, ast.Assign) or isinstance(loc, ast.AugAssign) or isinstance(loc, ast.Return)):
104            to_remove.append(index)
105
106        #we do not support AugAssign with function calls
107        if isinstance(loc, ast.AugAssign):
108            if isinstance(loc.value, ast.Call):
109                assert (1==0)
110                to_remove.append(index)
111
112        #we do not support function calls in functions other than main
113        #these are OK if they are object construction, otherwise error.
114        if isinstance(loc, ast.Assign):
115            if isinstance(loc.value, ast.Call):
116                if main and not (translate_var(loc.value.func) in fn_names+predefined):
117                    to_remove.append(index)
118                if not main and not translate_var(loc.value.func) in predefined:
119                    to_remove.append(index)
120    to_remove.sort(reverse=True)
121    for line_no in to_remove:
122        del callee.body[line_no]
123
124    return callee
125
126
127def clean_functions(module, fn_names, predefined):
128    new = []
129    for item in module.body:
130        if isinstance(item, ast.FunctionDef):
131            new.append(remove_unsupported_code(item, fn_names, predefined, item.name == "main")) 
132        else:
133            new.append(item)
134    module.body = new
135    return module
136
137def collect_names(body):
138    names = {}
139    for index, item in enumerate(body):
140        if isinstance(item, ast.FunctionDef):
141            names[item.name] = index
142
143    return names
144
145def modify_name(var, args, cargs, unique_prefix):
146    carg_names = [x.id for x in cargs]
147    arg_names = [translate_var(x) for x in args]
148
149    varname = translate_var(var)
150    if isinstance(var, ast.Name):
151        if varname in carg_names:
152            return ast.Name(id=arg_names[carg_names.index(varname)])
153        return ast.Name(id=unique_prefix+varname)
154    if isinstance(var, ast.Attribute) or isinstance(var, ast.Subscript):
155        if var.value.id in carg_names:
156            var.value.id = arg_names[carg_names.index(var.value.id)]
157            return ast.Name(id=translate_var(var))
158        return ast.Name(id=unique_prefix+translate_var(var))
159
160def replace_in_exp(exp, args, cargs, unique_prefix):
161    if isinstance(exp, ast.BinOp):
162        exp.left = replace_in_exp(exp.left, args, cargs, unique_prefix)
163        exp.right = replace_in_exp(exp.right, args, cargs, unique_prefix)
164    elif isinstance(exp, ast.UnaryOp):
165        exp.operand = replace_in_exp(exp.operand, args, cargs, unique_prefix)
166    elif isinstance(exp, ast.Name):
167        exp = modify_name(exp, args, cargs, unique_prefix)
168    elif isinstance(exp, ast.Attribute):
169        exp = modify_name(exp, args, cargs, unique_prefix)
170    elif isinstance(exp, ast.Subscript):
171        exp = modify_name(exp, args, cargs, unique_prefix)
172    elif isinstance(exp, ast.Call):
173        for index, item in enumerate(exp.args):
174            exp.args[index] = replace_in_exp(item, args, cargs, unique_prefix)
175    elif isinstance(exp, ast.Tuple):
176        for index, elt in enumerate(exp.elts):
177            exp.elts[index] = replace_in_exp(elt, args, cargs, unique_prefix)
178    elif isinstance(exp, ast.Num):
179        pass
180    else:
181        assert(1==0)
182    return exp
183
184def get_all_calls(main, all_func):
185    call_list = []
186    for index, loc in enumerate(main.body):
187        if isinstance(loc, ast.Assign) and isinstance(loc.value, ast.Call):
188                func_name = translate_var(loc.value.func)
189                if func_name in all_func:
190                    call_list.append(index)
191        if isinstance(loc, ast.AugAssign) and isinstance(loc.value, ast.Call):
192                func_name = translate_var(loc.value.func)
193                if func_name in all_func:
194                    call_list.append(index)
195    return call_list
196
197def update_var_names(callee, args):
198    unique_prefix = '_'+callee.name.replace('.', '_')+'_'
199
200    assert (len(args)==len(callee.args.args))
201
202    for index, loc in enumerate(callee.body):
203        if isinstance(loc, ast.Assign):
204            for index, var in enumerate(loc.targets):
205                loc.targets[index] = modify_name(var, args, callee.args.args, unique_prefix)
206            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
207        elif isinstance(loc, ast.AugAssign):
208            loc.target = modify_name(loc.target, args, callee.args.args, unique_prefix)
209            loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
210        elif isinstance(loc, ast.Return):
211            pass
212            #loc.value = replace_in_exp(loc.value, args, callee.args.args, unique_prefix)
213        else:
214            assert (1==0)
215
216    return callee, unique_prefix
217
218
219def get_name(loc, index):
220    if isinstance(loc.value, ast.Tuple):
221        return translate_var(loc.value.elts[index])
222    elif index == 0:
223        return translate_var(loc.value)
224    else:
225        assert(1==0)
226
227def modify_assignments_in_caller(main, line_no, return_list, prefix, callee):
228    if isinstance(main.body[line_no].targets[0], ast.Tuple):
229        all_vars = main.body[line_no].targets[0].elts
230    else:
231        all_vars = main.body[line_no].targets
232    new_code = []
233
234    if isinstance(callee.body[-1].value, ast.Tuple):
235        return_vars = [translate_var(x) for x in callee.body[-1].value.elts]
236    else:
237        return_vars = [translate_var(callee.body[-1].value)]
238
239    for index, item in enumerate(all_vars):
240        actual_name = translate_var(item)
241        if actual_name in [translate_var(x) for x in main.body[line_no].value.args]:
242            continue
243        temp_name = return_vars[index]
244
245        if temp_name in [translate_var(x) for x in callee.args.args]:
246            second_index = [translate_var(x) for x in callee.args.args].index(temp_name)
247            final_name = translate_var(main.body[line_no].value.args[second_index])
248        else:
249            final_name = prefix+temp_name
250
251        if return_list[temp_name][0] == "struct":
252            my_template = STRUCT_TEMPLATE
253        elif return_list[temp_name][0] == "array":
254            my_template = ARRAY_TEMPLATE
255 
256        if return_list[temp_name][0] == "bitblock":
257            new_code.append(ast.Assign( [ast.Name(id=item.id)] , ast.Name(id=final_name) ))
258        else:
259            for elem in return_list[temp_name][1]:
260                new_code.append(ast.Assign([ast.Name(id=my_template%(item.id, elem))], ast.Name(id=prefix+my_template%(temp_name, elem))))
261    main.body = main.body[:line_no]+new_code+main.body[line_no+1:]
262    return main
263
264def get_return_list(callee):
265    ret_list = []
266    loc = callee.body[-1]
267    if isinstance(loc, ast.Return):
268        if isinstance(loc.value, ast.Tuple):
269            all_rets = loc.value.elts
270        else:
271            all_rets = [loc.value]
272        for item in all_rets:
273            if isinstance(item, ast.Attribute) or isinstance(item, ast.Subscript):
274                ret_list.append(item.value.id)
275            if isinstance(item, ast.Name):
276                ret_list.append(item.id)
277    return ret_list
278
279def get_updated(loc):
280    if isinstance(loc, ast.AugAssign):
281        updated = [loc.target]
282    elif isinstance(loc, ast.Assign):
283        if isinstance(loc.targets[0], ast.Tuple):
284            updated = loc.targets[0].elts
285        else:
286            updated = loc.targets
287    else:
288        updated = []
289    return updated
290
291def get_update_details(callee, ret_list):
292    field_info = {}
293    for item in ret_list:
294        field_info[item] = ['simple', set([])]
295    for loc in callee.body:
296        updated = get_updated(loc)
297        for item in updated:
298            type, name, extra = parse_var(translate_var(item, True))
299            if name in field_info:
300                field_info[name][0] = type
301                if type=='struct':
302                    field_info[name][1].add(extra)
303                if type == 'array':
304                    field_info[name][1].add(extra)
305    return field_info
306
307def process_return(callee):
308    if not isinstance(callee.body[-1], ast.Return):
309        return {}
310    ret_list = get_return_list(callee)
311    field_info = get_update_details(callee, ret_list)
312    return field_info
313
314
315def generate_expanded_code(loc):
316    new_code = []
317    array_name = translate_var(loc.targets[0])
318    for ind, i in enumerate(range(len(loc.value.elts))):
319        new_loc = ast.Assign([ast.Subscript(value=ast.Name(id=array_name), slice=ast.Index(ast.Num(n=ind)) )], loc.value.elts[i])
320        new_code.append(new_loc)
321
322    return new_code
323
324def expand_list_init(module):
325    for item in module.body:
326        replace_code = {}
327        if isinstance(item, ast.FunctionDef):
328            for index,loc in enumerate(item.body):
329                if isinstance(loc, ast.Assign):
330                    if isinstance(loc.value, ast.List):
331                        equivalent_code = generate_expanded_code(loc)
332                        replace_code[index] = equivalent_code
333            keys = [key for key in replace_code]
334            keys.sort(reverse = True)
335            for line_no in keys:
336                item.body = item.body[:line_no]+replace_code[line_no]+item.body[line_no+1:]
337
338    return module
339
340def get_predefined_funcs():
341    advance = ast.Attribute(value=ast.Name(id="bitutil"), attr="Advance")
342    scan_thru = ast.Attribute(value=ast.Name(id="bitutil"), attr="ScanThru")
343    optimize = ast.Attribute(value=ast.Name(id="bitutil"), attr="Optimize")
344    return [translate_var(advance), translate_var(scan_thru), translate_var(optimize)]
345
346def do_inlining(module, file_path):
347    module = collect_functions(module, file_path)
348    func_dict = collect_names(module.body)
349    module = expand_list_init(module)
350
351    valid_fn = [name for name in func_dict] 
352    predef_fn = get_predefined_funcs()
353    module = clean_functions(module, valid_fn, predef_fn)
354
355    main = module.body[func_dict["main"]]
356    call_list = get_all_calls(main, [func for func in func_dict])
357    call_list.sort(reverse=True)
358    for line_no in call_list:
359        callee = copy.deepcopy(module.body[func_dict[translate_var(main.body[line_no].value.func)]])
360        args = main.body[line_no].value.args
361        return_list = process_return(callee)
362        callee, prefix = update_var_names(callee, args)
363
364        main = modify_assignments_in_caller(main, line_no, return_list, prefix, callee)
365
366        if isinstance(callee.body[-1], ast.Return):
367            del callee.body[-1]
368        main.body = main.body[:line_no]+callee.body+main.body[line_no:]
369    return module.body[func_dict["main"]].body
370
371#############################################################################################
372## Translation to bitExpr classes --- Pass 2
373#############################################################################################
374
375def translate_index(ast_value):
376        if isinstance(ast_value, ast.Str): return ast_value.s
377        elif isinstance(ast_value, ast.Num): return repr(ast_value.n)
378        else: raise PyBitError("Unknown value %s\n" % ast.dump(ast_value))
379
380def is_Advance_Call(fncall):
381        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'Advance'
382        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
383                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'Advance'
384        return iscall and len(fncall.args) == 1 and fncall.kwargs == None and fncall.starargs == None
385
386def is_ScanThru_Call(fncall):
387        if isinstance(fncall.func, ast.Name): iscall = fncall.func.id == 'ScanThru'
388        elif isinstance(fncall.func, ast.Attribute) and isinstance(fncall.func.value, ast.Name):
389                 iscall = fncall.func.value.id == 'bitutil' and fncall.func.attr == 'ScanThru'
390        return iscall and len(fncall.args) == 2 and fncall.kwargs == None and fncall.starargs == None
391
392def translate(ast_expr):
393        if isinstance(ast_expr, ast.Name):
394                if ast_expr.id=="allzero":
395                        return bitexpr.FalseLiteral()
396                if ast_expr.id=="allone":
397                        return bitexpr.TrueLiteral()
398                return bitexpr.Var(translate_var(ast_expr))
399        elif isinstance(ast_expr, ast.Num):
400                if ast_expr.n == 0:
401                    return bitexpr.FalseLiteral()
402                else:
403                    return bitexpr.Const(ast_expr.n)
404        elif isinstance(ast_expr, ast.Attribute):
405                if isinstance(ast_expr.value, ast.Name):
406                        return bitexpr.Var(translate_var(ast_expr))
407                        return bitexpr.Var("%s.%s" % (ast_expr.value.id, ast_expr.attr))
408                else: raise PyBitError("Illegal attribute %s\n" % ast.dump(ast_expr))
409        elif isinstance(ast_expr, ast.Subscript):
410                if isinstance(ast_expr.value, ast.Name) and isinstance(ast_expr.slice, ast.Index):
411                        return bitexpr.Var(translate_var(ast_expr))
412                        return bitexpr.Var("%s[%s]" % (ast_expr.value.id, translate_index(ast_expr.slice.value)))
413                else: raise PyBitError("Illegal array %s\n" % ast.dump(ast_expr))
414        elif isinstance(ast_expr, ast.UnaryOp):
415                e1 = translate(ast_expr.operand)
416                if isinstance(ast_expr.op, ast.Invert):
417                        return bitexpr.make_not(e1)
418                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
419        elif isinstance(ast_expr, ast.BinOp):
420                e1 = translate(ast_expr.left)
421                e2 = translate(ast_expr.right)
422                if isinstance(ast_expr.op, ast.BitOr):
423                        return bitexpr.make_or(e1, e2)
424                elif isinstance(ast_expr.op, ast.BitAnd):
425                        return bitexpr.make_and(e1, e2)
426                elif isinstance(ast_expr.op, ast.BitXor):
427                        return bitexpr.make_xor(e1, e2)
428                elif isinstance(ast_expr.op, ast.Add):
429                        return bitexpr.make_add(e1, e2)
430                elif isinstance(ast_expr.op, ast.Sub):
431                        return bitexpr.make_sub(e1, e2)
432                else: raise PyBitError("Unknown operator %s\n" % ast.dump(ast_expr.op))
433        elif isinstance(ast_expr, ast.Call):
434                if is_Advance_Call(ast_expr):
435                        e = translate(ast_expr.args[0])
436                        temp = bitexpr.make_add(e,e)
437                        return temp
438                elif is_ScanThru_Call(ast_expr):
439                        e0 = translate(ast_expr.args[0])
440                        e1 = translate(ast_expr.args[1])
441                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
442                else: 
443                    raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
444        elif isinstance(ast_expr, ast.Compare):
445                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
446                        e0 = translate(ast_expr.left)
447                        return bitexpr.isNoneZero(e0)
448                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
449               
450        else: 
451            raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
452
453def translate_var(v, aggregate_type = False):
454        if isinstance(v, ast.Name):
455            agg_name = v.id
456            var_name = v.id
457        elif isinstance(v, ast.Attribute):
458           if isinstance(v.value, ast.Name):
459               var_name = STRUCT_TEMPLATE%(v.value.id, v.attr)
460               agg_name = "%s.%s" % (v.value.id, v.attr)
461           else: raise PyBitError("Illegal attribute %s\n" % ast.dump(v))
462        elif isinstance(v, ast.Subscript):
463           if isinstance(v.value, ast.Name) and isinstance(v.slice, ast.Index):
464               var_name = ARRAY_TEMPLATE%(v.value.id, translate_index(v.slice.value))
465               agg_name = "%s[%s]" % (v.value.id, translate_index(v.slice.value))
466               
467        else: raise PyBitError("Unknown operator %s\n" % ast.dump(v))
468
469        if aggregate_type:
470            return agg_name
471        else:
472            return var_name
473
474def translate_stmts(ast_stmts):
475        translated = []
476        for s in ast_stmts:
477                if isinstance(s, ast.Expr):
478                        if (translate_var(s.value.func)=='strct_bitutil__Optimize_'):
479                            target = translate_var(s.value.args[0])
480                            replace = translate_var(s.value.args[1])
481                            translated.append(bitexpr.Reduce(target, replace))
482                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
483                elif isinstance(s, ast.Assign):
484                        e = translate(s.value)
485                        for astv in s.targets: 
486                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
487                elif isinstance(s, ast.AugAssign):
488                        v = bitexpr.Var(translate_var(s.target))
489                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
490                elif isinstance(s, ast.While):
491                        e = translate(s.test)
492                        body = translate_stmts(s.body)
493                        translated.append(bitexpr.WhileLoop(bitexpr.isNoneZero(e), body))
494                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
495        return translated
496#############################################################################################
497## Conversion to SSA form --- Pass 3
498#############################################################################################
499
500def extract_vars(rhs):
501    if isinstance(rhs, bitexpr.Var):
502        return [rhs.varname]
503    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral) or isinstance(rhs, bitexpr.Const):
504        return []
505    if isinstance(rhs, bitexpr.Not):
506        return extract_vars(rhs.operand1)
507    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
508    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
509
510def update_lineno(inner_table, shift):
511    for key in inner_table:
512        for index in range(len(inner_table[key][0])):
513            inner_table[key][0][index] += shift
514        for index in range(len(inner_table[key][1])):
515            inner_table[key][1][index] += shift
516    return inner_table
517
518def merge_tables(table, inner_table):
519    for key in inner_table:
520        if key in table:
521            table[key][0] += inner_table[key][0]
522            table[key][1] += inner_table[key][1]
523        else:
524            table[key] = inner_table[key]
525    return table
526
527def gen_sym_table(code, goInside = False):
528    """Generate a simple symbol table for a three address code
529        each entry is of this form: var_name:[[defs][uses]]
530    """
531    table = {}
532    index = 0
533    for stmt in code:
534        if isinstance(stmt, bitexpr.Reduce):
535            if stmt.target in table:
536                table[stmt.target][1].append(index)
537            else:
538                table[stmt.target] = [[], [index]]
539            index += 1
540        elif isinstance(stmt, bitexpr.BitAssign):
541            current = stmt.LHS.varname
542            if current in table:
543                table[current][0].append(index)
544            else:
545                table[current] = [[index],[]]
546
547            varlist = extract_vars(stmt.RHS)
548            for var in varlist:
549                if var in table:
550                    table[var][1].append(index)
551                else:
552                    table[var] = [[], [index]]
553            index += 1
554        elif isinstance(stmt, bitexpr.WhileLoop):
555            cond_var = stmt.control_expr.var.varname
556            if cond_var in table:
557                table[cond_var][1].append(index)
558            else:
559                #while loop conditioned on an undefined variable
560                assert(1==0)
561            if goInside:
562                inner_table = gen_sym_table(stmt.stmts, True)
563                inner_table = update_lineno(inner_table, index+1)
564                table = merge_tables(table, inner_table)
565                index += (1+len(stmt.stmts))
566            else:
567                index += 1
568        else:
569            assert(1==0)
570    return table
571##################################################################################
572def get_line(code, line):
573    lineno = 0
574    for stmt in code:
575        if isinstance(stmt, bitexpr.BitAssign):
576            if lineno == line:
577                return stmt
578            lineno += 1
579        elif isinstance(stmt, bitexpr.WhileLoop):
580            if lineno == line:
581                return stmt
582            elif lineno+len(stmt.stmts) < line:
583                lineno += 1
584                line -= len(stmt.stmts)
585                continue
586            else:
587                return get_line(stmt.stmts, line-(lineno+1))
588
589        elif isinstance(stmt, bitexpr.Reduce):
590            if lineno == line:
591                return stmt
592            lineno += 1
593        else:
594            assert(1==0)
595
596def update_def(code, var, line, suffix_num):
597    loc = code[line]
598    if isinstance(loc, bitexpr.BitAssign):
599        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
600        return loc
601    else:
602        #either Reduce or While Loop in both cases it's a use
603        pass
604
605def update_rhs(rhs, varname, newname):
606    if isinstance(rhs, bitexpr.Var):
607        if rhs.varname == varname:
608            rhs.varname = newname
609        return rhs
610
611    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
612        return rhs
613
614    if isinstance(rhs, bitexpr.Not):
615        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
616        return rhs
617    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
618    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
619    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
620    return rhs
621
622def update_use(code, var, line, suffix_num):
623    loc = code[line]
624    if isinstance(loc, bitexpr.BitAssign):
625        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
626    elif isinstance(loc, bitexpr.Reduce):
627        pass
628    elif isinstance(loc, bitexpr.WhileLoop):
629        loc.control_expr.var.varname += "_%i"%suffix_num
630    return loc
631
632def parse_var(var):
633    index=var.find('.')
634    if index >= 0:
635        return ('struct', var[0:index], var[index+1:])
636               
637    index = var.find('[')
638    if index >= 0:
639        right_index = var.find(']')
640        return ('array', var[0:index], var[index+1:right_index])       
641   
642    if var.startswith("carry") or var.startswith("Carry"):
643        return('int', var, None)
644   
645    return ('bitblock', var, None)
646
647def simplify_name(var):
648    (vartype, name, extra) = parse_var(var)
649    if vartype == 'bitblock':
650        return name
651    if vartype == 'array':
652        return "a_%s_%s"%(name, extra)
653    if vartype == 'struct':
654        return "s_%s_%s"%(name, extra)
655    if vartype == 'int':
656        assert(1==0)
657    assert(1==0)
658
659def pairs(lst):
660    if lst == []:
661            return []
662    return zip(lst,lst[1:]+[lst[0]])
663
664
665def make_SSA(code, st):
666    new_vars = []
667    total_lines = len(code)
668    for var in st:
669        st[var][0].append(total_lines)
670        st[var][1].append(total_lines)
671
672    unique_number = 0
673    for var in st:
674        use_index = 0
675        for current, next in pairs(st[var][0])[0:-2]:
676            code[current] = update_def(code, var, current, unique_number)
677            uline = st[var][1][use_index]
678
679            while uline <= next and uline < total_lines:
680                if uline > current:
681                    code[uline] = update_use(code, var, uline, unique_number)
682                use_index += 1
683                uline = st[var][1][use_index]
684            unique_number += 1
685    return code
686
687#################################################################################################################
688## Breaking the code into basic blocks depending on where optimize pragma appears --- Pass 4
689#################################################################################################################
690def get_opt_list(s):
691    opt_list = []
692    remove_index = []
693    for line, stmt in enumerate(s):
694        if isinstance(stmt, bitexpr.Reduce):
695            opt_pair = (stmt.target, stmt.replace)
696            opt_list.append(opt_pair)
697            remove_index.append(line)
698    for ind in reversed(remove_index):
699        del s[ind]
700
701    return opt_list
702
703def break_to_segs(s, breaks):
704    segs = []
705    breaks = [-1]+breaks+[len(s)]
706    for start, stop in pairs(breaks)[:-1]:
707        next = s[start+1:stop]
708        if len(next)>0:
709            segs.append(next)
710    return segs
711
712def extract_pragmas(s):
713    targets = []
714    exprs = []
715    lines = []
716
717    #extracting pragmas and their targets
718    for line, loc in enumerate(s):
719        if isinstance(loc, bitexpr.Reduce):
720            targets.append(loc.target)
721            exprs.append(loc)
722            lines.append(line)
723
724    #removing pragmas from the source code
725    for i in reversed(lines):
726        del s[i]
727
728    return targets, exprs
729
730def get_defs(targets, s):
731    temp = [(x,-1) for x in targets]
732    targ_dic = dict(temp)
733    for line, loc in enumerate(s):
734        if loc.LHS.varname in targets:
735            targ_dic[loc.LHS.varname] = line
736
737    items = targ_dic.items()
738    items = sorted(items, key=(lambda x: x[1]))
739
740    lineno = [key for value, key in items]
741    sorted_targets = [value for value, key in items]
742
743    return sorted_targets, lineno
744
745def defines(varname, stmts):
746    is_defined = False
747    for loc in stmts:
748        if isinstance(loc, bitexpr.WhileLoop):
749            is_defined |= defines(varname, loc.stmts)
750        elif isinstance(loc, bitexpr.BitAssign): 
751            if loc.LHS.varname == varname:
752                is_defined = True
753
754    return is_defined
755
756def adjust_earliest(s, earliest, varname):
757    if earliest == -1:
758        return -1
759
760    line = earliest+1
761    for stmt in s[earliest+1:]:
762        if isinstance(stmt, bitexpr.WhileLoop):
763            if defines(varname, stmt.stmts):
764               earliest = line+1 
765        line+=1
766    return earliest
767
768def chop_code(s, start, stop):
769    line = 0
770
771    if start == -1:
772        cut1 = -1
773
774    if stop == len(s):
775        cut2 = len(s)
776
777    for index, stmt in enumerate(s):
778        if isinstance(stmt, bitexpr.BitAssign):
779            if line == start:
780                cut1 = index
781            if line == stop:
782                cut2 = index
783            line += 1
784
785        if isinstance(stmt, bitexpr.WhileLoop):
786            if line == start:
787                cut1 = index
788            if line  == stop:
789                cut2 = index
790            line+=1
791
792    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
793
794def count_lines(code):
795    cnt = len(code)
796    for stmt in code:
797        if isinstance(stmt, bitexpr.WhileLoop):
798            cnt += count_lines(stmt.stmts)
799    return cnt
800
801
802def gen_cond_obj(opt_pragma):
803    if opt_pragma[1] == 'AllZero' or opt_pragma[1] == 'allzero':
804        cond_obj = bitexpr.isAllZero(opt_pragma[0])
805    if opt_pragma[1] == 'AllOne' or opt_pragma[1] == 'allone':
806        cond_obj = bitexpr.isAllOne(opt_pragma[0])
807    return cond_obj
808
809def gen_bb(s, opt_list, def_line):
810    assert (len(opt_list) == len(def_line))
811    if opt_list==[]:
812        return s
813
814    earliest = min(def_line)
815    latest = len(s)
816    earliest_index = def_line.index(earliest)
817    early_varname = opt_list[earliest_index][0]
818   
819    earliest = adjust_earliest(s, earliest, early_varname)
820
821    first, second, third = chop_code(s, earliest, latest)
822
823    cond_obj = gen_cond_obj(opt_list[earliest_index])
824   
825    del opt_list[earliest_index]
826    del def_line[earliest_index]
827    def_line = [x-earliest-1 for x in def_line]
828   
829    second = gen_bb(second, opt_list, def_line)
830
831    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))] + third
832    return result
833
834def partition2bb(s):
835    basicblocks = []
836    lineno = []
837    def_line = []
838    opt_list = get_opt_list(s)
839
840    st2 = gen_sym_table(s)
841
842    for item in opt_list:
843        if len(st2[item[0]][0]) > 0:
844            def_line.append(st2[item[0]][0][-1])
845        else:
846            def_line.append(-1)
847
848
849    return gen_bb(s, opt_list, def_line)
850
851#################################################################################################################
852## Generating declarations for variables. This pass is based on the syntax of C programming language ---
853## Normalizing the code
854#################################################################################################################
855
856def normalize(s, predec = {}, ccelim=True):
857    if len(s)==0:
858        return []
859    if isinstance(s[0], bitexpr.If):
860        gc = basic_block.BasicBlock.gensym_counter
861        cc = basic_block.BasicBlock.carry_counter
862        bc = basic_block.BasicBlock.brw_counter
863        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
864        maxgc = basic_block.BasicBlock.gensym_counter
865        maxcc = basic_block.BasicBlock.carry_counter
866        maxbc = basic_block.BasicBlock.brw_counter
867        basic_block.BasicBlock.gensym_counter = gc
868        basic_block.BasicBlock.carry_counter = cc
869        basic_block.BasicBlock.brw_counter = bc
870        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
871        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
872        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
873        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
874       
875        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
876    if isinstance(s[0], bitexpr.WhileLoop):
877        orig = copy.deepcopy(predec)
878        s[0].stmts = normalize(s[0].stmts, predec, False)
879        return [s[0]]+normalize(s[1:], {})
880    if isinstance(s[0], bitexpr.BitAssign):
881        code, next = recurse_forward(s)
882        bb = basic_block.BasicBlock(predec)
883        predec.update(bb.normalize(code, ccelim))
884        return bb.get_code() + normalize(next, copy.deepcopy(predec))
885
886#################################################################################################################
887## Generating declarations for variables. This pass is based on the syntax of C programming language --- Pass 5
888#################################################################################################################
889def get_vars(stmt):
890        ints = set([])
891        bitblocks = set(['AllOne', 'AllZero'])
892        arrays = {}
893        structs = {}
894
895        #for stmt in code:
896
897        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
898       
899
900
901        if isinstance(stmt.RHS, bitexpr.Add): 
902                ints.add(stmt.RHS.carry)
903        if isinstance(stmt.RHS, bitexpr.Sub):
904                ints.add(stmt.RHS.brw)
905
906        for var in all_vars:
907                if isinstance(var, bitexpr.Const):
908                    continue
909               
910                (var_type, name, extra) = parse_var(var.varname)
911
912                if (var_type == "bitblock"):
913                        bitblocks.add(name)
914
915                if var_type == "array":
916                        if not name in arrays:
917                                arrays[name]= extra
918                        else:
919                                arrays[name] = max(arrays[name], extra)
920
921                if var_type == "struct":
922                        if not name in structs:
923                                structs[name] = set([extra])
924                        else:
925                                structs[name].add(extra)
926                if var_type == "int":
927                    ints.add(name)
928                       
929        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
930
931def merge_var_dic(first, second):
932    res = {}
933    res['int'] = first['int'].union(second['int'])
934    res['bitblock'] = first['bitblock'].union(second['bitblock'])
935
936    res['array'] = first['array']
937    temp = dict({'array':copy.copy(second['array'])})
938    for i in res['array']:
939        if i in temp['array']:
940            res['array'][i] = max(res['array'][i], temp['array'][i])
941            del temp['array'][i]
942    res['array'].update(temp['array'])
943
944
945    res['struct'] = first['struct']
946    temp = dict({'struct':copy.copy(second['struct'])})
947    for i in res['struct']:
948        if i in temp['struct']:
949            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
950            del temp['struct'][i]
951    res['struct'].update(temp['struct'])
952    return res
953
954def gen_output(var_dic, int_carry):
955        s = ''
956        for i in var_dic['bitblock']:
957                if i == AllOne:
958                        s += "BitBlock %s = simd_const_1(1);\n"%i
959                elif i == AllZero:
960                        s+="BitBlock %s = simd_const_1(0);\n"%i
961                elif i == "EOF_mask":
962                        s+="BitBlock EOF_mask = simd_const_1(1);\n"
963                else:
964                        s+="BitBlock %s;\n"%i
965
966        for i in var_dic['array']:
967                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
968
969        for i in var_dic['struct']:
970                s+="struct __%s__{\n"%i
971                for j in var_dic['struct'][i]:
972                        s+= "\tBitBlock %s;\n"%j
973                s+="};\n"
974                s+="struct __%s__ %s;\n"%(i,i)
975
976        for i in  var_dic['int']:
977                if int_carry:
978                        s+="CarryType %s = Carry0;\n"%i
979                else:
980                        s+="CarryType %s = Carry0;\n"%i
981
982
983
984        return s
985
986def gen_var_dic(s):
987    if len(s)==0:
988        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
989    if isinstance(s[0], bitexpr.BitAssign):
990        vd = get_vars(s[0])
991        more = gen_var_dic(s[1:])
992        vd = merge_var_dic(vd, more)
993        return vd
994
995    if isinstance(s[0], bitexpr.If):
996        vd1 = gen_var_dic(s[0].true_branch)
997        vd2 = gen_var_dic(s[0].false_branch)
998        vd = merge_var_dic(vd1, vd2)
999        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1000        vd =  merge_var_dic(vd, vd3)
1001        more = gen_var_dic(s[1:])
1002        vd  = merge_var_dic(vd, more)
1003        return vd
1004
1005    if isinstance(s[0], bitexpr.WhileLoop):
1006        vd = gen_var_dic(s[0].stmts)
1007        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
1008        vd =  merge_var_dic(vd, vd1)
1009        more = gen_var_dic(s[1:])
1010        vd  = merge_var_dic(vd, more)
1011        return vd
1012
1013def gen_declarations(s, int_carry):
1014    vd = gen_var_dic(s)
1015    return gen_output(vd, int_carry)
1016#################################################################################################################
1017## This class replaces all occurrences of a reduced variable to its value --- Pass 7
1018## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
1019## *** the new AST notation                                                                                   ***
1020#################################################################################################################
1021def replace_in_rhs(rhs, target, replace):
1022    if isinstance(rhs, bitexpr.Const):
1023        return rhs
1024    if isinstance(rhs, bitexpr.Var):
1025        if rhs.varname == target:
1026            rhs = replace
1027            return replace
1028
1029    if isinstance(rhs.operand1, bitexpr.Var):
1030        if rhs.operand1.varname == target:
1031            rhs.operand1 = replace
1032    else:
1033        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
1034            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
1035
1036    if isinstance(rhs.operand2, bitexpr.Var):
1037        if rhs.operand2.varname == target:
1038            rhs.operand2 = replace
1039    else:
1040        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
1041            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
1042    return rhs
1043
1044def apply_single_opt(code, target, replace):
1045    if len(code) == 0:
1046        return []
1047
1048    if replace=='AllZero':
1049        replace = bitexpr.FalseLiteral()
1050    if replace == 'AllOne':
1051        replace = bitexpr.TrueLiteral()
1052
1053    if isinstance(code[0], bitexpr.BitAssign):
1054        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
1055
1056    if isinstance(code[0], bitexpr.If):
1057        if code[0].control_expr.var.varname == target:
1058            code[0].control_expr.var = replace
1059        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
1060        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
1061
1062    if isinstance(code[0], bitexpr.WhileLoop):
1063        if code[0].control_expr.var.varname == target:
1064            code[0].control_expr.var = replace
1065        #code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
1066
1067    return [code[0]]+apply_single_opt(code[1:], target, replace)
1068
1069def apply_all_opt(s):
1070    if len(s) == 0:
1071        return []
1072    elif isinstance(s[0], bitexpr.If):
1073        target = s[0].control_expr.var.varname
1074        replace = s[0].control_expr.val
1075        apply_single_opt(s[0].true_branch, target, replace)
1076        apply_all_opt(s[0].true_branch)
1077        apply_all_opt(s[0].false_branch)
1078
1079    elif isinstance(s[0], bitexpr.WhileLoop):
1080        apply_all_opt(s[0].stmts)
1081
1082    return [s[0]]+apply_all_opt(s[1:])
1083 
1084#################################################################################################################
1085## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation. --- Pass 8
1086## method prune is supposed to remove unreachable branches but it is incomplete
1087#################################################################################################################
1088
1089def prune(fixed, tree):
1090    """removes unreachable branches of the tree"""
1091    target = tree[0].target
1092    replace = tree[0].replace
1093    if target in fixed:
1094        if replace == 'allone':
1095            if isinstance(fixed[target], bitexpr.TrueLiteral):
1096                tree[0] = None
1097                tree[1].join(tree[2][1])
1098                del tree[3]
1099                del tree[2]
1100                fixed.update(tree[1].simplify(fixed))
1101            if isinstance(fixed[target], bitexpr.FalseLiteral):
1102                tree[0] = None
1103                tree[1].join(tree[3][1])
1104                del tree[3]
1105                del tree[2]
1106                fixed.update(tree[1].simplify(fixed))
1107        if replace == 'allzero':
1108            if isinstance(fixed[target], bitexpr.FalseLiteral):
1109                tree[0] = None
1110                tree[1].join(tree[2][1])
1111                del tree[3]
1112                del tree[2]
1113                fixed.update(tree[1].simplify(fixed))
1114            if isinstance(fixed[target], bitexpr.TrueLiteral):
1115                tree[0] = None
1116                tree[1].join(tree[3][1])
1117                del tree[3]
1118                del tree[2]
1119                fixed.update(tree[1].simplify(fixed))
1120
1121    empty_list = []
1122    for index, branch in enumerate(tree[2:]):
1123        if branch[0] is None:
1124            if branch[1].code == []:
1125                empty_list.append(2+index)
1126
1127    empty_list.reverse()
1128    for i in empty_list:
1129        del tree[i]
1130def filter_fixed(fixed, stmts):
1131    for loc in stmts:
1132        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
1133            del fixed[loc.LHS.varname]
1134        if isinstance(loc, bitexpr.If):
1135            fixed = filter_fixed(fixed, loc.true_branch)
1136            fixed = filter_fixed(fixed, loc.false_branch)
1137        if isinstance(loc, bitexpr.WhileLoop):
1138            fixed = filter_fixed(fixed, loc.stmts)
1139    return fixed
1140
1141def simplify_tree(code, fixed = {}, prev = []):
1142    if len(code) == 0:
1143        return []
1144
1145    assumptions = {}
1146
1147    if isinstance(code[0], bitexpr.BitAssign):
1148        this, next = recurse_forward(code)
1149        fixed.update(basic_block.simplify(this, fixed))
1150        return this+simplify_tree(next, fixed, this)
1151
1152    elif isinstance(code[0], bitexpr.If):
1153        fixed1 = copy.deepcopy(fixed)
1154        fixed2 = copy.deepcopy(fixed)
1155        if isinstance(code[0].control_expr, bitexpr.isAllOne):
1156            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
1157        if isinstance(code[0].control_expr, bitexpr.isAllZero):
1158            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
1159        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
1160        fixed1.update(assumptions)
1161        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
1162        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
1163        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1164
1165    elif isinstance(code[0], bitexpr.WhileLoop):
1166        fixed = filter_fixed(fixed, code[0].stmts)
1167        fixed1 = copy.deepcopy(fixed)
1168        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
1169        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
1170
1171#################################################################################################################
1172## Dead Code Elimination
1173#################################################################################################################
1174def get_effective_name(varname):
1175    if varname.startswith("array") or varname.startswith("strct"):
1176        try:
1177            ind = varname.index("__")
1178            return varname[6:ind]
1179        except:
1180            return varname
1181    return varname
1182
1183def check_loc(loc, must_liv):
1184
1185    effective_name = get_effective_name(loc.LHS.varname)
1186    if effective_name in must_liv or loc.LHS.varname in must_liv:
1187        if isinstance(loc.RHS, bitexpr.Not):
1188            return set([loc.RHS.operand1.varname]), False
1189        elif isinstance(loc.RHS, bitexpr.Var):
1190            return set([loc.RHS.varname]), False
1191        elif isinstance(loc.RHS, bitexpr.FalseLiteral):
1192            return set([]), False
1193        elif isinstance(loc.RHS, bitexpr.TrueLiteral):
1194            return set([]), False
1195        elif isinstance(loc.RHS, bitexpr.Const):
1196            return set([]), False
1197        else:
1198            return set([loc.RHS.operand1.varname, loc.RHS.operand2.varname]), False
1199    else:
1200        return set([]), True
1201
1202def remove_copies(bb):
1203    """removes all copy statements e.g. var1 = var2"""
1204    lhs = [x.LHS.varname for x in bb]
1205    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
1206    for i in rhs:
1207        if i[1] in lhs:
1208            line = lhs.index(i[1])
1209            if i[0] > line:
1210                bb[i[0]].RHS = bb[line].RHS
1211    return bb
1212
1213def remove_dead(bb, must_live):
1214    bb = remove_copies(bb)
1215    my_alives = set([])
1216    dead = []
1217
1218    for line, loc in reversed(list(enumerate(bb))):
1219        new_lives, removable = check_loc(loc, my_alives.union(must_live))
1220
1221        if removable:
1222            dead.append(line)
1223        else:
1224            my_alives = my_alives.union(new_lives)
1225
1226    for i in dead:
1227        del bb[i]
1228
1229    return my_alives, bb
1230
1231def get_loop_variables(code):
1232    must_live = set([])
1233    for loc in code:
1234        if isinstance(loc, bitexpr.WhileLoop):
1235            if not loc.carry_expr is None:
1236                must_live.add(loc.carry_expr.var.varname)
1237        if isinstance(loc, bitexpr.If):
1238            true_lives = get_loop_variables(loc.true_branch)
1239            false_lives = get_loop_variables(loc.false_branch)
1240            must_live = must_live.union(true_lives)
1241            must_live = must_live.union(false_lives)
1242    return must_live
1243
1244def eliminate_dead_code(tree, must_live):
1245    if len(tree) == 0:
1246        return must_live, []
1247
1248    last = len(tree) - 1
1249    new_alives = set([])
1250    bb = []
1251
1252    must_live = must_live.union(get_loop_variables(tree))
1253
1254    if isinstance(tree[-1], bitexpr.BitAssign):
1255        first = 0
1256        for i in reversed(range(len(tree))):
1257            if not isinstance(tree[i], bitexpr.BitAssign):
1258                first = i+1
1259                break
1260
1261        new_alives, bb = remove_dead(tree[first:], must_live)
1262        last = first
1263
1264    elif isinstance(tree[-1], bitexpr.If):
1265        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
1266        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
1267        bb = [tree[-1]]
1268        new_alives.add(tree[-1].control_expr.var.varname)
1269
1270    elif isinstance(tree[-1], bitexpr.WhileLoop):
1271        must_live.add(tree[-1].control_expr.var.varname)
1272        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
1273        bb = [tree[-1]]
1274
1275    all_alives = new_alives.union(must_live)
1276    dummy, new_tree = eliminate_dead_code(tree[:last], all_alives)
1277    tree = new_tree+bb
1278    return all_alives.union(dummy), tree
1279
1280#################################################################################################################
1281## This pass processes the code in the while loop and adds extra code required for handling carry variables
1282#################################################################################################################
1283carry_suffix = "_i"
1284
1285def fix_the_loop(loop, carry_type):
1286    carries = []
1287    for loc in loop.stmts:
1288        if isinstance(loc.RHS, bitexpr.Add):
1289            carries.append(loc.RHS.carry)
1290    for item in carries:
1291        newvar = bitexpr.Var(item+carry_suffix)
1292        loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), carry_type)))
1293    for item in carries:
1294        loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral(carry_type) ))
1295
1296    return loop, carries
1297
1298def process_while_loops(code, int_carry):
1299    all = []
1300    update = {}
1301   
1302    if int_carry:
1303        carry_type = "int"
1304    else:
1305        carry_type = "vector"
1306   
1307    for index, loc in enumerate(code):
1308        if isinstance(loc, bitexpr.WhileLoop):
1309            update[index] = fix_the_loop(loc, carry_type)
1310            for i in update[index][1]:
1311                all.append(i)
1312                all.append(i+carry_suffix)
1313
1314        if isinstance(loc, bitexpr.If):
1315            loc.true_branch, live_list = process_while_loops(loc.true_branch)
1316            loc.false_branch, live_list = process_while_loops(loc.false_branch)
1317
1318    keys = [k for k in update]
1319    keys.sort(reverse=True)
1320    for key in keys:
1321        code[key] = update[key][0]
1322        carry_variable = bitexpr.Var("CarryTemp"+str(key))
1323        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral(carry_type)))
1324        for item in update[key][1][2:]:
1325            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), carry_type)))
1326        if len(update[key][1]) == 1:
1327            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
1328        elif len(update[key][1]) > 1:
1329            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), carry_type)))
1330        for item in update[key][1]:
1331            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral(carry_type)))
1332        for item in update[key][1]:
1333            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
1334
1335        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
1336    return code, all
1337#################################################################################################################
1338## This pass factors out the code that is common between the true branch and false branch of an if
1339## statement.
1340#################################################################################################################
1341
1342def are_the_same(one, two):
1343    if one.__class__ != two.__class__:
1344        return False
1345   
1346    if isinstance(one, bitexpr.BitAssign):
1347        if (one.LHS.varname != two.LHS.varname):
1348            return False
1349        if one.RHS.__class__ != two.RHS.__class__:
1350            return False
1351        if isinstance(one.RHS, bitexpr.FalseLiteral):
1352            return True
1353        if isinstance(one.RHS, bitexpr.TrueLiteral):
1354            return True
1355        if isinstance(one.RHS, bitexpr.Var):
1356            return (one.RHS.varname == two.RHS.varname)
1357        if isinstance(one.RHS, bitexpr.Const):
1358            return one.RHS.n == two.RHS.n
1359        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
1360
1361    if isinstance(one, bitexpr.If):
1362        if (one.control_expr.__class__ != two.control_expr.__class__):
1363            return False
1364        if (one.control_expr.var.varname != two.control_expr.var.varname):
1365            return False
1366        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
1367            return False
1368        for ind in range(len(one.true_branch)):
1369            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
1370                return False
1371        for ind in range(len(one.false_branch)):
1372            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
1373                return False
1374        return True
1375
1376    if isinstance(one, bitexpr.WhileLoop):
1377        if (one.control_expr.__class__ != two.control_expr.__class__):
1378            return False
1379        if (one.control_expr.var.varname != two.control_expr.var.varname):
1380            return False
1381        if (one.carry_expr.__class__ != two.carry_expr.__class__):
1382            return False
1383        #if (one.carry_expr.var.varname != two.carry_expr.var.varname):
1384        #    return False
1385        if (len(one.stmts) != len(two.stmts)):
1386            return False
1387        for ind in range(len(one.stmts)):
1388            if not are_the_same(one.stmts[ind], two.stmts[ind]):
1389                return False
1390        return True
1391
1392def get_factorable(cond):
1393    earliest = 0
1394    l = min(len(cond.true_branch), len(cond.false_branch))
1395    for index in reversed(range(-l,0)):
1396        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
1397            earliest = index
1398        else:
1399            break
1400
1401    common_length = -earliest
1402    return common_length
1403
1404def get_common_code(cond, common):
1405    true_len = len(cond.true_branch)-common
1406    false_len = len(cond.false_branch)-common
1407
1408    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
1409    common = cond.true_branch[true_len:]
1410
1411    return [new_cond], common
1412
1413def do_factorization(code, pos):
1414    indices = [key for key in pos]
1415    indices.sort()
1416
1417    for index in reversed(indices):
1418        new_cond, common = get_common_code(code[index], pos[index])
1419        code = code[:index]+new_cond+common+code[index+1:]
1420    return code
1421
1422def factor_out(code):
1423    for index, loc in enumerate(code):
1424        if isinstance(loc, bitexpr.If):
1425            code[index].true_branch = factor_out(code[index].true_branch)
1426            code[index].false_branch = factor_out(code[index].false_branch)
1427        if isinstance(loc, bitexpr.WhileLoop):
1428            code[index].stmts = factor_out(code[index].stmts)
1429
1430    pos = {}
1431    for index, loc in enumerate(code):
1432        if isinstance(loc, bitexpr.If):
1433            common_length = get_factorable(code[index])
1434            pos[index] = common_length
1435
1436    return do_factorization(code, pos)
1437
1438#################################################################################################################
1439## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 10
1440## ***Changes needed here are the same as Pass 7 and Pass 10***
1441#################################################################################################################
1442
1443def generate_condition(expr):
1444    if isinstance(expr, bitexpr.isAllZero):
1445        return "!bitblock_has_bit(%s)"%(expr.var.varname)
1446    elif isinstance(expr, bitexpr.isAllOne):
1447        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
1448    elif isinstance(expr, bitexpr.isNoneZero):
1449        return "bitblock_has_bit(%s)"%(expr.var.varname)
1450    else:
1451        assert (1==0)
1452
1453def generate_statement(expr, indent, cond_stmt):
1454    if isinstance(expr, bitexpr.isAllZero):
1455        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1456    elif isinstance(expr, bitexpr.isAllOne):
1457        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1458    elif isinstance(expr, bitexpr.isNoneZero):
1459        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
1460    else:
1461        assert (1==0)
1462
1463def print_block_stmts(s, int_carry, indent = 0):
1464    indent_unit = 4
1465    code = ""
1466    if len(s) == 0:
1467        return ""
1468    if isinstance(s[0], bitexpr.If):
1469        code = generate_statement(s[0].control_expr, indent, "if")
1470        code += print_block_stmts(s[0].true_branch, int_carry, indent+indent_unit)
1471        code += " "*indent+"}\n"
1472        code += " "*indent+"else {\n"
1473        code += print_block_stmts(s[0].false_branch, int_carry, indent+indent_unit)
1474        code += " "*indent+"}\n"
1475
1476    elif isinstance(s[0], bitexpr.WhileLoop):
1477        code = "\n%s((%s)||(test_carry(%s))) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname)
1478        code += print_block_stmts(s[0].stmts, int_carry, indent+indent_unit)
1479        code += " "*indent+"}\n"
1480    elif (isinstance(s[0].RHS, bitexpr.Const)):
1481        pass
1482 
1483    elif isinstance(s[0], bitexpr.BitAssign) and not (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1484
1485        code += " "*indent + s[0].LHS.varname
1486        code += " = "
1487
1488        if s[0].RHS.data_type == "vector":
1489            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
1490                code += s[0].RHS.operand1.varname
1491                code += ";\n"
1492            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
1493                code += s[0].RHS.operand1.varname
1494                code += ";\n"
1495            elif isinstance(s[0].RHS, bitexpr.Var):
1496                code += s[0].RHS.operand1.varname
1497                code += ";\n"
1498
1499            else:
1500                code += s[0].RHS.op_C + "("
1501                code += s[0].RHS.operand1.varname
1502                code += ','
1503                code += s[0].RHS.operand2.varname
1504                code += ");\n"
1505        elif s[0].RHS.data_type == "int":
1506            if isinstance(s[0].RHS, bitexpr.Or):
1507                code += "carry_or(%s, %s);\n" % (s[0].RHS.operand1.varname, s[0].RHS.operand2.varname)
1508            elif isinstance(s[0].RHS, bitexpr.FalseLiteral):
1509                code += "Carry0;\n"
1510    elif isinstance(s[0], bitexpr.BitAssign) and (isinstance(s[0].RHS, bitexpr.Add) or isinstance(s[0].RHS, bitexpr.Sub)):
1511            if isinstance(s[0].RHS, bitexpr.Add) and (s[0].RHS.operand1.varname == s[0].RHS.operand2.varname):
1512                code += " "*indent + s[0].RHS.op_advance + "("
1513                code += s[0].RHS.operand1.varname
1514                code += ', '
1515                code += s[0].RHS.carry
1516                code += ", "
1517                code += s[0].LHS.varname
1518                code += ");\n"
1519            elif isinstance(s[0].RHS, bitexpr.Add):
1520                code += " "*indent + s[0].RHS.op_C + "("
1521                code += s[0].RHS.operand1.varname
1522                code += ', '
1523                code += s[0].RHS.operand2.varname
1524                code += ', '
1525                code += s[0].RHS.carry
1526                code += ", "
1527                code += s[0].LHS.varname
1528                code += ");\n"
1529            elif isinstance(s[0].RHS, bitexpr.Sub):
1530                code += " "*indent + s[0].RHS.op_C + "("
1531                code += s[0].RHS.operand1.varname
1532                code += ', '
1533                code += s[0].RHS.operand2.varname
1534                code += ', '
1535                code += s[0].RHS.brw
1536                code += ", "
1537                code += s[0].LHS.varname
1538                code += ");\n"
1539    s.pop(0)
1540    return code+print_block_stmts(s, int_carry, indent)
1541
1542
1543def print_stream_stmts(s, indent = 0):
1544    indent_unit = 4
1545    code = ""
1546    if len(s) == 0:
1547        return ""
1548    elif isinstance(s[0], bitexpr.BitAssign):
1549        if (isinstance(s[0].RHS, bitexpr.Const)):
1550            code += " "*indent + s[0].LHS.varname
1551            code += " = "
1552            code += "sisd_from_int(%i);"%s[0].RHS.n
1553            code += "\n"
1554
1555    s.pop(0)
1556    return code+print_stream_stmts(s, indent)
1557
1558#################################################################################################################
1559## Auxiliary Functions
1560#################################################################################################################
1561
1562def get_bb_of(code, index):
1563    if index < 0 or index >= len(code):
1564        return None
1565
1566    if isinstance(code[index], bitexpr.If):
1567        return code[index]
1568    if isinstance(code[index], bitexpr.WhileLoop):
1569        return code[index]
1570    if isinstance(code[index], bitexpr.BitAssign):
1571        last = index
1572        for loc in code[index+1:]:
1573            if isinstance(loc, bitexpr.BitAssign):
1574                last += 1
1575            else:
1576                break
1577 
1578        first = index
1579        for loc in reversed(code[:index]):
1580            if isinstance(loc, bitexpr.BitAssign):
1581                first -= 1
1582            else:
1583                break
1584
1585        return code[first:last+1]
1586
1587def get_previous_bb(code, index):
1588    if index < 0:
1589        return None
1590    if index >= len(code):
1591        return get_bb_of(code, len(code)-1)
1592    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
1593        return get_bb_of(code, index-1)
1594    if isinstance(code[index], bitexpr.BitAssign):
1595        for ind, loc in reversed(enumerate(code[:index])):
1596            if not isinstance(loc, bitexpr.BitAssign):
1597                return get_bb_of(code, ind)
1598        return get_bb_of(code, -1)
1599    assert(1==0)
1600
1601def recurse_forward(code):
1602    if isinstance(code[0], bitexpr.If):
1603        return code[0], code[1:]
1604    if isinstance(code[0], bitexpr.WhileLoop):
1605        return code[0], code[1:]
1606    if isinstance(code[0], bitexpr.BitAssign):
1607        last = 0
1608        for loc in code[1:]:
1609            if isinstance(loc, bitexpr.BitAssign):
1610                last += 1
1611            else:
1612                break
1613        return code[0:last+1], code[last+1:]
Note: See TracBrowser for help on using the repository browser.