Changeset 320 for proto


Ignore:
Timestamp:
Oct 20, 2009, 5:02:55 PM (10 years ago)
Author:
eamiri
Message:

clean up and restructuring the code

Location:
proto/parabix2/Compiler
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • proto/parabix2/Compiler/bitstream_compiler.py

    r319 r320  
    206206
    207207
    208 def gen_declarations(var_dic):
    209         s = ''
    210         for i in  var_dic['int']:
    211                 s+="int %s=0;\n"%i
    212 
    213         for i in var_dic['bitblock']:
    214                 if i == AllOne:
    215                         s += "BitBlock %s = simd_const_1(1);\n"%i
    216                 elif i==AllZero:
    217                         s+="BitBlock %s = simd_const_1(0);\n"%i
    218                 else:
    219                         s+="BitBlock %s;\n"%i
    220 
    221         for i in var_dic['array']:
    222                 s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
    223 
    224         for i in var_dic['struct']:
    225                 s+="struct __%s__{\n"%i
    226                 for j in var_dic['struct'][i]:
    227                         s+= "\tBitBlock %s;\n"%j
    228                 s+="};\n"
    229                 s+="struct __%s__ %s;\n"%(i,i)
    230 
    231         return s
    232 
    233 def extract_vars(code):
    234         ints = set([])
    235         bitblocks = set(['AllOne', 'AllZero'])
    236         arrays = {}
    237         structs = {}
    238 
    239         for stmt in code:
    240 
    241                 all_vars = [stmt.LHS]+stmt.RHS.vars
    242 
    243                 if isinstance(stmt.RHS, bitexpr.Add) or isinstance(stmt.RHS, bitexpr.Sub):
    244                         ints.add(all_vars.pop())
    245 
    246                 for var in all_vars:
    247                         #print var
    248                         (var_type, name, extra) = py2bitexpr.parse_var(var.varname)
    249 
    250                         if (var_type == "bitblock"):
    251                                 bitblocks.add(name)
    252 
    253                         if var_type == "array":
    254                                 if not name in arrays:
    255                                         arrays[name]= extra
    256                                 else:
    257                                         arrays[name] = max(arrays[name], extra)
    258 
    259                         if var_type == "struct":
    260                                 if not name in structs:
    261                                         structs[name] = set([extra])
    262                                 else:
    263                                         structs[name].add(extra)
    264                                 #print structs[name]
    265 
    266         return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
    267 
    268 def merge_var_dic(first, second):
    269     res = {}
    270     res['int'] = first['int'].union(second['int'])
    271     res['bitblock'] = first['bitblock'].union(second['bitblock'])
    272 
    273     res['array'] = first['array']
    274     temp = dict({'array':copy.copy(second['array'])})
    275     for i in res['array']:
    276         if i in temp['array']:
    277             res['array'][i] = max(res['array'][i], temp['array'][i])
    278             del temp['array'][i]
    279     res['array'].update(temp['array'])
    280 
    281    
    282     res['struct'] = first['struct']
    283     temp = dict({'struct':copy.copy(second['struct'])})
    284     for i in res['struct']:
    285         if i in temp['struct']:
    286             res['struct'][i] = res['struct'][i].union(temp['struct'][i])
    287             del temp['struct'][i]
    288     res['struct'].update(temp['struct'])
    289     return res
    290208################################################################
    291209## This class abstracts one basic block of code
     
    440358            fixed, changed = self.propagate_constant(previous)
    441359        return fixed
     360
    442361################################################################
    443 # This class abstracts the whole program
     362# Going through compilation passes one by one
    444363################################################################
    445 
    446 class Program:
    447     def __init__(self):
    448         self. reducer = Reducer()
    449         self.declar = ""
    450 
    451     def generate_if_stmt(self, expr, indent):
    452         target = expr.target
    453         replace = expr.replace
    454         if replace == "allzero":
    455             return "\n%sif (bitblock_has_bit(%s) == 0)  {\n"%(" "*indent, target)
    456         elif replace == "allone":
    457            return "\n%sif (bitblock_bit_count(%s) == 128) {\n"%(" "*indent, target)
    458         else:
    459             assert (1==0)
    460 
    461     def break_to_segs(self, s, breaks):
    462         segs = []
    463         breaks = [-1]+breaks+[len(s)]
    464         for start, stop in py2bitexpr.pairs(breaks)[:-1]:
    465             next = s[start+1:stop]
    466             if len(next)>0:
    467                 segs.append(next)
    468         return segs
    469 
    470     def unwind(self, tree, head = [], indent = 0):
    471         #TODO: There might be a bug here, previous declarations are not passed to CodeGenObject
    472         indent_unit = 4
    473         s = ""
    474         total = []
    475         if tree[0] is None:
    476                 return tree[1].showcode(indent)
    477         else:
    478                 bnum = len(tree) - 1 #number of branches
    479                 s += tree[1].showcode(indent)
    480                 if bnum == 1:
    481                     return s
    482                 s += self.generate_if_stmt(tree[0], indent)
    483                 indent += indent_unit
    484                 s+=self.unwind(tree[2], tree[1], indent)
    485                 indent -= indent_unit
    486                 s+=" "*indent+"}\n"
    487                 if bnum == 2:
    488                     return s
    489                 s+=" "*indent+"else {\n"
    490                 indent += indent_unit
    491                 s+=self.unwind(tree[3], tree[1], indent)
    492                 indent -= indent_unit
    493                 s+=" "*indent+"}\n"
    494                 if bnum == 3:
    495                     return s
    496                 s+=self.unwind(tree[4], tree[1], indent)
    497                 return s
    498     def partition2bb(self, s):
    499         """
    500         At this point we assume all expression in the input Python code
    501         are optimize statements
    502         """
    503         basicblocks = []
    504         exprs = []
    505         lineno = []
    506         s = py2bitexpr.translate_stmts(s)
    507 
    508         st = py2bitexpr.gen_sym_table(s)
    509         #print st
    510         s=py2bitexpr.make_SSA(s, st)
    511 
    512         for line, loc in enumerate(s):
    513             if isinstance(loc, bitexpr.Reduce):
    514                 lineno.append(line)
    515                 exprs.append(loc)
    516         code_segs = self.break_to_segs(s, lineno)
    517         previous = {}
    518 
    519         for seg in code_segs:
    520             bb = BasicBlock(previous)
    521             bb.normalize(seg)
    522             previous.update(bb.common_expression_map)
    523             basicblocks.append(bb)
    524 
    525         #gen_compound_sym_table(basicblocks)
    526         #make_SSA(basicblocks)
    527 
    528         #make_SSA(basicblocks)
    529         vd = extract_vars(basicblocks[0].code)
    530         for bb in basicblocks[1:]:
    531             more = extract_vars(bb.code)
    532             vd = merge_var_dic(vd, more)
    533         self.declar = gen_declarations(vd)
    534         return basicblocks, exprs
    535 
    536     def construct_tree(self, all_code, exprs):
    537         #print ">>>>>", len(all_code), len(exprs)
    538         #print len(all_code)
    539         #for i in all_code:
    540         #    print "  ", len(i)
    541         if all_code == []:
    542             dummy = BasicBlock()
    543             dummy.code = []
    544             return [None, dummy]
    545 
    546         if len(all_code) == 1+len(exprs):
    547             head = all_code.pop(0)
    548         elif len(all_code) == len(exprs):
    549             dummy = BasicBlock()
    550             dummy.code = []
    551             head = dummy
    552         else:
    553             assert (1==0)
    554 
    555         if exprs == []:
    556             return [None, head]
    557         first_cond = exprs.pop(0)
    558 
    559         first_def = (None, None) #Basic Block number and line number of first def
    560         last_use = (None, None) #Basic Block number and line number of last_use
    561 
    562         target = first_cond.target
    563 
    564         for bbn, bb in enumerate(all_code):
    565             fdef = bb.get_defs(target)[0]
    566             luse = bb.get_uses(target)[-1]
    567             if not fdef is None and first_def[0] is None:
    568                 first_def = (bbn, fdef)
    569             if not luse is None:
    570                 last_use = (bbn, luse)
    571 
    572         if first_def[0] is None and last_use[0] is None:
    573             cut = (None, None)
    574             dummy = BasicBlock()
    575             dummy.code = []
    576             #print ";;;;;;;;;;;;", len(all_code), len(exprs)
    577             return [first_cond, head, self.construct_tree(copy.deepcopy(all_code), copy.copy(exprs)), self.construct_tree(all_code, copy.copy(exprs)), [None, dummy]]
    578         else:
    579             if first_def[0] is None:
    580                 cut = last_use
    581             else:
    582                 cut = first_def
    583 
    584             #bb1, bb2 = all_code[cut[0]].split(cut[1]+1)
    585             #all_code[cut[0]] = bb1
    586             #all_code.insert(cut[0]+1, bb2)
    587             #new_exprs =
    588             #print "~~~~~", len(all_code), len(exprs)
    589             #print "cut[0] = ", cut[0]
    590             return [first_cond, head, self.construct_tree(copy.deepcopy(all_code[:cut[0]+1]), copy.copy(exprs[:cut[0]])), self.construct_tree(all_code[:cut[0]+1], copy.copy(exprs[:cut[0]])), self.construct_tree(all_code[cut[0]+1:], exprs[cut[0]:])]
    591 
    592     def prune(self, fixed, tree):
    593         """removes unreachable branches of the tree"""
    594         target = tree[0].target
    595         replace = tree[0].replace
    596         if target in fixed:
    597             if replace == 'allone':
    598                 if isinstance(fixed[target], bitexpr.TrueLiteral):
    599                     tree[0] = None
    600                     tree[1].join(tree[2][1])
    601                     del tree[3]
    602                     del tree[2]
    603                     fixed.update(tree[1].simplify(fixed))
    604                 if isinstance(fixed[target], bitexpr.FalseLiteral):
    605                     tree[0] = None
    606                     tree[1].join(tree[3][1])
    607                     del tree[3]
    608                     del tree[2]
    609                     fixed.update(tree[1].simplify(fixed))
    610             if replace == 'allzero':
    611                 if isinstance(fixed[target], bitexpr.FalseLiteral):
    612                     tree[0] = None
    613                     tree[1].join(tree[2][1])
    614                     del tree[3]
    615                     del tree[2]
    616                     fixed.update(tree[1].simplify(fixed))
    617                 if isinstance(fixed[target], bitexpr.TrueLiteral):
    618                     tree[0] = None
    619                     tree[1].join(tree[3][1])
    620                     del tree[3]
    621                     del tree[2]
    622                     fixed.update(tree[1].simplify(fixed))
    623 
    624         empty_list = []
    625         for index, branch in enumerate(tree[2:]):
    626             if branch[0] is None:
    627                 if branch[1].code == []:
    628                     empty_list.append(2+index)
    629 
    630         empty_list.reverse()
    631         for i in empty_list:
    632             del tree[i]
    633 
    634     def simplify_tree(self, tree, fixed = {}):
    635         fixed.update(tree[1].simplify(fixed))
    636         if tree[0] is None:
    637             return
    638         else:
    639             #self.prune(fixed, tree)
    640             for branch in tree[2:]:
    641                 self.simplify_tree(branch, copy.copy(fixed))
    642 
    643     def generate_code(self, s):
    644         bb, exprs = self.partition2bb(s)
    645         tree = self.construct_tree(bb, exprs)
    646         #tree = self.reducer.apply_all_opt(bb, exprs)
    647         #print tree
    648         tree = self.reducer.apply_all_opt(tree)
    649 
    650         #print tree
    651         #print
    652         #print
    653 
    654         self.simplify_tree(tree)
    655 
    656         #print tree
    657 
    658         s = self.unwind(tree)
    659         return self.declar+s
    660 
    661 ################################################################
    662 ## This class is an optimization pass
    663 ################################################################
    664 
    665 class Reducer:
    666     def __init__(self):
    667         pass
    668 
    669     def replace_in_rhs(self, expr, target, replace):
    670         #print target,replace
    671         #print ",,,,",expr.LHS.varname
    672         if replace=='allzero':
    673             replace = bitexpr.FalseLiteral()
    674         if replace == 'allone':
    675             replace = bitexpr.TrueLiteral()
    676 
    677         if isinstance(expr.RHS, bitexpr.Var):
    678             if expr.RHS.varname == target:
    679                 expr.RHS = replace
    680             return
    681 
    682         if expr.RHS.operand1.varname == target:
    683             expr.RHS.operand1 = replace
    684 
    685         if expr.RHS.operand2.varname == target:
    686             expr.RHS.operand2 = replace
    687 
    688     def apply_single_opt(self, expr, tree):
    689 
    690         target = expr.target
    691         replace = expr.replace
    692         assert(replace=='allzero' or replace=='allone')
    693 
    694         for loc in tree[1].code:
    695             if loc.LHS.varname == target:
    696                 break
    697             self.replace_in_rhs(loc, target, replace)
    698         if tree[0] is None:
    699             return
    700         else:
    701             for branch in tree[2:]:
    702                 self.apply_single_opt(expr, branch)
    703 
    704         return
    705 
    706     def apply_all_opt(self, tree):
    707         """
    708           all_code is  list of basic blocks.
    709           exp is a list of optimize pragmas found in the code
    710         """
    711         if tree[0] is None:
    712             return tree
    713         else:
    714             self.apply_single_opt(tree[0], tree[2])
    715             tree = [tree[0], tree[1], self.apply_all_opt(tree[2]), self.apply_all_opt(tree[3]), self.apply_all_opt(tree[4])]
    716             return tree
    717 
    718         return [first_cond, head, self.apply_all_opt(optimized_code, copy.copy(exp)), self.apply_all_opt(all_code, exp)]
     364def generate_code(s):
     365
     366    #Pass 1
     367    s = ast.parse(s)
     368    s = s.body[0].body
     369
     370    #Pass 2
     371    s = py2bitexpr.translate_stmts(s)
     372
     373    #Pass 3
     374    st = py2bitexpr.gen_sym_table(s)
     375    s=py2bitexpr.make_SSA(s, st)
     376
     377    #Pass 4
     378    bb, exprs = py2bitexpr.partition2bb(s)
     379
     380    #Pass 5
     381    declarations = py2bitexpr.gen_declarations(bb)
     382
     383    #Pass 6
     384    tree = py2bitexpr.construct_tree(bb, exprs)
     385
     386    #Pass 7
     387    tree = py2bitexpr.Reducer().apply_all_opt(tree)
     388
     389    #Pass 8
     390    py2bitexpr.simplify_tree(tree)
     391
     392    #Pass 9
     393    s = py2bitexpr.unwind(tree)
     394    return declarations+s
     395
     396
    719397###############################################################
    720398if __name__ == '__main__':
    721         s=ast.parse(r"""def u8u16(u8, u8bit):
     399        s=r"""def u8u16(u8, u8bit):
    722400        u8.unibyte = (~u8bit[0]);
     401        optimize(u8.unibyte, allone)
    723402        u8.prefix = (u8bit[0] & u8bit[1]);
    724403        u8.prefix2 = (u8.prefix &~ u8bit[2]);
    725404        temp1 = (u8bit[2] &~ u8bit[3]);
    726405        u8.prefix3 = (u8.prefix & temp1);
    727         optimize(u8.prefix3, allzero)
     406        #optimize(u8.prefix3, allzero)
    728407        temp2 = (u8bit[2] & u8bit[3]);
    729408        u8.prefix4 = (u8.prefix & temp2);
    730         optimize(u8.prefix4, allzero)
     409        #optimize(u8.prefix4, allzero)
    731410        u8.suffix = (u8bit[0] &~ u8bit[1]);
    732411        temp3 = (u8bit[2] | u8bit[3]);
     
    754433        u8.x90_xBF = (u8.suffix & temp3);
    755434        u8.x80_x8F = (u8.suffix &~ temp3);
    756        
     435
    757436        u8.scope22 = bitutil.Advance(u8.prefix2)
    758437        u8.scope32 = bitutil.Advance(u8.prefix3)
     
    760439        u8.scope42 = bitutil.Advance(u8.prefix4)
    761440        u8.scope43 = bitutil.Advance(u8.scope42)
    762         optimize(u8.scope43, allzero)
     441        #optimize(u8.scope43, allzero)
    763442        u8.scope44 = bitutil.Advance(u8.scope43)
    764443        u8lastscope = u8.scope22 | u8.scope33 | u8.scope44
    765444        u8anyscope = u8lastscope | u8.scope32 | u8.scope42 | u8.scope43
    766        
     445
    767446        # C0-C1 and F5-FF are illegal
    768447        error_mask = u8.badprefix
    769        
     448
    770449        error_mask |= bitutil.Advance(u8.xE0) & u8.x80_x9F
    771450        error_mask |= bitutil.Advance(u8.xED) & u8.xA0_xBF
    772451        error_mask |= bitutil.Advance(u8.xF0) & u8.x80_x8F
    773452        error_mask |= bitutil.Advance(u8.xF4) & u8.x90_xBF
    774        
     453
    775454        error_mask |= u8anyscope ^ u8.suffix
    776455        u8.error = error_mask
    777        
     456
    778457        u8lastscope = u8.scope22 | u8.scope33 | u8.scope44
    779458        u8lastbyte = u8.unibyte | u8lastscope
    780         optimize(u8lastbyte, allzero)
     459        #optimize(u8lastbyte, allzero)
    781460        u16lo[2] = u8lastbyte & u8bit[2]
    782461        u16lo[3] = u8lastbyte & u8bit[3]
     
    787466        u16lo[1] = (u8.unibyte & u8bit[1]) | (u8lastscope & bitutil.Advance(u8bit[7]))
    788467        u16lo[0] = u8lastscope & bitutil.Advance(u8bit[6])
    789        
     468
    790469        u16hi[5] = u8lastscope & bitutil.Advance(u8bit[3])
    791470        u16hi[6] = u8lastscope & bitutil.Advance(u8bit[4])
     
    798477
    799478        u8surrogate = u8.scope43 | u8.scope44
    800         optimize(u8surrogate, allzero)
     479        #optimize(u8surrogate, allzero)
    801480        u16hi[0] = u16hi[0] | u8surrogate       
    802481        u16hi[1] = u16hi[1] | u8surrogate       
     
    825504
    826505        delmask = u8.prefix | u8.scope32 | u8.scope42
    827 """)
    828         s = s.body[0].body
    829         prog = Program()
    830         s = prog.generate_code(s)
    831         print s
     506"""
     507        print generate_code(s)
    832508
    833509#TODO: merge expr2simd() and a simplify_expr()
  • proto/parabix2/Compiler/py2bitexpr.py

    r319 r320  
    99# Requires ast module of python 2.6
    1010
    11 import ast, bitexpr, copy
     11import ast, bitexpr, copy, bitstream_compiler
    1212
    1313class PyBitError(Exception):
    1414        pass
     15
     16#############################################################################################
     17## Translation to bitExpr classes
     18#############################################################################################
     19
    1520
    1621def translate_index(ast_value):
     
    110115               
    111116               
    112 def arglist(value):
    113         if isinstance(value, ast.UnaryOp):
    114                 return (translate_var(value.operand))
    115         if isinstance(value, ast.BinOp):
    116                 return (translate_var(value.left), translate_var(value.right))
    117         if isinstance(value, ast.Call):
    118                 return value.args
    119         if isinstance(value, ast.Attribute):
    120                 return (translate_var(value.value))
    121         if isinstance(value, ast.Name):
    122                 return (translate_var(value))
    123 
    124         print "BUG: %s not supported!"%value.__class__
    125 
     117#############################################################################################
     118## Conversion to SSA form
     119#############################################################################################
    126120
    127121def extract_vars(rhs):
     
    227221                        unique_number += 1
    228222        return code
     223
     224#################################################################################################################
     225## Breaking the code into basic blocks depending on where optimize pragma appears
     226#################################################################################################################
     227def break_to_segs(s, breaks):
     228    segs = []
     229    breaks = [-1]+breaks+[len(s)]
     230    for start, stop in pairs(breaks)[:-1]:
     231        next = s[start+1:stop]
     232        if len(next)>0:
     233            segs.append(next)
     234    return segs
     235   
     236def partition2bb(s):
     237    basicblocks = []
     238    exprs = []
     239    lineno = []
     240   
     241    for line, loc in enumerate(s):
     242        if isinstance(loc, bitexpr.Reduce):
     243            lineno.append(line)
     244            exprs.append(loc)
     245    code_segs = break_to_segs(s, lineno)
     246    previous = {}
     247
     248    for seg in code_segs:
     249        bb = bitstream_compiler.BasicBlock(previous)
     250        bb.normalize(seg)
     251        previous.update(bb.common_expression_map)
     252        basicblocks.append(bb)
     253
     254    return basicblocks, exprs
     255
     256#################################################################################################################
     257## Generating declarations for variables. This pass is based on the syntax of C programming language
     258#################################################################################################################
     259def get_vars(code):
     260        ints = set([])
     261        bitblocks = set(['AllOne', 'AllZero'])
     262        arrays = {}
     263        structs = {}
     264
     265        for stmt in code:
     266
     267                all_vars = [stmt.LHS]+stmt.RHS.vars
     268
     269                if isinstance(stmt.RHS, bitexpr.Add) or isinstance(stmt.RHS, bitexpr.Sub):
     270                        ints.add(all_vars.pop())
     271
     272                for var in all_vars:
     273                        #print var
     274                        (var_type, name, extra) = parse_var(var.varname)
     275
     276                        if (var_type == "bitblock"):
     277                                bitblocks.add(name)
     278
     279                        if var_type == "array":
     280                                if not name in arrays:
     281                                        arrays[name]= extra
     282                                else:
     283                                        arrays[name] = max(arrays[name], extra)
     284
     285                        if var_type == "struct":
     286                                if not name in structs:
     287                                        structs[name] = set([extra])
     288                                else:
     289                                        structs[name].add(extra)
     290                                #print structs[name]
     291
     292        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
     293
     294def merge_var_dic(first, second):
     295    res = {}
     296    res['int'] = first['int'].union(second['int'])
     297    res['bitblock'] = first['bitblock'].union(second['bitblock'])
     298
     299    res['array'] = first['array']
     300    temp = dict({'array':copy.copy(second['array'])})
     301    for i in res['array']:
     302        if i in temp['array']:
     303            res['array'][i] = max(res['array'][i], temp['array'][i])
     304            del temp['array'][i]
     305    res['array'].update(temp['array'])
     306
     307   
     308    res['struct'] = first['struct']
     309    temp = dict({'struct':copy.copy(second['struct'])})
     310    for i in res['struct']:
     311        if i in temp['struct']:
     312            res['struct'][i] = res['struct'][i].union(temp['struct'][i])
     313            del temp['struct'][i]
     314    res['struct'].update(temp['struct'])
     315    return res
     316
     317def gen_output(var_dic):
     318        s = ''
     319        for i in  var_dic['int']:
     320                s+="int %s=0;\n"%i
     321
     322        for i in var_dic['bitblock']:
     323                if i == bitstream_compiler.AllOne:
     324                        s += "BitBlock %s = simd_const_1(1);\n"%i
     325                elif i==bitstream_compiler.AllZero:
     326                        s+="BitBlock %s = simd_const_1(0);\n"%i
     327                else:
     328                        s+="BitBlock %s;\n"%i
     329
     330        for i in var_dic['array']:
     331                s+="BitBlock %s[%i];\n"%(i, int(var_dic['array'][i])+1)
     332
     333        for i in var_dic['struct']:
     334                s+="struct __%s__{\n"%i
     335                for j in var_dic['struct'][i]:
     336                        s+= "\tBitBlock %s;\n"%j
     337                s+="};\n"
     338                s+="struct __%s__ %s;\n"%(i,i)
     339
     340        return s
     341
     342def gen_declarations(bb):
     343    vd = get_vars(bb[0].code)
     344    for block in bb[1:]:
     345        more = get_vars(block.code)
     346        vd = merge_var_dic(vd, more)
     347    declarations = gen_output(vd)
     348    return declarations
     349
     350#################################################################################################################
     351## Converting to condition-tree form
     352#################################################################################################################
     353def construct_tree(all_code, exprs):
     354    if all_code == []:
     355        dummy = bitstream_compiler.BasicBlock()
     356        dummy.code = []
     357        return [None, dummy]
     358
     359    if len(all_code) == 1+len(exprs):
     360        head = all_code.pop(0)
     361    elif len(all_code) == len(exprs):
     362        dummy = BasicBlock()
     363        dummy.code = []
     364        head = dummy
     365    else:
     366        assert (1==0)
     367
     368    if exprs == []:
     369        return [None, head]
     370    first_cond = exprs.pop(0)
     371
     372    first_def = (None, None) #Basic Block number and line number of first def
     373    last_use = (None, None) #Basic Block number and line number of last_use
     374
     375    target = first_cond.target
     376
     377    for bbn, bb in enumerate(all_code):
     378        fdef = bb.get_defs(target)[0]
     379        luse = bb.get_uses(target)[-1]
     380        if not fdef is None and first_def[0] is None:
     381            first_def = (bbn, fdef)
     382        if not luse is None:
     383            last_use = (bbn, luse)
     384
     385    ################### This few lines were temporary put here to remove the removal of code duplication in the output code#######################
     386    dummy = bitstream_compiler.BasicBlock()
     387    dummy.code = []
     388    return [first_cond, head, construct_tree(copy.deepcopy(all_code), copy.copy(exprs)), construct_tree(all_code, copy.copy(exprs)), [None, dummy]]
     389    ################### This few lines were temporary put here to remove the removal of code duplication in the output code#######################
     390    if first_def[0] is None and last_use[0] is None:
     391        cut = (None, None)
     392        dummy = bitstream_compiler.BasicBlock()
     393        dummy.code = []
     394        return [first_cond, head, construct_tree(copy.deepcopy(all_code), copy.copy(exprs)), construct_tree(all_code, copy.copy(exprs)), [None, dummy]]
     395    else:
     396        if first_def[0] is None:
     397            cut = last_use
     398        else:
     399            cut = first_def
     400        return [first_cond, head, construct_tree(copy.deepcopy(all_code[:cut[0]+1]), copy.copy(exprs[:cut[0]])), construct_tree(all_code[:cut[0]+1], copy.copy(exprs[:cut[0]])), construct_tree(all_code[cut[0]+1:], exprs[cut[0]:])]
     401
     402#################################################################################################################
     403## This class replaces all occurences of a reduced variable to its value
     404#################################################################################################################
     405class Reducer:
     406    def __init__(self):
     407        pass
     408
     409    def replace_in_rhs(self, expr, target, replace):
     410        #print target,replace
     411        #print ",,,,",expr.LHS.varname
     412        if replace=='allzero':
     413            replace = bitexpr.FalseLiteral()
     414        if replace == 'allone':
     415            replace = bitexpr.TrueLiteral()
     416
     417        if isinstance(expr.RHS, bitexpr.Var):
     418            if expr.RHS.varname == target:
     419                expr.RHS = replace
     420            return
     421
     422        if expr.RHS.operand1.varname == target:
     423            expr.RHS.operand1 = replace
     424
     425        if expr.RHS.operand2.varname == target:
     426            expr.RHS.operand2 = replace
     427
     428    def apply_single_opt(self, expr, tree):
     429
     430        target = expr.target
     431        replace = expr.replace
     432        assert(replace=='allzero' or replace=='allone')
     433
     434        for loc in tree[1].code:
     435            if loc.LHS.varname == target:
     436                break
     437            self.replace_in_rhs(loc, target, replace)
     438        if tree[0] is None:
     439            return
     440        else:
     441            for branch in tree[2:]:
     442                self.apply_single_opt(expr, branch)
     443
     444        return
     445
     446    def apply_all_opt(self, tree):
     447        """
     448          all_code is  list of basic blocks.
     449          exp is a list of optimize pragmas found in the code
     450        """
     451        if tree[0] is None:
     452            return tree
     453        else:
     454            self.apply_single_opt(tree[0], tree[2])
     455            tree = [tree[0], tree[1], self.apply_all_opt(tree[2]), self.apply_all_opt(tree[3]), self.apply_all_opt(tree[4])]
     456            return tree
     457
     458        return [first_cond, head, self.apply_all_opt(optimized_code, copy.copy(exp)), self.apply_all_opt(all_code, exp)]
     459#################################################################################################################
     460## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation.
     461## method prune is supposed to remove unreachable branches but it is incomplete
     462#################################################################################################################
     463def prune(fixed, tree):
     464    """removes unreachable branches of the tree"""
     465    target = tree[0].target
     466    replace = tree[0].replace
     467    if target in fixed:
     468        if replace == 'allone':
     469            if isinstance(fixed[target], bitexpr.TrueLiteral):
     470                tree[0] = None
     471                tree[1].join(tree[2][1])
     472                del tree[3]
     473                del tree[2]
     474                fixed.update(tree[1].simplify(fixed))
     475            if isinstance(fixed[target], bitexpr.FalseLiteral):
     476                tree[0] = None
     477                tree[1].join(tree[3][1])
     478                del tree[3]
     479                del tree[2]
     480                fixed.update(tree[1].simplify(fixed))
     481        if replace == 'allzero':
     482            if isinstance(fixed[target], bitexpr.FalseLiteral):
     483                tree[0] = None
     484                tree[1].join(tree[2][1])
     485                del tree[3]
     486                del tree[2]
     487                fixed.update(tree[1].simplify(fixed))
     488            if isinstance(fixed[target], bitexpr.TrueLiteral):
     489                tree[0] = None
     490                tree[1].join(tree[3][1])
     491                del tree[3]
     492                del tree[2]
     493                fixed.update(tree[1].simplify(fixed))
     494
     495    empty_list = []
     496    for index, branch in enumerate(tree[2:]):
     497        if branch[0] is None:
     498            if branch[1].code == []:
     499                empty_list.append(2+index)
     500
     501    empty_list.reverse()
     502    for i in empty_list:
     503        del tree[i]
     504
     505def simplify_tree(tree, fixed = {}):
     506    fixed.update(tree[1].simplify(fixed))
     507    if tree[0] is None:
     508        return
     509    else:
     510        #self.prune(fixed, tree)
     511        for branch in tree[2:]:
     512            simplify_tree(branch, copy.copy(fixed))
     513
     514
     515#################################################################################################################
     516## Generates C code given conditions-tree, by a recursive traversal of the tree
     517#################################################################################################################
     518def generate_if_stmt(expr, indent):
     519    target = expr.target
     520    replace = expr.replace
     521    if replace == "allzero":
     522        return "\n%sif (bitblock_has_bit(%s) == 0)  {\n"%(" "*indent, target)
     523    elif replace == "allone":
     524        return "\n%sif (bitblock_bit_count(%s) == 128) {\n"%(" "*indent, target)
     525    else:
     526        assert (1==0)
     527
     528
     529def unwind(tree, head = [], indent = 0):
     530    #TODO: There might be a bug here, previous declarations are not passed to CodeGenObject
     531    indent_unit = 4
     532    s = ""
     533    total = []
     534    if tree[0] is None:
     535            return tree[1].showcode(indent)
     536    else:
     537            bnum = len(tree) - 1 #number of branches
     538            s += tree[1].showcode(indent)
     539            if bnum == 1:
     540                return s
     541            s += generate_if_stmt(tree[0], indent)
     542            indent += indent_unit
     543            s+=unwind(tree[2], tree[1], indent)
     544            indent -= indent_unit
     545            s+=" "*indent+"}\n"
     546            if bnum == 2:
     547                return s
     548            s+=" "*indent+"else {\n"
     549            indent += indent_unit
     550            s+=unwind(tree[3], tree[1], indent)
     551            indent -= indent_unit
     552            s+=" "*indent+"}\n"
     553            if bnum == 3:
     554                return s
     555            s+=unwind(tree[4], tree[1], indent)
     556            return s
Note: See TracChangeset for help on using the changeset viewer.