Changeset 344


Ignore:
Timestamp:
Jan 6, 2010, 2:30:29 PM (10 years ago)
Author:
eamiri
Message:

Can compiler Parabix 2 now

Location:
proto/parabix2/Compiler
Files:
2 added
3 edited

Legend:

Unmodified
Added
Removed
  • proto/parabix2/Compiler/bitexpr.py

    r327 r344  
    1 #
    2 # bitexpr.py - bitstream expressions in Python
    3 #
    4 # (c) 2009 Robert D. Cameron with modifications by Ehsan Amiri
    5 # All rights reserved.
    6 # Licensed to International Characters, Inc. under Academic Free License 3.0
    7 #
    8 # Class BitExpr provides a model for symbolic expressions
    9 # involving bitwise operations on bitstreams.  Operations
    10 # include various bitwise logical operators, as well as
    11 # add and subtract.
    12 #
    131
    142class Pragma:
     
    219
    2210class BitExpr:
    23    """The BitExpr class and its subclasses provide a symbolic
     11    """The BitExpr class and its subclasses provide a symbolic
    2412      representation of Bitwise logical expressions.
    25    """
    26    pass
    27 
     13    """
     14    def __init__(self, op1, op2, op = "", data_type="vector"):
     15        self.op = op
     16        self.operand1 = op1
     17        self.operand2 = op2
     18        self.data_type = data_type
    2819class Var(BitExpr):
    2920    def __init__(self, varname):
    3021        self.varname = varname
    31         self.vars = [self]
     22        BitExpr.__init__(self, self, self)
    3223    def show(self): return 'Var("' + self.varname + '")'
    3324    def show_C(self):
     
    3728    def __init__(self):
    3829        self.value = True
    39         self.vars = []
    40         self.varname = 'allone'
     30        self.varname = 'AllOne'
     31        BitExpr.__init__(self, self, self)
    4132    def show(self): return 'T'
    4233    def show_C(self): return 'AllOne'
    4334
    4435class FalseLiteral(BitExpr):
    45     def __init__(self):
     36    def __init__(self, data_type="vector"):
    4637        self.value = False
    47         self.varname = 'allzero'
    48         self.vars = []
     38        self.varname = 'AllZero'
     39        BitExpr.__init__(self, self, self, data_type=data_type)
    4940    def show(self): return 'F'
    5041    def show_C(self):return 'AllZero'
     
    5243class Not(BitExpr):
    5344    def __init__(self, expr):
    54         self.operand = expr
    55         self.vars = [self.operand]
    56     def show(self): return 'Not(%s)' % (self.operand.show())
     45        #self.operand = expr
     46        BitExpr.__init__(self, expr, expr, "Not")
     47    def show(self): return 'Not(%s)' % (self.operand1.show())
    5748    def show_C(self): return "NOT IMPLEMENTED"
    5849
     
    6152        self.operand1 = expr1
    6253        self.operand2 = expr2
    63         self.vars = [self.operand1, self.operand2]
     54        self.op_C = "simd_and"
     55        BitExpr.__init__(self, expr1, expr2, "And")
    6456    def show(self):
    6557        return 'And(%s, %s)' % (self.operand1.show(), self.operand2.show())
     
    6860class Andc(BitExpr):
    6961    def __init__(self, expr1, expr2):
    70         self.operand1 = expr1
    71         self.operand2 = expr2
    72         self.vars = [self.operand1, self.operand2]
     62        self.op_C = "simd_andc"
     63        BitExpr.__init__(self, expr1, expr2, "Andc")
     64
    7365    def show(self): return 'Andc(%s, %s)' % (self.operand1.show(), self.operand2.show())
    7466    def show_C(self): return 'simd_andc(%s, %s)' % (self.operand1.show_C(), self.operand2.show_C())
    7567
    7668class Or(BitExpr):
    77     def __init__(self, expr1, expr2):
    78         self.operand1 = expr1
    79         self.operand2 = expr2
    80         self.vars = [self.operand1, self.operand2]
     69    def __init__(self, expr1, expr2, data_type="vector"):
     70        self.op_C = "simd_or"
     71        BitExpr.__init__(self, expr1, expr2, "Or", data_type)
    8172    def show(self): return 'Or(%s, %s)' % (self.operand1.show(), self.operand2.show())
    8273    def show_C(self): return 'simd_or(%s, %s)' % (self.operand1.show_C(), self.operand2.show_C())
     74
    8375class Xor(BitExpr):
    8476    def __init__(self, expr1, expr2):
    85         self.operand1 = expr1
    86         self.operand2 = expr2
    87         self.vars = [self.operand1, self.operand2]
     77        self.op_C = "simd_xor"
     78        BitExpr.__init__(self, expr1, expr2, "Xor")
     79
    8880    def show(self): return 'Xor(%s, %s)' % (self.operand1.show(), self.operand2.show())
    8981    def show_C(self): return 'simd_xor(%s, %s)' % (self.operand1.show_C(), self.operand2.show_C())
     
    9587        self.false_branch = expr3
    9688        self.vars = [self.operand1, self.operand2]
     89        BitExpr.__init__(self, "Sel")
    9790    def show(self): return 'Sel(%s, %s, %s)' % (self.sel.show(), self.true_branch.show(), self.false_branch.show())
    9891    def show_C(self): return 'Sel(%s, %s, %s)' % (self.sel.show_C(), self.true_branch.show_C(), self.false_branch.show_C())
     
    10093class Add(BitExpr):
    10194    def __init__(self, expr1, expr2, carry = None):
    102         self.operand1 = expr1
    103         self.operand2 = expr2
     95        self.op_C = "adc128"
    10496        self.carry = carry
    105         self.vars = [self.operand1, self.operand2, self.carry]
     97        BitExpr.__init__(self, expr1, expr2, "Add")
    10698    def show(self): return 'Add(%s, %s)' % (self.operand1.show(), self.operand2.show())
    10799    def show_C(self): return 'adc128(%s, %s, %s)' % (self.operand1.show_C(), self.operand2.show_C(), self.carry)
     
    109101class Sub(BitExpr):
    110102    def __init__(self, expr1, expr2, brw = None):
    111         self.operand1 = expr1
    112         self.operand2 = expr2
    113         self.vars = [self.operand1, self.operand2]
     103        self.op_C = "sbb128"
    114104        self.brw = None
     105        BitExpr.__init__(self, expr1, expr2, "Sub")
    115106    def show(self): return 'Sub(%s, %s)' % (self.operand1.show(), self.operand2.show())
    116107    def show_C(self): return 'sbb128(%s, %s)' % (self.operand1.show_C(), self.operand2.show_C())
     
    149140#
    150141
    151 class WhileLoop(BitStmt):
     142class WhileLoop(StmtList):
    152143    def __init__(self, expr, stmts):
    153144        self.control_expr = expr
    154         self.loop_body = stmts
     145        self.carry_expr = None
     146        StmtList.__init__(self, stmts)
     147        #self.stmt = stmts
    155148    def show(self):
    156149        rslt = ''
    157150        for s in self.loop_body: rslt += s.show() + '\n'
    158151        return 'while (%s) {%s}' % (self.control_expr.show(), rslt)
     152
     153
     154
     155class If(StmtList):
     156    def __init__(self, expr, true_branch, false_branch):
     157        self.control_expr = expr
     158        self.true_branch = true_branch
     159        self.false_branch = false_branch
     160
     161
     162class BitCondition:
     163    def __init__(self, var, val = ""):
     164        if isinstance(var, Var):
     165            self.var = var
     166        elif isinstance(var, str):
     167            self.var = Var(var)
     168        else:
     169            assert(1==0)
     170       
     171        self.val = val
     172
     173class isNoneZero(BitCondition):
     174    def __init__(self, var):
     175        BitCondition.__init__(self, var)
     176    def show(self):
     177        return '%s > 0'%self.var.show()
     178
     179class isAllZero(BitCondition):
     180    def __init__(self, var):
     181        BitCondition.__init__(self, var, "AllZero")
     182    def show(self):
     183        return '%s == 0'%self.var.show()
     184
     185class isAllOne(BitCondition):
     186    def __init__(self, var):
     187        BitCondition.__init__(self, var, "AllOne")
     188    def show(self):
     189        return '%s  == -1'%self.var.show()
    159190
    160191
     
    174205        return TrueLiteral()
    175206    elif isinstance(expr, Not):
    176         return expr.operand
     207        return expr.operand1
    177208    else: return Not(expr)
    178209
     
    189220    elif isinstance(expr1, Not):
    190221        if isinstance(expr2, Not):
    191             return make_not(make_or(expr1.operand, expr2.operand))
    192         elif equal_exprs(expr1.operand, expr2): return FalseLiteral()
     222            return make_not(make_or(expr1.operand1, expr2.operand1))
     223        elif equal_exprs(expr1.operand1, expr2): return FalseLiteral()
    193224        else: return And(expr1, expr2)
    194225    elif isinstance(expr2, Not):
    195         if equal_exprs(expr1, expr2.operand): return FalseLiteral()
     226        if equal_exprs(expr1, expr2.operand1): return FalseLiteral()
    196227        else: return And(expr1, expr2)
    197228    else: return And(expr1, expr2)
     
    328359    elif isinstance(expr, Andc):
    329360        return make_and(simplify_expr(expr.operand1), make_not(simplify_expr(expr.operand2)))
    330 
    331 
  • proto/parabix2/Compiler/bitstream_compiler.py

    r326 r344  
    88#
    99
    10 import ast, bitexpr, py2bitexpr, copy
    11 
    12 
    13 #
    14 #  Code Generation
    15 #
    16 
     10import ast, py2bitexpr, string
     11
     12#import bitexpr
     13"""
    1714Nop = 'nop'
    1815Andc = 'simd_andc'
     
    2623If = 'simd_has_bit'
    2724
    28 AllOne = 'AllOne'
    29 AllZero = 'AllZero'
    30 
    31 
    32 def expr2simd(genobj, expr):
    33     """Translate a Boolean expression into three-address simd code
    34        using code generator object genobj.
    35        
    36     """
    37     if isinstance(expr, bitexpr.TrueLiteral): return expr
    38     elif isinstance(expr, bitexpr.FalseLiteral): return expr
    39     elif isinstance(expr, bitexpr.Var):
    40         genobj.common_expression_map[expr.show()] = expr
    41         return expr
    42     elif isinstance(expr, bitexpr.Not):
    43        e = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand))
    44        return bitexpr.Andc(bitexpr.Var('AllOne'), e)
    45     elif isinstance(expr, bitexpr.Or):
    46        if isinstance(expr.operand1, bitexpr.FalseLiteral):
    47            return expr.operand2
    48        elif isinstance(expr.operand1, bitexpr.TrueLiteral):
    49            return bitexpr.TrueLiteral()
    50        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    51            return expr.operand1
    52        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    53            return bitexpr.TrueLiteral()
    54        e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    55        i = expr2simd(genobj, expr.operand2)
    56        e2 = genobj.expr_string_to_variable(i)
    57        return bitexpr.Or(e1, e2)
    58     elif isinstance(expr, bitexpr.Xor):
    59        e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    60        e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
    61        return bitexpr.Xor(e1,e2)
    62     elif isinstance(expr, bitexpr.And):
    63        if isinstance(expr.operand1, bitexpr.Not):
    64            e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1.operand))
    65            e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
    66            return bitexpr.Andc(e2, e1)
    67        elif isinstance(expr.operand2, bitexpr.Not):
    68            e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    69            e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2.operand))
    70            return bitexpr.Andc(e1,e2)
    71        elif isinstance(expr.operand1, bitexpr.FalseLiteral):
    72            return bitexpr.FalseLiteral()
    73        elif isinstance(expr.operand1, bitexpr.TrueLiteral):
    74            return expr.operand2
    75        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    76            return bitexpr.FalseLiteral()
    77        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    78            return expr.operand1
    79        else:
    80            e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    81            j = expr2simd(genobj, expr.operand2)
    82            e2 = genobj.expr_string_to_variable(j)
    83            return bitexpr.And(e1, e2)
    84     elif isinstance(expr, bitexpr.Sel):
    85        sel = genobj.expr_string_to_variable(expr2simd(genobj, expr.sel))
    86        e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.true_branch))
    87        e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.false_branch))
    88        return bitexpr.Sel(sel,e1,e2)
    89     elif isinstance(expr, bitexpr.Add):
    90        e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    91        e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
    92        carry = "carry%i"%BasicBlock.carry_counter
    93        return bitexpr.Add(e1, e2, carry)
    94     elif isinstance(expr, bitexpr.Sub):
    95        e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    96        e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2))
    97        brw = "brw%i"%BasicBlock.brw_counter
    98        return bitexpr.Sub(e1, e2, brw)
    99     elif isinstance(expr, bitexpr.Andc):
    100        if isinstance(expr.operand2, bitexpr.Not):
    101            e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    102            e2 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand2.operand))
    103            return bitexpr.And(e1,e2)
    104        elif isinstance(expr.operand1, bitexpr.FalseLiteral):
    105            return bitexpr.FalseLiteral()
    106        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    107            return expr.operand1
    108        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    109            return bitexpr.FalseLiteral()
    110        else:
    111            e1 = genobj.expr_string_to_variable(expr2simd(genobj, expr.operand1))
    112            j = expr2simd(genobj, expr.operand2)
    113            e2 = genobj.expr_string_to_variable(j)
    114            return bitexpr.Andc(e1, e2)
    115 
    116 
    117 def simplify_expr(expr):
    118     """
    119        simplifies a logical expression due to replacement of variables by constants
    120     """
    121     if isinstance(expr, bitexpr.TrueLiteral): return expr
    122     elif isinstance(expr, bitexpr.FalseLiteral): return expr
    123     elif isinstance(expr, bitexpr.Var): return expr
    124     elif isinstance(expr, bitexpr.Not): return make_not(simplify(expr.operand))
    125     elif isinstance(expr, bitexpr.Or):
    126        if isinstance(expr.operand1, bitexpr.FalseLiteral):
    127            return expr.operand2
    128        elif isinstance(expr.operand1, bitexpr.TrueLiteral):
    129            return bitexpr.TrueLiteral()
    130        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    131            return expr.operand1
    132        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    133            return bitexpr.TrueLiteral()
    134        else:
    135            return expr
    136     elif isinstance(expr, bitexpr.Xor):
    137        if isinstance(expr.operand1, bitexpr.FalseLiteral):
    138            return expr.operand2
    139        elif isinstance(expr.operand1, bitexpr.TrueLiteral):
    140            return bitexpr.Andc(bitexpr.TrueLiteral() ,expr.operand2)
    141        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    142            return expr.operand1
    143        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    144            return bitexpr.Andc(bitexpr.TrueLiteral() ,expr.operand1)
    145        else:
    146            return expr
    147     elif isinstance(expr, bitexpr.And):
    148        if isinstance(expr.operand1, bitexpr.FalseLiteral):
    149            return bitexpr.FalseLiteral()
    150        elif isinstance(expr.operand1, bitexpr.TrueLiteral):
    151            return expr.operand2
    152        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    153            return bitexpr.FalseLiteral()
    154        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    155            return expr.operand1
    156        else:
    157             return expr
    158     elif isinstance(expr, bitexpr.Sel):
    159        if isinstance(expr.sel, bitexpr.FalseLiteral):
    160            return expr.false_branch
    161        elif isinstance(expr.sel, bitexpr.TrueLiteral):
    162            return expr.true_branch
    163        elif isinstance(expr.true_branch, bitexpr.FalseLiteral):
    164            return bitexpr.Andc(expr.false_branch, expr.sel)
    165        elif isinstance(expr.true_branch, bitexpr.TrueLiteral):
    166            return bitexpr.Or(expr.false_branch, expr.sel)
    167        elif isinstance(expr.false_branch, bitexpr.FalseLiteral):
    168            return bitexpr.And(expr.true_branch, expr.sel)
    169        elif isinstance(expr.false_branch, bitexpr.TrueLiteral):
    170            return expr #SIMPLIFICATION IS POSSIBLE BUT ACCESS TO BASIC BLOCK IS NEEDED
    171        else:
    172             return expr
    173     elif isinstance(expr, bitexpr.Add):
    174 #  The following simplification does not work because there may be a carry bit.
    175 #        if isinstance(expr.operand1, bitexpr.FalseLiteral) and isinstance(expr.operand1, bitexpr.FalseLiteral):
    176 #            return bitexpr.FalseLiteral()
    177 #        else:
    178             return expr
    179     elif isinstance(expr, bitexpr.Sub):
    180         return expr
    181     elif isinstance(expr, bitexpr.Andc):
    182        if isinstance(expr.operand1, bitexpr.FalseLiteral):
    183            return bitexpr.FalseLiteral()
    184        elif isinstance(expr.operand2, bitexpr.FalseLiteral):
    185            return expr.operand1
    186        elif isinstance(expr.operand2, bitexpr.TrueLiteral):
    187            return bitexpr.FalseLiteral()
    188        else:
    189            return expr
    190 
    191 def gen_sym_table(code):
    192         """Generate a simple symbol table for a three address code
    193            each entry is of this form: var_name:[[defs][uses]]
    194         """
    195         table = {}
    196         for index, stmt in enumerate(code):
    197                 if stmt.LHS in table:
    198                         table[stmt.LHS][0].append(index)
    199                 else:
    200                         table[stmt.LHS] = [[index],[]]
    201                
    202                 for var in stmt.RHS.vars:
    203                         if var in table:
    204                                 table[var][1].append(index)
    205                         else:
    206                                 table[var] = [[], [index]]
    207         return table
    208 
    209 
    210 ################################################################
    211 ## This class abstracts one basic block of code
    212 ################################################################
    213 
    214 class BasicBlock:
    215     gensym_counter = 0
    216     carry_counter = 0
    217     brw_counter = 0
    218     def __init__(self, predeclared={}):
    219         self.gensym_template = 'Temp%i'
    220         self.code = []
    221         self.common_expression_map = {}
    222         self.new_vars = []
    223         self.vars = {}
    224         #for sym in predeclared: self.common_expression_map[sym] =sym
    225         self.common_expression_map.update(predeclared)
    226     def __len__(self):
    227         return len(self.code)
    228     def join(self, block):
    229         self.code += block.code
    230         self.common_expression_map.update(block.common_expression_map)
    231     def get_defs(self, varname):
    232         defs = []
    233         for line, loc in enumerate(self.code):
    234             if loc.LHS.varname == varname:
    235                 defs.append(line)
    236         if len(defs)==0:
    237             defs.append(None)
    238         return defs
    239     def get_uses(self, varname):
    240         uses = []
    241         for line, loc in enumerate(self.code):
    242             if isinstance(loc.RHS, bitexpr.Not):
    243                 if loc.RHS.operand.varname == varname:
    244                     uses.append(line)
    245                 continue
    246             if isinstance(loc.RHS, bitexpr.Var):
    247                 if loc.RHS.varname == varname:
    248                     uses.append(line)
    249                 continue
    250                
    251             if loc.RHS.operand1.varname == varname:
    252                 uses.append(line)
    253             if loc.RHS.operand2.varname == varname:
    254                 uses.append(line)
    255         if len(uses)==0:
    256             uses.append(None)
    257         return uses
    258 
    259     def split(self, line):
    260         code1 = self.code[:line]
    261         code2 = self.code[line:]
    262         bb2 = BasicBlock()
    263         bb2.code = code2
    264         bb2.common_expression_map = copy.copy(self.common_expression_map)
     25
     26"""
     27
     28class Program:
     29    def __init__(self):
     30        pass
     31        self.templ_name = "template.c"
     32        self.outfile_name = "code.c"
     33
     34    def output(self, decl, stmts, templ):
     35        templ = string.replace(templ, '@decl', decl)
     36        templ = string.replace(templ, '@stmts', stmts)
     37        return templ
     38
     39    def read_template(self):
     40        f=open(self.templ_name)
     41        inp = ""
     42        for line in f:
     43            inp += line
     44        f.close()
     45        return inp
     46
     47    def write_final_code(self, code):
     48        f=open(self.outfile_name, 'w')
     49        f.write(code)
     50        f.close()
     51
     52    def generate_code(self, s):
     53        s = ast.parse(s)
     54        s = s.body[0].body
     55
     56        s = py2bitexpr.translate_stmts(s)
     57
     58        st = py2bitexpr.gen_sym_table(s)
     59
     60        s=py2bitexpr.make_SSA(s, st)
     61
     62        s = py2bitexpr.partition2bb(s, )
     63
     64        s = py2bitexpr.apply_all_opt(s)
     65        s = py2bitexpr.normalize(s)
     66        py2bitexpr.simplify_tree(s)
     67
     68        livelist = ['u16hi[0]', 'u16hi[1]', 'u16hi[2]', 'u16hi[3]', 'u16hi[4]', 'u16hi[5]', 'u16hi[6]', 'u16hi[7]']
     69        livelist += ['u16lo[0]','u16lo[1]','u16lo[2]','u16lo[3]','u16lo[4]','u16lo[5]','u16lo[6]','u16lo[7]', 'delmask', 'u8.error', 'error_mask', 'u8lastbyte', 'Cursor2',]
     70        all_lives, s = py2bitexpr.eliminate_dead_code(s, set(livelist))
     71
     72        s=py2bitexpr.factor_out(s)
     73
     74        s, livelist = py2bitexpr.process_while_loops(s)
     75        declarations = py2bitexpr.gen_declarations(s)
     76
     77        templ = self.read_template()
     78        templ = self.output(declarations, py2bitexpr.print_prog(s), templ)
     79        self.write_final_code(templ)
     80        return s
     81
     82        #livelist = ['u16hi[0]','u16hi[1]','u16hi[2]','u16hi[3]','u16hi[4]','u16hi[5]','u16hi[6]','u16hi[7]']
     83        #livelist += ['u16lo[0]','u16lo[1]','u16lo[2]','u16lo[3]','u16lo[4]','u16lo[5]','u16lo[6]','u16lo[7]', 'delmask', 'u8.error', 'error_mask', 'u8lastbyte']
     84
     85
     86###############################################################
     87
     88#if __name__ == '__main__':
     89#if __name__ == '__main__':
     90
     91if True:
     92        s=r"""def u8u16(u8bit):
     93        temp1 = (u8bit[0] | u8bit[1]);
     94        temp2 = (u8bit[2] & u8bit[3]);
     95        temp3 = (temp2 &~ temp1);
     96        temp4 = (u8bit[4] & u8bit[5]);
     97        temp5 = (u8bit[6] | u8bit[7]);
     98        temp6 = (temp4 &~ temp5);
     99        LAngle =(temp3 & temp6);
     100        temp7 = (u8bit[6] &~ u8bit[7]);
     101        temp8 = (temp4 & temp7);
     102        RAngle =(temp3 & temp8);
     103        temp9 = (u8bit[7] &~ u8bit[6]);
     104        temp10 =(temp4 & temp9);
     105        Equal =(temp3 & temp10);
     106        temp11 =(u8bit[2] &~ u8bit[3]);
     107        temp12 =(temp11 &~ temp1);
     108        temp13 =(u8bit[5] &~ u8bit[4]);
     109        temp14 =(u8bit[6] & u8bit[7]);
     110        temp15 =(temp13 & temp14);
     111        SQuote =(temp12 & temp15);
     112        temp16 =(u8bit[4] | u8bit[5]);
     113        temp17 =(temp7 &~ temp16);
     114        DQuote =(temp12 & temp17);
     115        temp18 =(temp4 & temp14);
     116        Slash = (temp12 & temp18);
     117        temp19 =(temp16 | temp5);
     118        temp20 =(temp12 &~ temp19);
     119        temp21 =(u8bit[2] | u8bit[3]);
     120        temp22 =(temp1 | temp21);
     121        temp23 =(temp10 &~ temp22);
     122        temp24 =(temp20 | temp23);
     123        temp25 =(u8bit[4] &~ u8bit[5]);
     124        temp26 =(temp25 & temp9);
     125        temp27 =(temp26 &~ temp22);
     126        temp28 =(temp24 | temp27);
     127        temp29 =(temp25 & temp7);
     128        temp30 =(temp29 &~ temp22);
     129        WS = (temp28 | temp30);
     130        temp31 = (temp14 &~ temp16);
     131        temp32 = (temp12 & temp31);
     132        temp33 = (temp32 | Equal);
     133        temp34 = (temp33 | Slash);
     134        temp35 = (temp34 | RAngle);
     135        temp36 = (temp35 | LAngle);
     136        temp37 = (u8bit[1] &~ u8bit[0]);
     137        temp38 = (u8bit[3] &~ u8bit[2]);
     138        temp39 = (temp37 & temp38);
     139        temp40 = (temp39 & temp6);
     140        temp41 = (temp36 | temp40);
     141        temp42 = (temp41 | DQuote);
     142        temp43 = (temp42 | SQuote);
     143        temp44 = (temp25 & temp14);
     144        temp45 = (temp3 & temp44);
     145        temp46 = (temp43 | temp45);
     146        temp47 = (temp46 | temp20);
     147        temp48 = (temp47 | temp27);
     148        temp49 = (temp48 | temp30);
     149        temp50 = (temp49 | temp23);
     150        temp51 = (temp13 & temp7);
     151        temp52 = (temp12 & temp51);
     152        temp53 = (temp50 | temp52);
     153        temp54 = (temp12 & temp6);
     154        temp55 = (temp53 | temp54);
     155        temp56 = (temp37 & temp2);
     156        temp57 = (temp56 & temp6);
     157        temp58 = (temp55 | temp57);
     158        temp59 = (temp9 &~ temp16);
     159        temp60 = (temp12 & temp59);
     160        temp61 = (temp58 | temp60);
     161        temp62 = (temp3 & temp18);
     162        temp63 = (temp61 | temp62);
     163        temp64 = (temp39 & temp10);
     164        NameDelim = (temp63 | temp64);
     165
     166        DQuoteDelim = DQuote | LAngle
     167        SQuoteDelim = SQuote | LAngle
     168        AttListDelim = Slash | RAngle
    265169       
    266         bb2vars = []
    267         for loc in code2:
    268             bb2vars.append(loc.LHS.varname)
    269         bb1vars = []
    270         for loc in code1:
    271             bb1vars.append(loc.LHS.varname)
    272        
    273         for i in self.common_expression_map:
    274             if self.common_expression_map[i] in bb2vars and not self.common_expression_map[i] in bb1vars:
    275                 del self.common_expression_map[i]
    276        
    277         bb1 = BasicBlock()
    278         bb1.code = code1
    279         bb1.common_expression_map = self.common_expression_map
    280         return bb1, bb2
    281     def add_stmt(self, varname, expr):
    282         assert(not isinstance(varname, str))
    283         self.common_expression_map[expr.show()] = varname
    284         self.code.append(bitexpr.BitAssign(varname, expr))
    285         if isinstance(expr, bitexpr.Add):
    286                 BasicBlock.carry_counter += 1
    287         if isinstance(expr, bitexpr.Sub):
    288                 BasicBlock.brw_counter += 1
    289 
    290     def expr_string_to_variable(self, expr_string):
    291         #if isinstance(expr_string, bitexpr.And):
    292         #    print expr_string.operand1, expr_string.operand2
    293         #print expr_string.show()
    294         if self.common_expression_map.has_key(expr_string.show()):
    295             return self.common_expression_map[expr_string.show()]
    296         else:
    297             BasicBlock.gensym_counter += 1
    298             sym = bitexpr.Var(self.gensym_template % BasicBlock.gensym_counter)
    299             self.add_stmt(sym, expr_string)
    300             return sym
    301 
    302     def showcode(self, indent, line_no = False):
    303         s=""
    304         for index, stmt in enumerate(self.code):
    305                 #if index == 36:
    306                 #    print "::::", stmt, stmt.LHS, stmt.RHS, stmt.RHS.varname
    307                 if line_no:
    308                         s+= "%i %s%s;\n"%(index, " "*indent,stmt.show())
    309                 else:
    310                         #if not isinstance(stmt, bitexpr.BitAssign):
    311                         #   print "~~~~", index, stmt
    312                         s += " "*indent + stmt.show() + ";\n"
    313         return s
    314 
    315     def get_code(self):
    316         return self.code
    317 
    318     def calc_implications(self, assumptions):
    319         changed = False
    320         notchecked = [x for x in assumptions]
    321         lhs = [x.LHS.varname for x in self.code]
    322         while True:
    323             if len(notchecked) == 0:
    324                 break
    325             var = notchecked.pop(0)
    326             if var in lhs:
    327                 index = lhs.index(var)
    328                 s = self.code[index]
    329             else:
    330                 continue
    331             if var in assumptions:
    332                 if isinstance(assumptions[s.LHS.varname], bitexpr.FalseLiteral):
    333                     if isinstance(s.RHS, bitexpr.Not):
    334                         assumptions[s.RHS.operand.varname] = bitexpr.TrueLiteral()
    335                         notchecked.append(s.RHS.operand.varname)
    336                     elif isinstance(s.RHS, bitexpr.Var):
    337                         assumptions[s.RHS.varname] = bitexpr.FalseLiteral()
    338                         notchecked.append(s.RHS.operand.varname)
    339                     elif isinstance(s.RHS, bitexpr.Or):
    340                         assumptions[s.RHS.operand1.varname] = bitexpr.FalseLiteral()
    341                         assumptions[s.RHS.operand2.varname] = bitexpr.FalseLiteral()
    342                         notchecked.append(s.RHS.operand1.varname)
    343                         notchecked.append(s.RHS.operand2.varname)
    344                     elif isinstance(s.RHS, bitexpr.Xor):
    345                         pass
    346                 if isinstance(assumptions[s.LHS.varname], bitexpr.TrueLiteral):
    347                     if isinstance(s.RHS, bitexpr.Not):
    348                         assumptions[s.RHS.operand.varname] = bitexpr.FalseLiteral()
    349                         notchecked.append(s.RHS.operand.varname)
    350                     elif isinstance(s.RHS, bitexpr.Var):
    351                         assumptions[s.RHS.varname] = bitexpr.TrueLiteral()
    352                         notchecked.append(s.RHS.operand.varname)
    353                     elif isinstance(s.RHS, bitexpr.And):
    354                         assumptions[s.RHS.operand1.varname] = bitexpr.TrueLiteral()
    355                         assumptions[s.RHS.operand2.varname] = bitexpr.TrueLiteral()
    356                         notchecked.append(s.RHS.operand1.varname)
    357                         notchecked.append(s.RHS.operand2.varname)
    358                     elif isinstance(s.RHS, bitexpr.Andc):
    359                         if isinstance(s.RHS.operand1, bitexpr.TrueLiteral):
    360                             assumptions[s.RHS.operand2.varname] = bitexpr.FalseLiteral()
    361                             notchecked.append(s.RHS.operand2.varname)
    362                         else:
    363                             assumptions[s.RHS.operand1.varname] = bitexpr.TrueLiteral()
    364                             assumptions[s.RHS.operand2.varname] = bitexpr.FalseLiteral()
    365                             notchecked.append(s.RHS.operand1.varname)
    366                             notchecked.append(s.RHS.operand2.varname)
    367         return assumptions
    368     def propagate_constant(self, previous):
    369         changed = False
    370         fixed = previous
    371         #print fixed
    372         for s in self.code:
    373             #if len(self.code)==1: print s, s.LHS.varname, s.RHS
    374 
    375             if isinstance(s.RHS, bitexpr.FalseLiteral):
    376                 fixed[s.LHS.varname] = bitexpr.FalseLiteral()
    377                 continue
    378             if isinstance(s.RHS, bitexpr.TrueLiteral):
    379                 fixed[s.LHS.varname] = bitexpr.TrueLiteral()
    380                 continue
    381             if isinstance(s.RHS, bitexpr.Var):
    382                 if s.RHS.varname in fixed:
    383                     s.RHS = fixed[s.RHS.varname]
    384                 fixed[s.LHS.varname] = s.RHS
    385                 continue
    386             #   fixed[s.LHS] = s.RHS
    387             if isinstance(s.RHS, bitexpr.Not):
    388                 if s.RHS.operand.varname in fixed:
    389                     s.RHS.operand = fixed[s.RHS.operand.varname]
    390                 continue
    391             if s.RHS.operand1.varname in fixed:
    392                 changed = True
    393                 s.RHS.operand1 = fixed[s.RHS.operand1.varname]
    394             if s.RHS.operand2.varname in fixed:
    395                 s.RHS.operand2 = fixed[s.RHS.operand2.varname]
    396                 changed = True
    397 
    398         return fixed, changed
    399 
    400     def normalize(self, stmts):
    401         for s in stmts:
    402             self.add_stmt(s.LHS, expr2simd(self, s.RHS))
    403     def transform(self):
    404         for loc in self.code:
    405             loc.RHS = simplify_expr(loc.RHS)
    406     def simplify(self, previous):
    407         changed = True
    408         while changed:
    409             self.transform()
    410             fixed, changed = self.propagate_constant(previous)
    411         return fixed
    412 
    413 ################################################################
    414 # Going through compilation passes one by one
    415 ################################################################
    416 def generate_code(s):
    417 
    418     #Pass 1
    419     s = ast.parse(s)
    420     s = s.body[0].body
    421 
    422     #Pass 2
    423     s = py2bitexpr.translate_stmts(s)
    424 
    425     #Pass 3
    426     st = py2bitexpr.gen_sym_table(s)
    427     s=py2bitexpr.make_SSA(s, st)
    428 
    429     #Pass 4
    430     bb, exprs = py2bitexpr.partition2bb(s)
    431 
    432     #Pass 5
    433     declarations = py2bitexpr.gen_declarations(bb)
    434 
    435     #Pass 6
    436     tree = py2bitexpr.construct_tree(bb, exprs)
    437 
    438     #Pass 7
    439     tree = py2bitexpr.Reducer().apply_all_opt(tree)
    440 
    441     #Pass 8
    442     py2bitexpr.simplify_tree(tree)
    443 
    444    
    445     livelist = ['u16hi[0]','u16hi[1]','u16hi[2]','u16hi[3]','u16hi[4]','u16hi[5]','u16hi[6]','u16hi[7]']
    446     livelist += ['u16lo[0]','u16lo[1]','u16lo[2]','u16lo[3]','u16lo[4]','u16lo[5]','u16lo[6]','u16lo[7]', 'delmask', 'u8.error', 'error_mask']
    447     py2bitexpr.eliminate_dead_code(tree, set(livelist))
    448    
    449     #Pass 9
    450     s = py2bitexpr.unwind(tree)
    451     return declarations+s
    452 
    453 
    454 ###############################################################
    455 if __name__ == '__main__':
    456         s=r"""def u8u16(u8, u8bit):
    457         u8.unibyte = (~u8bit[0]);
    458         optimize(u8bit[0], allzero)
    459 #        optimize(u8.unibyte, allone)
    460         u8.prefix = (u8bit[0] & u8bit[1]);
    461         u8.prefix2 = (u8.prefix &~ u8bit[2]);
    462         temp1 = (u8bit[2] &~ u8bit[3]);
    463         u8.prefix3 = (u8.prefix & temp1);
    464         temp2 = (u8bit[2] & u8bit[3]);
    465         u8.prefix4 = (u8.prefix & temp2);
    466         u8.suffix = (u8bit[0] &~ u8bit[1]);
    467 #        maxtwo = u8.prefix2 | u8.unibyte | u8.suffix
    468 #        optimize(maxtwo, allone)
    469         temp3 = (u8bit[2] | u8bit[3]);
    470         temp4 = (u8.prefix &~ temp3);
    471         temp5 = (u8bit[4] | u8bit[5]);
    472         temp6 = (temp5 | u8bit[6]);
    473         temp7 = (temp4 &~ temp6);
    474         temp8 = (u8bit[6] | u8bit[7]);
    475         temp9 = (u8bit[5] & temp8);
    476         temp10 = (u8bit[4] | temp9);
    477         temp11 = (u8.prefix4 & temp10);
    478         u8.badprefix = (temp7 | temp11);
    479         temp12 = (temp5 | temp8);
    480         u8.xE0 = (u8.prefix3 &~ temp12);
    481         temp13 = (u8bit[4] & u8bit[5]);
    482         temp14 = (u8bit[7] &~ u8bit[6]);
    483         temp15 = (temp13 & temp14);
    484         u8.xED = (u8.prefix3 & temp15);
    485         u8.xF0 = (u8.prefix4 &~ temp12);
    486         temp16 = (u8bit[5] &~ u8bit[4]);
    487         temp17 = (temp16 &~ temp8);
    488         u8.xF4 = (u8.prefix4 & temp17);
    489         u8.xA0_xBF = (u8.suffix & u8bit[2]);
    490         u8.x80_x9F = (u8.suffix &~ u8bit[2]);
    491         u8.x90_xBF = (u8.suffix & temp3);
    492         u8.x80_x8F = (u8.suffix &~ temp3);
    493        
    494         u8.scope22 = bitutil.Advance(u8.prefix2)
    495         u8.scope32 = bitutil.Advance(u8.prefix3)
    496         u8.scope33 = bitutil.Advance(u8.scope32)
    497         u8.scope42 = bitutil.Advance(u8.prefix4)
    498         u8.scope43 = bitutil.Advance(u8.scope42)
    499         u8.scope44 = bitutil.Advance(u8.scope43)
    500         u8lastscope = u8.scope22 | u8.scope33 | u8.scope44
    501         u8anyscope = u8lastscope | u8.scope32 | u8.scope42 | u8.scope43
    502         optimize(u8anyscope,allzero)
    503         # C0-C1 and F5-FF are illegal
    504         error_mask = u8.badprefix
    505        
    506         error_mask |= bitutil.Advance(u8.xE0) & u8.x80_x9F
    507         error_mask |= bitutil.Advance(u8.xED) & u8.xA0_xBF
    508         error_mask |= bitutil.Advance(u8.xF0) & u8.x80_x8F
    509         error_mask |= bitutil.Advance(u8.xF4) & u8.x90_xBF
    510        
    511         error_mask |= u8anyscope ^ u8.suffix
    512         u8.error = error_mask
    513        
    514         u8lastscope = u8.scope22 | u8.scope33 | u8.scope44
    515         u8lastbyte = u8.unibyte | u8lastscope
    516         u16lo[2] = u8lastbyte & u8bit[2]
    517         u16lo[3] = u8lastbyte & u8bit[3]
    518         u16lo[4] = u8lastbyte & u8bit[4]
    519         u16lo[5] = u8lastbyte & u8bit[5]
    520         u16lo[6] = u8lastbyte & u8bit[6]
    521         u16lo[7] = u8lastbyte & u8bit[7]
    522         u16lo[1] = (u8.unibyte & u8bit[1]) | (u8lastscope & bitutil.Advance(u8bit[7]))
    523         u16lo[0] = u8lastscope & bitutil.Advance(u8bit[6])
    524        
    525         u16hi[5] = u8lastscope & bitutil.Advance(u8bit[3])
    526         u16hi[6] = u8lastscope & bitutil.Advance(u8bit[4])
    527         u16hi[7] = u8lastscope & bitutil.Advance(u8bit[5])
    528         u16hi[0] = u8.scope33 & bitutil.Advance(bitutil.Advance(u8bit[4]))
    529         u16hi[1] = u8.scope33 & bitutil.Advance(bitutil.Advance(u8bit[5]))
    530         u16hi[2] = u8.scope33 & bitutil.Advance(bitutil.Advance(u8bit[6]))
    531         u16hi[3] = u8.scope33 & bitutil.Advance(bitutil.Advance(u8bit[7]))
    532         u16hi[4] = u8.scope33 & bitutil.Advance(u8bit[2])
    533 
    534         u8surrogate = u8.scope43 | u8.scope44
    535         u16hi[0] = u16hi[0] | u8surrogate       
    536         u16hi[1] = u16hi[1] | u8surrogate       
    537         u16hi[3] = u16hi[3] | u8surrogate       
    538         u16hi[4] = u16hi[4] | u8surrogate       
    539         u16hi[5] = u16hi[5] | u8.scope44
    540 
    541         s42lo1 = ~u8bit[3] # subtract 1
    542         u16lo[1] = u16lo[1] | (u8.scope43 & bitutil.Advance(s42lo1))
    543         s42lo0 = u8bit[2] ^ s42lo1 # borrow *
    544         u16lo[0] = u16lo[0] | (u8.scope43 & bitutil.Advance(s42lo0))
    545         borrow1 = s42lo1 & ~u8bit[2]
    546         bitutil.Advance_bit7 = bitutil.Advance(u8bit[7])
    547         s42hi7 = bitutil.Advance_bit7 ^ borrow1
    548         u16hi[7]= u16hi[7] | (u8.scope43 & bitutil.Advance(s42hi7))
    549         borrow2 = borrow1 & ~bitutil.Advance_bit7
    550         s42hi6 = bitutil.Advance(u8bit[6]) ^ borrow2
    551         u16hi[6] = u16hi[6] | (u8.scope43 & bitutil.Advance(s42hi6))
    552 
    553         u16lo[2] = u16lo[2] | (u8.scope43 & bitutil.Advance(u8bit[4]))
    554         u16lo[3] = u16lo[3] | (u8.scope43 & bitutil.Advance(u8bit[5]))
    555         u16lo[4] = u16lo[4] | (u8.scope43 & bitutil.Advance(u8bit[6]))
    556         u16lo[5] = u16lo[5] | (u8.scope43 & bitutil.Advance(u8bit[7]))
    557         u16lo[6] = u16lo[6] | (u8.scope43 & u8bit[2])
    558         u16lo[7] = u16lo[7] | (u8.scope43 & u8bit[3])
    559 
    560         delmask = u8.prefix | u8.scope32 | u8.scope42"""
    561         print generate_code(s)
     170        LAngleFollow = bitutil.Advance(LAngle)
     171        ElemNamePositions = LAngleFollow & ~Slash
     172        EndTagSeconds = LAngleFollow & Slash
     173
     174        ElemNameFollows = bitutil.ScanThru(ElemNamePositions, ~NameDelim)
     175        #ElemNames = ElemNameFollows - ElemNamePositions
     176        ParseError = ElemNamePositions & ElemNameFollows
     177
     178        AttNameStarts = AllZero
     179        AttNameFollows = AllZero
     180        EqToCheck = AllZero
     181        AttValStarts = AllZero
     182        AttValEnds = AllZero
     183        AttValFollows = AllZero
     184
     185        AfterWS = bitutil.ScanThru(ElemNameFollows, WS)
     186        AttListEnd = AfterWS & AttListDelim
     187        AttNameStart = AfterWS & ~AttListDelim
     188        ParseError |= ElemNameFollows & AttNameStart
     189
     190        while AttNameStart > 0:
     191            AttNameStarts |= AttNameStart
     192            AttNameFollow = bitutil.ScanThru(AttNameStart, ~NameDelim)
     193            AttNameFollows |= AttNameFollow
     194
     195            # Scan through WS to the expected '=' delimiter.
     196            EqExpected = bitutil.ScanThru(AttNameFollow, WS)
     197            EqToCheck |= EqExpected
     198            AttValPos = bitutil.ScanThru(bitutil.Advance(EqExpected), WS)
     199            AttValStarts |= AttValPos
     200            DQuoteAttVal = AttValPos & DQuote
     201            SQuoteAttVal = AttValPos & SQuote
     202            DQuoteAttEnd = bitutil.ScanThru(bitutil.Advance(DQuoteAttVal), ~DQuoteDelim)
     203            SQuoteAttEnd = ScanThru(bitutil.Advance(SQuoteAttVal), ~SQuoteDelim)
     204            AttValEnd = DQuoteAttEnd | SQuoteAttEnd
     205            AttValEnds |= AttValEnd
     206            AttValFollow = bitutil.Advance(AttValEnd)
     207            AttValFollows |= AttValFollow
     208            AfterWS = bitutil.ScanThru(AttValFollow, WS)
     209            AttListEnd |= AfterWS & AttListDelim
     210            AttNameStart = AfterWS & ~AttListDelim
     211
     212        # No more attribute values to process when AttNameStart == 0.
     213
     214        #AttNames = AttNameFollows - AttNameStarts
     215        #AttVals = AttValFollows - AttValStarts
     216        STagEnds = AttListEnd & RAngle
     217
     218        # Mark any "/" characters found as the ends of empty element tags.
     219        EmptyTagEnds = Advance(AttListEnd & Slash)
     220        Tags = (STagEnds | EmptyTagEnds) - ElemNamePositions
     221
     222        # Check for errors.
     223        ParseError |= AttValFollows & AttNameStarts
     224        ParseError |= AttNameStarts & AttNameFollows
     225        ParseError |= EqToCheck & ~Equal
     226        ParseError |= AttValStarts & ~ (DQuote | SQuote)
     227        ParseError |= AttValEnds & ~ (DQuote | SQuote)
     228        ParseError |= EmptyTagEnds & ~RAngle
     229
     230        # End Tag Parsing
     231        EndTagEnds = bitutil.ScanThru(bitutil.ScanThru(bitutil.Advance(EndTagSeconds), ~NameDelim), WS)
     232        ParseError |= EndTagEnds &~ RAngle
     233        error_mask=ParseError"""
     234        s = Program().generate_code(s)
    562235
    563236#TODO: merge expr2simd() and a simplify_expr()
     
    565238#TODO: removing unreachable code is developed but is incomplete. Currently the code is not used.
    566239#TODO: Support optimizing over subset of values
    567 #TODO: optimize pragma can be put in the beginning of the code, or any point. Right now, it is sensitive to the location
    568240#TODO: A mechanism for programmer to let the compiler know about relations between bitstreams. Some relations might
    569241#       be possibly extracted automatically, but it might be expensive.Look at the optimize(maxtwo, allone) in the current example.
    570 #       it does not help at all, while it is an important case. This is because thecompiler can not extract all the information 
     242#       it does not help at all, while it is an important case. This is because thecompiler can not extract all the information
    571243#       implied by the assumption maxtwo = allone.
     244#TODO: How variables redefined within a while loop should be dealt with for SSA generation?
     245#TODO: How variables redefined within a while loop should be dealt if there is an optimization on them?
     246#TODO: Numbering of temp variables generated by class BasicBlock
     247#TODO: Cross-BasicBlock use of common expressions in class BasicBlock
     248#######################################
     249# Errors to consider and report to the programmer:
     250# 1- Optimizaing on a variable that does not exist
     251# 2- infinite while loop when the loop looks like this: while (AllOne>0)
  • proto/parabix2/Compiler/py2bitexpr.py

    r323 r344  
    1313# Requires ast module of python 2.6
    1414
    15 import ast, bitexpr, copy, bitstream_compiler
     15import ast, bitexpr, copy, basic_block
     16
     17AllOne = 'AllOne'
     18AllZero = 'AllZero'
    1619
    1720class PyBitError(Exception):
     
    2124## Translation to bitExpr classes --- Pass 2
    2225#############################################################################################
    23 
    2426
    2527def translate_index(ast_value):
     
    8587                        return bitexpr.make_and(bitexpr.make_add(e0, e1), bitexpr.make_not(e1))
    8688                else: raise PyBitError("Bad PyBit function call: %s\n" % ast.dump(ast_expr))
     89        elif isinstance(ast_expr, ast.Compare):
     90                if (isinstance(ast_expr.ops[0], ast.Gt) and (ast_expr.comparators[0].n==0)):
     91                        e0 = translate(ast_expr.left)
     92                        return bitexpr.isNoneZero(e0)
     93                else:   raise PyBitError("Bad condition in while loop: %s\n" % ast.dump(ast_expr))
    8794        else: raise PyBitError("Unknown expression %s\n" % ast.dump(ast_expr))
    8895
     
    104111        for s in ast_stmts:
    105112                if isinstance(s, ast.Expr):
    106                     target = translate_var(s.value.args[0])
    107                     replace = translate_var(s.value.args[1])
    108                     translated.append(bitexpr.Reduce(target, replace))
     113                        if (s.value.func.id=='optimize'):
     114                            target = translate_var(s.value.args[0])
     115                            replace = translate_var(s.value.args[1])
     116                            translated.append(bitexpr.Reduce(target, replace))
     117                        else: raise PyBitError("Unknown operation %s\n", ast.dump(s))
    109118                elif isinstance(s, ast.Assign):
    110119                        e = translate(s.value)
     
    112121                                translated.append(bitexpr.BitAssign(bitexpr.Var(translate_var(astv)), e))
    113122                elif isinstance(s, ast.AugAssign):
    114                         #print "~~~~~", translate_var(s.target)
    115123                        v = bitexpr.Var(translate_var(s.target))
    116124                        translated.append(bitexpr.BitAssign(v, translate(ast.BinOp(s.target, s.op, s.value))))
     
    121129                else: raise PyBitError("Unknown PyBit statement type %s\n" % ast.dump(s))
    122130        return translated
    123                
    124                
     131
    125132#############################################################################################
    126133## Conversion to SSA form --- Pass 3
     
    133140        return []
    134141    if isinstance(rhs, bitexpr.Not):
    135         return extract_vars(rhs.operand)
     142        return extract_vars(rhs.operand1)
    136143    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
    137144    return extract_vars(rhs.operand1)+extract_vars(rhs.operand2)
    138145
    139 def gen_sym_table(code):
    140         """Generate a simple symbol table for a three address code
    141            each entry is of this form: var_name:[[defs][uses]]
    142         """
    143         table = {}
    144         for index, stmt in enumerate(code):
    145                 if isinstance(stmt, bitexpr.Reduce):
    146                     if stmt.target in table:
    147                         table[stmt.target][1].append(index)
    148                     else:
    149                         table[stmt.target][1] = [[], [index]]
     146def update_lineno(inner_table, shift):
     147    for key in inner_table:
     148        for index in range(len(inner_table[key][0])):
     149            inner_table[key][0][index] += shift
     150        for index in range(len(inner_table[key][1])):
     151            inner_table[key][1][index] += shift
     152    return inner_table
     153
     154def merge_tables(table, inner_table):
     155    for key in inner_table:
     156        if key in table:
     157            table[key][0] += inner_table[key][0]
     158            table[key][1] += inner_table[key][1]
     159        else:
     160            table[key] = inner_table[key]
     161    return table
     162
     163def gen_sym_table(code, goInside = False):
     164    """Generate a simple symbol table for a three address code
     165        each entry is of this form: var_name:[[defs][uses]]
     166    """
     167    table = {}
     168    index = 0
     169    for stmt in code:
     170        if isinstance(stmt, bitexpr.Reduce):
     171            if stmt.target in table:
     172                table[stmt.target][1].append(index)
     173            else:
     174                table[stmt.target] = [[], [index]]
     175            index += 1
     176        elif isinstance(stmt, bitexpr.BitAssign):
     177            current = stmt.LHS.varname
     178            if current in table:
     179                table[current][0].append(index)
     180            else:
     181                table[current] = [[index],[]]
     182
     183            varlist = extract_vars(stmt.RHS)
     184            for var in varlist:
     185                if var in table:
     186                    table[var][1].append(index)
    150187                else:
    151                         assert(isinstance(stmt, bitexpr.BitAssign))
    152                         current = stmt.LHS.varname
    153                         if current in table:
    154                                 table[current][0].append(index)
    155                         else:
    156                                 table[current] = [[index],[]]
    157 
    158                         varlist = extract_vars(stmt.RHS)
    159                         for var in varlist:
    160                                 if var in table:
    161                                         table[var][1].append(index)
    162                                 else:
    163                                         table[var] = [[], [index]]
    164         return table
     188                    table[var] = [[], [index]]
     189            index += 1
     190        elif isinstance(stmt, bitexpr.WhileLoop):
     191            cond_var = stmt.control_expr.var.varname
     192            if cond_var in table:
     193                table[cond_var][1].append(index)
     194            else:
     195                #while loop conditioned on an undefined variable
     196                assert(1==0)
     197            if goInside:
     198                inner_table = gen_sym_table(stmt.stmts, True)
     199                inner_table = update_lineno(inner_table, index+1)
     200                table = merge_tables(table, inner_table)
     201                index += (1+len(stmt.stmts))
     202            else:
     203                index += 1
     204        else:
     205            assert(1==0)
     206    return table
     207##################################################################################
     208def get_line(code, line):
     209
     210    lineno = 0
     211    for stmt in code:
     212        if isinstance(stmt, bitexpr.BitAssign):
     213            if lineno == line:
     214                return stmt
     215            lineno += 1
     216        elif isinstance(stmt, bitexpr.WhileLoop):
     217            if lineno == line:
     218                return stmt
     219            elif lineno+len(stmt.stmts) < line:
     220                lineno += 1
     221                line -= len(stmt.stmts)
     222                continue
     223            else:
     224                return get_line(stmt.stmts, line-(lineno+1))
     225
     226        elif isinstance(stmt, bitexpr.Reduce):
     227            if lineno == line:
     228                return stmt
     229            lineno += 1
     230        else:
     231            assert(1==0)
     232
     233def update_def(code, var, line, suffix_num):
     234    #for i in code:
     235    #    print i
     236    #print "------------------------", len(code), line, var
     237    loc = code[line]
     238    if isinstance(loc, bitexpr.BitAssign):
     239        loc.LHS.varname = simplify_name(loc.LHS.varname)+ "_%i"%suffix_num
     240        return loc
     241    else:
     242        #either Reduce or While Loop in both cases it's a use
     243        pass
     244
     245def update_rhs(rhs, varname, newname):
     246    if isinstance(rhs, bitexpr.Var):
     247        if rhs.varname == varname:
     248            #if varname == "u8.unibyte":
     249            #    print ")))))))))))))))))"
     250           
     251            rhs.varname = newname
     252        return rhs
     253
     254    if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
     255        return rhs
     256
     257    if isinstance(rhs, bitexpr.Not):
     258        rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
     259        return rhs
     260    #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
     261    rhs.operand1 = update_rhs(rhs.operand1, varname, newname)
     262    rhs.operand2 = update_rhs(rhs.operand2, varname, newname)
     263    return rhs
     264
     265def update_use(code, var, line, suffix_num):
     266    loc = code[line]
     267    if isinstance(loc, bitexpr.BitAssign):
     268        loc.RHS = update_rhs(loc.RHS, var, simplify_name(var)+("_%i"%suffix_num))
     269    elif isinstance(loc, bitexpr.Reduce):
     270        pass
     271    elif isinstance(loc, bitexpr.WhileLoop):
     272        loc.control_expr.var.varname += "_%i"%suffix_num
     273    return loc
    165274
    166275def parse_var(var):
     
    173282        right_index = var.find(']')
    174283        return ('array', var[0:index], var[index+1:right_index])       
    175        
     284   
     285    if var.startswith("carry") or var.startswith("Carry"):
     286        return('int', var, None)
     287   
    176288    return ('bitblock', var, None)
    177289
     
    184296    if vartype == 'struct':
    185297        return "s_%s_%s"%(name, extra)
     298    if vartype == 'int':
     299        assert(1==0)
    186300    assert(1==0)
    187301
    188 def update_rhs(rhs, varname, newname):
    189     if isinstance(rhs, bitexpr.Var):
    190         if rhs.varname == varname:
    191             rhs.varname = newname
    192         return
    193            
    194     if isinstance(rhs, bitexpr.FalseLiteral) or isinstance(rhs, bitexpr.TrueLiteral):
    195         return
    196    
    197     if isinstance(rhs, bitexpr.Not):
    198         update_rhs(rhs.operand, varname, newname)
    199         return
    200     #So it is a binary operation WHAT ABOUT ADD AND SUBTRACT?
    201     update_rhs(rhs.operand1, varname, newname)
    202     update_rhs(rhs.operand2, varname, newname)
    203 
    204302def pairs(lst):
    205         if lst == []:
    206                 return []
    207         return zip(lst,lst[1:]+[lst[0]])
     303    if lst == []:
     304            return []
     305    return zip(lst,lst[1:]+[lst[0]])
     306
    208307
    209308def make_SSA(code, st):
    210         new_vars = []
    211         total_lines = len(code)
    212         for var in st:
    213                 st[var][0].append(total_lines)
    214                 st[var][1].append(total_lines)
    215 
    216         unique_number = 0
    217         for var in st:
    218                 use_index = 0
    219                 for current, next in pairs(st[var][0])[0:-2]:
    220                         code[current].LHS.varname = "%s_%i"%(simplify_name(var), unique_number)
    221                         new_vars.append("%s_%i"%(simplify_name(var), unique_number))
    222                         uline = st[var][1][use_index]
    223                         while uline <= next and uline < total_lines:
    224                                 if uline > current:
    225                                         update_rhs(code[uline].RHS, var, "%s_%i"%(simplify_name(var), unique_number))
    226                                         #code[uline].RHS.update_var(var, )
    227                                 use_index += 1
    228                                 uline = st[var][1][use_index]
    229                         unique_number += 1
    230         return code
     309    new_vars = []
     310    total_lines = len(code)
     311    for var in st:
     312        st[var][0].append(total_lines)
     313        st[var][1].append(total_lines)
     314
     315    unique_number = 0
     316    for var in st:
     317        use_index = 0
     318        for current, next in pairs(st[var][0])[0:-2]:
     319            code[current] = update_def(code, var, current, unique_number)
     320            uline = st[var][1][use_index]
     321
     322            while uline <= next and uline < total_lines:
     323                if uline > current:
     324                    code[uline] = update_use(code, var, uline, unique_number)
     325                use_index += 1
     326                uline = st[var][1][use_index]
     327            unique_number += 1
     328    return code
    231329
    232330#################################################################################################################
    233331## Breaking the code into basic blocks depending on where optimize pragma appears --- Pass 4
    234332#################################################################################################################
     333def get_opt_list(s):
     334    opt_list = []
     335    remove_index = []
     336    for line, stmt in enumerate(s):
     337        if isinstance(stmt, bitexpr.Reduce):
     338            opt_pair = (stmt.target, stmt.replace)
     339            opt_list.append(opt_pair)
     340            remove_index.append(line)
     341    for ind in reversed(remove_index):
     342        del s[ind]
     343
     344    return opt_list
     345
    235346def break_to_segs(s, breaks):
    236347    segs = []
    237348    breaks = [-1]+breaks+[len(s)]
    238349    for start, stop in pairs(breaks)[:-1]:
     350        #print start, stop
    239351        next = s[start+1:stop]
    240352        if len(next)>0:
     
    242354    return segs
    243355
     356def extract_pragmas(s):
     357    targets = []
     358    exprs = []
     359    lines = []
     360
     361    #extracting pragmas and their targets
     362    for line, loc in enumerate(s):
     363        if isinstance(loc, bitexpr.Reduce):
     364            targets.append(loc.target)
     365            exprs.append(loc)
     366            lines.append(line)
     367
     368    #removing pragmas from the source code
     369    for i in reversed(lines):
     370        del s[i]
     371
     372    return targets, exprs
     373
     374def get_defs(targets, s):
     375    temp = [(x,-1) for x in targets]
     376    targ_dic = dict(temp)
     377    for line, loc in enumerate(s):
     378        if loc.LHS.varname in targets:
     379            targ_dic[loc.LHS.varname] = line
     380
     381    items = targ_dic.items()
     382    items = sorted(items, key=(lambda x: x[1]))
     383
     384    lineno = [key for value, key in items]
     385    sorted_targets = [value for value, key in items]
     386
     387    return sorted_targets, lineno
     388
     389def get_boundary(def_line, lastuse_line):
     390    earliest = min(def_line)
     391
     392    # The indices of all variables defined in the same line of code and earlier than all other variables
     393    # There is more than one such variable only if these variables are input variables
     394    all_early = []
     395    all_early_cnt = 0
     396    j = 0
     397
     398    while earliest in def_line[j:]:
     399        m = def_line[j:].index(earliest)
     400        all_early_cnt += 1
     401        all_early.append(m+j)
     402        j += m+1
     403
     404    #all_early contains the indices of the all earliest defined vars
     405    lasts = []
     406    map(lambda y: lasts.append(lastuse_line[y]), all_early)
     407
     408    latest = max(lasts)
     409    return earliest, latest
     410
     411def adjust_latest(s, latest):
     412    line = 0
     413    for stmt in s:
     414        if isinstance(stmt, bitexpr.BitAssign):
     415            if latest == line:
     416                return latest
     417            line += 1
     418        if isinstance(stmt, bitexpr.WhileLoop):
     419            end_of_loop = line+len(stmt.stmts)
     420            if latest <= end_of_loop:
     421                return end_of_loop
     422            else:
     423                line = end_of_loop+1
     424    return line
     425
     426#There is slight difference between the function below and adjust_latest (look at the second return)
     427def adjust_earliest(s, earliest):
     428    line = 0
     429    if earliest == -1:
     430        return -1
     431
     432    for stmt in s:
     433        if isinstance(stmt, bitexpr.BitAssign):
     434            if earliest == line:
     435                return earliest
     436            line += 1
     437        if isinstance(stmt, bitexpr.WhileLoop):
     438            end_of_loop = line+len(stmt.stmts)
     439            if earliest <= end_of_loop:
     440                return end_of_loop+1
     441            else:
     442                line = end_of_loop+1
     443
     444def chop_code(s, start, stop):
     445    line = 0
     446    if start == -1:
     447        cut1 = -1
     448   
     449    if stop == len(s):
     450        cut2 = len(s)
     451   
     452    for index, stmt in enumerate(s):
     453        if isinstance(stmt, bitexpr.BitAssign):
     454            if line == start:
     455                cut1 = index
     456            if line == stop:
     457                cut2 = index
     458            line += 1
     459
     460        if isinstance(stmt, bitexpr.WhileLoop):
     461            end_of_loop = line+len(stmt.stmts)
     462            if line == start:
     463                cut1 = index
     464            if end_of_loop == stop:
     465                cut2 = index
     466            line = end_of_loop+1
     467
     468    return s[:cut1+1], s[cut1+1:cut2+1], s[cut2+1:]
     469
     470def count_lines(code):
     471    cnt = len(code)
     472    for stmt in code:
     473        if isinstance(stmt, bitexpr.WhileLoop):
     474            cnt += count_lines(stmt.stmts)
     475    return cnt
     476
     477
     478def gen_bb(s, opt_list, def_line, lastuse_line):
     479    assert (len(opt_list) == len(def_line) == len(lastuse_line))
     480    if opt_list==[]:
     481        return s
     482
     483    earliest, latest = get_boundary(def_line, lastuse_line)
     484
     485    indices = []
     486    for index, i in enumerate(def_line): #was enumerate(lastuse_line)
     487        if i <= latest:
     488            indices.append(index)
     489    #if latest or earliest are inside a while loop they are pushed to the end of while loop or beginning of the next block (respectively)
     490
     491    e = earliest
     492    l = latest
     493
     494    opt1 = {}
     495    def1 = {}
     496    use1 = {}
     497    for index, i in enumerate(opt_list):
     498        opt1.setdefault(index in indices, []).append(i)
     499    for index, i in enumerate(def_line):
     500        def1.setdefault(index in indices, []).append(i)
     501    for index, i in enumerate(lastuse_line):
     502        use1.setdefault(index in indices, []).append(i)
     503
     504    latest = max(use1[True])
     505    latest = adjust_latest(s, latest)
     506    earliest = adjust_earliest(s, earliest)
     507
     508    #use earliest to construct an initial block with no if-then-else
     509    #use use1[True], def1[True], opt1[True] to construct if then else and recurse on inner blocks
     510    #use use1[False], def1[False], opt1[False] to construct the block after if-then-else and recurse on that block
     511    first, second, third = chop_code(s, earliest, latest)
     512    the_opt = None
     513    the_index = None
     514    for index, value in enumerate(opt1[True]):
     515        if e==def1[True][index] and l==use1[True][index]:
     516            the_opt = value
     517            the_index = index
     518    del opt1[True][the_index]
     519    del use1[True][the_index]
     520    del def1[True][the_index]
     521
     522    second = gen_bb(second, opt1[True], def1[True],use1[True])
     523
     524    if the_opt[1] == 'allzero':
     525        cond_obj = bitexpr.isAllZero(the_opt[0])
     526    if the_opt[1] == 'allone':
     527        cond_obj = bitexpr.isAllOne(the_opt[0])
     528
     529    change = count_lines(first) + count_lines(second)
     530    new_def = []
     531    new_use = []
     532
     533    if (False in opt1):
     534        for i in def1[False]:
     535            new_def.append(max(i-change, -1))
     536        for i in use1[False]:
     537            new_use.append(i-change)
     538
     539        third = gen_bb(third, opt1[False], new_def, new_use)
     540
     541    result = first + [bitexpr.If(cond_obj, second, copy.deepcopy(second))]+third
     542
     543    return result
     544
    244545def partition2bb(s):
    245546    basicblocks = []
    246     exprs = []
    247547    lineno = []
    248 
    249     for line, loc in enumerate(s):
    250         if isinstance(loc, bitexpr.Reduce):
    251             lineno.append(line)
    252             exprs.append(loc)
    253     code_segs = break_to_segs(s, lineno)
    254     previous = {}
    255 
    256     for seg in code_segs:
    257         bb = bitstream_compiler.BasicBlock(previous)
    258         bb.normalize(seg)
    259         previous.update(bb.common_expression_map)
    260         basicblocks.append(bb)
    261 
    262     return basicblocks, exprs
     548    #####
     549    def_line = []
     550    lastuse_line = []
     551    opt_list = get_opt_list(s)
     552
     553    st = gen_sym_table(s, True)
     554    st2 = gen_sym_table(s)
     555    total_lines = count_lines(s)
     556
     557    for item in opt_list:
     558        if len(st2[item[0]][0]) > 0:
     559            #The last definition of the variable is extracted
     560            def_line.append(st2[item[0]][0][-1])
     561        else:
     562            #This variable is an input variable and not defined by the programmer
     563            def_line.append(-1)
     564
     565        #if len(st[item[0]][1]) > 0:
     566            #last use of the variables is extracted
     567        #    lastuse_line.append(st[item[0]][1][-1])
     568        #else:
     569            #The variable is not used in the code anywhere, we assume it is needed at the end
     570        lastuse_line.append(total_lines)
     571
     572    return gen_bb(s, opt_list, def_line, lastuse_line)
     573
     574#################################################################################################################
     575## Generating declarations for variables. This pass is based on the syntax of C programming language ---
     576## Normalizing the code
     577#################################################################################################################
     578
     579def normalize(s, predec = {}, ccelim=True):
     580    if len(s)==0:
     581        return []
     582    if isinstance(s[0], bitexpr.If):
     583        gc = basic_block.BasicBlock.gensym_counter
     584        cc = basic_block.BasicBlock.carry_counter
     585        bc = basic_block.BasicBlock.brw_counter
     586        s[0].true_branch = normalize(s[0].true_branch, copy.deepcopy(predec))
     587        maxgc = basic_block.BasicBlock.gensym_counter
     588        maxcc = basic_block.BasicBlock.carry_counter
     589        maxbc = basic_block.BasicBlock.brw_counter
     590        basic_block.BasicBlock.gensym_counter = gc
     591        basic_block.BasicBlock.carry_counter = cc
     592        basic_block.BasicBlock.brw_counter = bc
     593        s[0].false_branch = normalize(s[0].false_branch, copy.deepcopy(predec))
     594        basic_block.BasicBlock.gensym_counter = max(basic_block.BasicBlock.gensym_counter, maxgc)
     595        basic_block.BasicBlock.carry_counter = max(basic_block.BasicBlock.carry_counter, maxcc)
     596        basic_block.BasicBlock.brw_counter = max(basic_block.BasicBlock.brw_counter, maxbc)
     597       
     598        return [s[0]]+normalize(s[1:], copy.deepcopy(predec))
     599    if isinstance(s[0], bitexpr.WhileLoop):
     600        orig = copy.deepcopy(predec)
     601        s[0].stmts = normalize(s[0].stmts, predec, False)
     602        return [s[0]]+normalize(s[1:], orig)
     603    if isinstance(s[0], bitexpr.BitAssign):
     604        code, next = recurse_forward(s)
     605        bb = basic_block.BasicBlock(predec)
     606        predec.update(bb.normalize(code, ccelim))
     607        return bb.get_code() + normalize(next, copy.deepcopy(predec))
    263608
    264609#################################################################################################################
    265610## Generating declarations for variables. This pass is based on the syntax of C programming language --- Pass 5
    266611#################################################################################################################
    267 def get_vars(code):
     612def get_vars(stmt):
    268613        ints = set([])
    269614        bitblocks = set(['AllOne', 'AllZero'])
     
    271616        structs = {}
    272617
    273         for stmt in code:
    274 
    275                 all_vars = [stmt.LHS]+stmt.RHS.vars
    276 
    277                 if isinstance(stmt.RHS, bitexpr.Add) or isinstance(stmt.RHS, bitexpr.Sub):
    278                         ints.add(all_vars.pop())
    279 
    280                 for var in all_vars:
    281                         #print var
    282                         (var_type, name, extra) = parse_var(var.varname)
    283 
    284                         if (var_type == "bitblock"):
    285                                 bitblocks.add(name)
    286 
    287                         if var_type == "array":
    288                                 if not name in arrays:
    289                                         arrays[name]= extra
    290                                 else:
    291                                         arrays[name] = max(arrays[name], extra)
    292 
    293                         if var_type == "struct":
    294                                 if not name in structs:
    295                                         structs[name] = set([extra])
    296                                 else:
    297                                         structs[name].add(extra)
    298                                 #print structs[name]
    299 
     618        #for stmt in code:
     619
     620        all_vars = [stmt.LHS, stmt.RHS.operand1, stmt.RHS.operand2]
     621
     622        if isinstance(stmt.RHS, bitexpr.Add):
     623                ints.add(stmt.RHS.carry)
     624        if isinstance(stmt.RHS, bitexpr.Sub):
     625                ints.add(stmt.RHS.brw)
     626
     627        for var in all_vars:
     628                (var_type, name, extra) = parse_var(var.varname)
     629
     630                if (var_type == "bitblock"):
     631                        bitblocks.add(name)
     632
     633                if var_type == "array":
     634                        if not name in arrays:
     635                                arrays[name]= extra
     636                        else:
     637                                arrays[name] = max(arrays[name], extra)
     638
     639                if var_type == "struct":
     640                        if not name in structs:
     641                                structs[name] = set([extra])
     642                        else:
     643                                structs[name].add(extra)
     644                if var_type == "int":
     645                    ints.add(name)
    300646        return {'int':ints, 'bitblock': bitblocks, 'array': arrays, 'struct': structs}
    301647
    302648def merge_var_dic(first, second):
    303649    res = {}
     650    #print first
     651    #print second
    304652    res['int'] = first['int'].union(second['int'])
    305653    res['bitblock'] = first['bitblock'].union(second['bitblock'])
     
    329677
    330678        for i in var_dic['bitblock']:
    331                 if i == bitstream_compiler.AllOne:
     679                if i == AllOne:
    332680                        s += "BitBlock %s = simd_const_1(1);\n"%i
    333                 elif i==bitstream_compiler.AllZero:
     681                elif i== AllZero:
    334682                        s+="BitBlock %s = simd_const_1(0);\n"%i
    335683                else:
     
    348696        return s
    349697
    350 def gen_declarations(bb):
     698def gen_var_dic(s):
     699    if len(s)==0:
     700        return {'int':set([]), 'bitblock': set([]), 'array': set([]), 'struct': set([])}
     701    if isinstance(s[0], bitexpr.BitAssign):
     702        vd = get_vars(s[0])
     703        more = gen_var_dic(s[1:])
     704        vd = merge_var_dic(vd, more)
     705        return vd
     706
     707    if isinstance(s[0], bitexpr.If):
     708        vd1 = gen_var_dic(s[0].true_branch)
     709        vd2 = gen_var_dic(s[0].false_branch)
     710        vd = merge_var_dic(vd1, vd2)
     711        vd3 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
     712        vd =  merge_var_dic(vd, vd3)
     713        more = gen_var_dic(s[1:])
     714        vd  = merge_var_dic(vd, more)
     715        return vd
     716
     717    if isinstance(s[0], bitexpr.WhileLoop):
     718        vd = gen_var_dic(s[0].stmts)
     719        vd1 = get_vars(bitexpr.BitAssign(s[0].control_expr.var, s[0].control_expr.var))
     720        vd =  merge_var_dic(vd, vd1)
     721        more = gen_var_dic(s[1:])
     722        vd  = merge_var_dic(vd, more)
     723        return vd
     724
     725    """for loc in s[1:]:
     726        more = get_vars(loc)
     727
    351728    vd = get_vars(bb[0].code)
    352729    for block in bb[1:]:
     
    355732    declarations = gen_output(vd)
    356733    return declarations
    357 
    358 #################################################################################################################
    359 ## Converting to condition-tree form --- Pass 6
    360 #################################################################################################################
    361 def construct_tree(all_code, exprs):
    362     if all_code == []:
    363         dummy = bitstream_compiler.BasicBlock()
    364         dummy.code = []
    365         return [None, dummy]
    366 
    367     if len(all_code) == 1+len(exprs):
    368         head = all_code.pop(0)
    369     elif len(all_code) == len(exprs):
    370         dummy = BasicBlock()
    371         dummy.code = []
    372         head = dummy
     734    """
     735
     736def gen_declarations(s):
     737    vd = gen_var_dic(s)
     738    return gen_output(vd)
     739#################################################################################################################
     740## This class replaces all occurences of a reduced variable to its value --- Pass 7
     741## *** This pass should change so that instead of recursing on the tree structure used before, it recurses on ***
     742## *** the new AST notation                                                                                   ***
     743#################################################################################################################
     744def replace_in_rhs(rhs, target, replace):
     745    if isinstance(rhs, bitexpr.Var):
     746        if rhs.varname == target:
     747            rhs = replace
     748            return replace
     749
     750    if isinstance(rhs.operand1, bitexpr.Var):
     751        if rhs.operand1.varname == target:
     752            rhs.operand1 = replace
    373753    else:
    374         assert (1==0)
    375 
    376     if exprs == []:
    377         return [None, head]
    378     first_cond = exprs.pop(0)
    379 
    380     first_def = (None, None) #Basic Block number and line number of first def
    381     last_use = (None, None) #Basic Block number and line number of last_use
    382 
    383     target = first_cond.target
    384 
    385     for bbn, bb in enumerate(all_code):
    386         fdef = bb.get_defs(target)[0]
    387         luse = bb.get_uses(target)[-1]
    388         if not fdef is None and first_def[0] is None:
    389             first_def = (bbn, fdef)
    390         if not luse is None:
    391             last_use = (bbn, luse)
    392 
    393     ################### This few lines were temporary put here to remove the removal of code duplication in the output code#######################
    394     dummy = bitstream_compiler.BasicBlock()
    395     dummy.code = []
    396     return [first_cond, head, construct_tree(copy.deepcopy(all_code), copy.copy(exprs)), construct_tree(all_code, copy.copy(exprs)), [None, dummy]]
    397     ################### This few lines were temporary put here to remove the removal of code duplication in the output code#######################
    398     if first_def[0] is None and last_use[0] is None:
    399         cut = (None, None)
    400         dummy = bitstream_compiler.BasicBlock()
    401         dummy.code = []
    402         return [first_cond, head, construct_tree(copy.deepcopy(all_code), copy.copy(exprs)), construct_tree(all_code, copy.copy(exprs)), [None, dummy]]
     754        if not (isinstance(rhs.operand1, bitexpr.FalseLiteral) or isinstance(rhs.operand1, bitexpr.TrueLiteral)):
     755            rhs.operand1 = replace_in_rhs(rhs.operand1, target, replace)
     756
     757    if isinstance(rhs.operand2, bitexpr.Var):
     758        if rhs.operand2.varname == target:
     759            rhs.operand2 = replace
    403760    else:
    404         if first_def[0] is None:
    405             cut = last_use
    406         else:
    407             cut = first_def
    408         return [first_cond, head, construct_tree(copy.deepcopy(all_code[:cut[0]+1]), copy.copy(exprs[:cut[0]])), construct_tree(all_code[:cut[0]+1], copy.copy(exprs[:cut[0]])), construct_tree(all_code[cut[0]+1:], exprs[cut[0]:])]
    409 
    410 #################################################################################################################
    411 ## This class replaces all occurences of a reduced variable to its value --- Pass 7
    412 #################################################################################################################
    413 class Reducer:
    414     def __init__(self):
    415         pass
    416 
    417     def replace_in_rhs(self, expr, target, replace):
    418         #print target,replace
    419         #print ",,,,",expr.LHS.varname
    420         if replace=='allzero':
    421             replace = bitexpr.FalseLiteral()
    422         if replace == 'allone':
    423             replace = bitexpr.TrueLiteral()
    424 
    425         if isinstance(expr.RHS, bitexpr.Var):
    426             if expr.RHS.varname == target:
    427                 expr.RHS = replace
    428             return
    429 
    430         if expr.RHS.operand1.varname == target:
    431             expr.RHS.operand1 = replace
    432 
    433         if expr.RHS.operand2.varname == target:
    434             expr.RHS.operand2 = replace
    435 
    436     def apply_single_opt(self, expr, tree):
    437 
    438         target = expr.target
    439         replace = expr.replace
    440         assert(replace=='allzero' or replace=='allone')
    441 
    442         for loc in tree[1].code:
    443             if loc.LHS.varname == target:
    444                 break
    445             self.replace_in_rhs(loc, target, replace)
    446         if tree[0] is None:
    447             return
    448         else:
    449             for branch in tree[2:]:
    450                 self.apply_single_opt(expr, branch)
    451 
    452         return
    453 
    454     def apply_all_opt(self, tree):
    455         """
    456           all_code is  list of basic blocks.
    457           exp is a list of optimize pragmas found in the code
    458         """
    459         if tree[0] is None:
    460             return tree
    461         else:
    462             self.apply_single_opt(tree[0], tree[2])
    463             tree = [tree[0], tree[1], self.apply_all_opt(tree[2]), self.apply_all_opt(tree[3]), self.apply_all_opt(tree[4])]
    464             return tree
    465 
    466         return [first_cond, head, self.apply_all_opt(optimized_code, copy.copy(exp)), self.apply_all_opt(all_code, exp)]
     761        if not (isinstance(rhs.operand2, bitexpr.FalseLiteral) or isinstance(rhs.operand2, bitexpr.TrueLiteral)):
     762            rhs.operand2 = replace_in_rhs(rhs.operand2, target, replace)
     763    return rhs
     764
     765def apply_single_opt(code, target, replace):
     766    if len(code) == 0:
     767        return []
     768
     769    if replace=='AllZero':
     770        replace = bitexpr.FalseLiteral()
     771    if replace == 'AllOne':
     772        replace = bitexpr.TrueLiteral()
     773
     774    if isinstance(code[0], bitexpr.BitAssign):
     775        code[0].RHS = replace_in_rhs(code[0].RHS, target, replace)
     776
     777    if isinstance(code[0], bitexpr.If):
     778        if code[0].control_expr.var.varname == target:
     779            code[0].control_expr.var = replace
     780        code[0].true_branch = apply_single_opt(code[0].true_branch, target, replace)
     781        code[0].false_branch = apply_single_opt(code[0].false_branch, target, replace)
     782
     783    if isinstance(code[0], bitexpr.WhileLoop):
     784        if code[0].control_expr.var.varname == target:
     785            code[0].control_expr.var = replace
     786        code[0].stmts = apply_single_opt(code[0].stmts, target, replace)
     787
     788    return [code[0]]+apply_single_opt(code[1:], target, replace)
     789
     790def apply_all_opt(s):
     791    if len(s) == 0:
     792        return []
     793    elif isinstance(s[0], bitexpr.If):
     794        target = s[0].control_expr.var.varname
     795        replace = s[0].control_expr.val
     796        apply_single_opt(s[0].true_branch, target, replace)
     797        apply_all_opt(s[0].true_branch)
     798        apply_all_opt(s[0].false_branch)
     799
     800    elif isinstance(s[0], bitexpr.WhileLoop):
     801        apply_all_opt(s[0].stmts)
     802
     803    return [s[0]]+apply_all_opt(s[1:])
     804 
    467805#################################################################################################################
    468806## Simplifying conditions-tree by applying various optimizations like: constant and copy propagation. --- Pass 8
    469807## method prune is supposed to remove unreachable branches but it is incomplete
    470808#################################################################################################################
     809
    471810def prune(fixed, tree):
    472811    """removes unreachable branches of the tree"""
     
    510849    for i in empty_list:
    511850        del tree[i]
    512 
    513 def simplify_tree(tree, fixed = {}):
     851def filter_fixed(fixed, stmts):
     852    for loc in stmts:
     853        if isinstance(loc, bitexpr.BitAssign) and loc.LHS.varname in fixed:
     854            del fixed[loc.LHS.varname]
     855        if isinstance(loc, bitexpr.If):
     856            fixed = filter_fixed(fixed, loc.true_branch)
     857            fixed = filter_fixed(fixed, loc.false_branch)
     858        if isinstance(loc, bitexpr.WhileLoop):
     859            fixed = filter_fixed(fixed, loc.stmts)
     860    return fixed
     861           
     862def simplify_tree(code, fixed = {}, prev = []):
     863    #print len(code)
     864    #print "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
     865    if len(code) == 0:
     866        return []
     867
    514868    assumptions = {}
    515     fixed.update(tree[1].simplify(fixed))
    516    
    517     if tree[0] is None:
    518         return
    519     else:
    520         #self.prune(fixed, tree)
    521         if tree[0].replace == 'allzero':
    522             assumptions[tree[0].target] = bitexpr.FalseLiteral()
    523         else:
    524             assumptions[tree[0].target] = bitexpr.TrueLiteral()
    525 
    526         assumptions = tree[1].calc_implications(assumptions)
    527         #print "~~~~~~", assumptions
    528         fixed_a = copy.copy(fixed)
    529         fixed_a.update(assumptions)
    530         #print "~~~~~~", fixed_a
    531 
    532         simplify_tree(tree[2], fixed_a)
    533         for branch in tree[2:]:
    534             simplify_tree(branch, copy.copy(fixed))
    535 
    536 #################################################################################################################
    537 ## Dead Code Elimination
    538 #################################################################################################################
    539 
     869
     870    if isinstance(code[0], bitexpr.BitAssign):
     871        this, next = recurse_forward(code)
     872        fixed.update(basic_block.simplify(this, fixed))
     873        #assumptions = basic_block.calc_implications(this, copy.deepcopy(fixed))
     874        #fixed.update(assumptions)
     875        return this+simplify_tree(next, fixed, this)
     876
     877    elif isinstance(code[0], bitexpr.If):
     878        fixed1 = copy.deepcopy(fixed)
     879        fixed2 = copy.deepcopy(fixed)
     880        if isinstance(code[0].control_expr, bitexpr.isAllOne):
     881            assumptions[code[0].control_expr.var.varname] = bitexpr.TrueLiteral()
     882        if isinstance(code[0].control_expr, bitexpr.isAllZero):
     883            assumptions[code[0].control_expr.var.varname] = bitexpr.FalseLiteral()
     884        assumptions = basic_block.calc_implications(copy.deepcopy(prev), assumptions)
     885        #print assumptions, len(prev), prev[0].LHS.varname
     886        fixed1.update(assumptions)
     887        code[0].true_branch = simplify_tree(code[0].true_branch, fixed1)
     888        code[0].false_branch = simplify_tree(code[0].false_branch, fixed2)
     889        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
     890
     891    elif isinstance(code[0], bitexpr.WhileLoop):
     892        fixed = filter_fixed(fixed, code[0].stmts)
     893        fixed1 = copy.deepcopy(fixed)
     894        code[0].stmts = simplify_tree(code[0].stmts, fixed1)
     895        return [code[0]]+simplify_tree(code[1:], fixed, [code[0]])
     896
     897#################################################################################################################
     898## Dead Code Elimination ---- Pass 9
     899## *** Changes required here is the same as Pass 7
     900#################################################################################################################
    540901def check_loc(loc, must_liv):
    541902    if loc.LHS.varname in must_liv:
    542903        if isinstance(loc.RHS, bitexpr.Not):
    543             return set([loc.RHS.operand.varname]), False
     904            return set([loc.RHS.operand1.varname]), False
    544905        elif isinstance(loc.RHS, bitexpr.Var):
    545906            return set([loc.RHS.varname]), False
     
    553914        return set([]), True
    554915
    555 
    556916def remove_copies(bb):
    557917    """removes all copy statements e.g. var1 = var2"""
    558     lhs = [x.LHS.varname for x in bb.code]
    559     rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb.code) if isinstance(loc.RHS, bitexpr.Var)]
     918    lhs = [x.LHS.varname for x in bb]
     919    rhs = [(line, loc.RHS.varname) for line, loc in enumerate(bb) if isinstance(loc.RHS, bitexpr.Var)]
    560920    for i in rhs:
    561921        if i[1] in lhs:
    562922            line = lhs.index(i[1])
    563             bb.code[i[0]].RHS = bb.code[line].RHS
    564 
     923            if i[0] > line:
     924                bb[i[0]].RHS = bb[line].RHS
     925    return bb
    565926
    566927def remove_dead(bb, must_live):
    567     remove_copies(bb)
     928    #eliminates dead code from a basic block
     929    bb = remove_copies(bb)
    568930    my_alives = set([])
    569931    dead = []
    570     for line, loc in reversed(list(enumerate(bb.code))):
     932
     933    for line, loc in reversed(list(enumerate(bb))):
     934        #print line, my_alives.union(must_live)
    571935        new_lives, removable = check_loc(loc, my_alives.union(must_live))
     936
    572937        if removable:
    573938            dead.append(line)
     
    576941
    577942    for i in dead:
    578         del bb.code[i]
    579 
    580     return my_alives
    581 
     943        del bb[i]
     944
     945    return my_alives, bb
    582946
    583947def eliminate_dead_code(tree, must_live):
    584     if not tree[0] is None:
    585         new_alives = set([tree[0].target])
    586         for branch in tree[2:]:
    587             new_alives = new_alives.union(eliminate_dead_code(branch, must_live))
    588         must_live = must_live.union(new_alives)
    589        
    590     #print len(tree[1].code), must_live
     948
     949    if len(tree) == 0:
     950        return [], []
     951
     952    last = len(tree) - 1
     953    new_alives = set([])
     954    bb = []
     955
     956    for loc in tree:
     957        if isinstance(loc, bitexpr.WhileLoop):
     958            if not loc.carry_expr is None:
     959                must_live.add(loc.carry_expr.var.varname)
     960
     961    if isinstance(tree[-1], bitexpr.BitAssign):
     962        first = 0
     963        for i in reversed(range(len(tree))):
     964            if not isinstance(tree[i], bitexpr.BitAssign):
     965                first = i+1
     966                break
     967        new_alives, bb = remove_dead(tree[first:], must_live)
     968        last = first
     969
     970    elif isinstance(tree[-1], bitexpr.If):
     971        new_alives, tree[-1].true_branch = eliminate_dead_code(tree[-1].true_branch, must_live)
     972
     973        new_alives, tree[-1].false_branch = eliminate_dead_code(tree[-1].false_branch, must_live)
     974        bb = [tree[-1]]
     975        new_alives.add(tree[-1].control_expr.var.varname)
     976 
     977    elif isinstance(tree[-1], bitexpr.WhileLoop):
     978        must_live.add(tree[-1].control_expr.var.varname)
     979        new_alives, tree[-1].stmts = eliminate_dead_code(tree[-1].stmts, must_live)
     980        bb = [tree[-1]]
     981 
     982    all_alives = new_alives.union(must_live)
     983    all_lives, new_tree = eliminate_dead_code(tree[:last], all_alives)
     984    tree = new_tree+bb
     985    return all_alives, tree
     986
     987
     988#################################################################################################################
     989## This pass processes the code in the while loop and adds extra code required for handling carry variables
     990#################################################################################################################
     991carry_suffix = "_i"
     992
     993def fix_the_loop(loop):
     994    carries = []
     995    for loc in loop.stmts:
     996        if isinstance(loc.RHS, bitexpr.Add):
     997            carries.append(loc.RHS.carry)
     998    for item in carries:
     999        newvar = bitexpr.Var(item+carry_suffix)
     1000        loop.stmts.append(bitexpr.BitAssign(newvar, bitexpr.Or(newvar, bitexpr.Var(item), "int")))
     1001    for item in carries:
     1002        loop.stmts.append(bitexpr.BitAssign( bitexpr.Var(item), bitexpr.FalseLiteral("int") ))
    5911003   
    592     return remove_dead(tree[1], must_live)
    593     #print len(tree[1].code)
    594     #print "~~~~~~~~~~~~~~~~~"
    595     #print his_alives
    596     #print
    597     #print
    598 
    599 
    600 #################################################################################################################
    601 ## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 9
    602 #################################################################################################################
    603 def generate_if_stmt(expr, indent):
    604     target = expr.target
    605     replace = expr.replace
    606     if replace == "allzero":
    607         return "\n%sif (!bitblock_has_bit(%s))  {\n"%(" "*indent, target)
    608     elif replace == "allone":
    609         return "\n%sif (!bitblock_has_bit(simd_not(%s))) {\n"%(" "*indent, target)
     1004    return loop, carries
     1005
     1006def process_while_loops(code):
     1007    all = []
     1008    update = {}
     1009    for index, loc in enumerate(code):
     1010        if isinstance(loc, bitexpr.WhileLoop):
     1011            update[index] = fix_the_loop(loc)
     1012            for i in update[index][1]:
     1013                all.append(i)
     1014                all.append(i+carry_suffix)
     1015           
     1016    keys = [k for k in update]
     1017    keys.sort(reverse=True)
     1018    for key in keys:
     1019        code[key] = update[key][0]
     1020        carry_variable = bitexpr.Var("CarryTemp"+str(key))
     1021        code[key].stmts.insert(0, bitexpr.BitAssign(carry_variable, bitexpr.FalseLiteral("int")))
     1022        for item in update[key][1][2:]:
     1023            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(carry_variable, bitexpr.Var(item), "int")))
     1024        if len(update[key][1]) == 1:
     1025            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Var(update[key][1][0])))
     1026        elif len(update[key][1]) > 1:
     1027            code.insert(key+1, bitexpr.BitAssign(carry_variable, bitexpr.Or(bitexpr.Var(update[key][1][0]), bitexpr.Var(update[key][1][1]), "int")))
     1028        for item in update[key][1]:
     1029            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item+carry_suffix), bitexpr.FalseLiteral("int")))
     1030        for item in update[key][1]:
     1031            code.insert(key+1, bitexpr.BitAssign(bitexpr.Var(item), bitexpr.Var(item+carry_suffix)))
     1032
     1033        code[key].carry_expr = bitexpr.isNoneZero(carry_variable)
     1034    return code, all
     1035
     1036
     1037
     1038
     1039
     1040
     1041#################################################################################################################
     1042## This pass factors out the code that is common between the true branch and false branch of an if
     1043## statement.
     1044#################################################################################################################
     1045
     1046def are_the_same(one, two):
     1047    if one.__class__ != two.__class__:
     1048        return False
     1049
     1050    if isinstance(one, bitexpr.BitAssign):
     1051        if (one.LHS.varname != two.LHS.varname):
     1052            return False
     1053        if one.RHS.__class__ != two.RHS.__class__:
     1054            return False
     1055        if isinstance(one.RHS, bitexpr.FalseLiteral):
     1056            return True
     1057        if isinstance(one.RHS, bitexpr.TrueLiteral):
     1058            return True
     1059        if isinstance(one.RHS, bitexpr.Var):
     1060            return (one.RHS.varname == two.RHS.varname)
     1061
     1062        return (one.RHS.operand1.varname == two.RHS.operand1.varname)and(one.RHS.operand2.varname == two.RHS.operand2.varname)
     1063
     1064    if isinstance(one, bitexpr.If):
     1065        if (one.control_expr.__class__ != two.control_expr.__class__):
     1066            return False
     1067        if (one.control_expr.var.varname != two.control_expr.var.varname):
     1068            return False
     1069        if (len(one.true_branch) != len(two.true_branch)) or (len(one.false_branch) != len(two.false_branch)):
     1070            return False
     1071        for ind in range(len(one.true_branch)):
     1072            if not are_the_same(one.true_branch[ind], two.true_branch[ind]):
     1073                return False
     1074        for ind in range(len(one.false_branch)):
     1075            if not are_the_same(one.false_branch[ind], two.false_branch[ind]):
     1076                return False
     1077        return True
     1078
     1079    if isinstance(one, bitexpr.WhileLoop):
     1080        if (one.control_expr.__class__ != two.control_expr.__class__):
     1081            return False
     1082        if (one.control_expr.var.varname != two.control_expr.var.varname):
     1083            return False
     1084        if (one.carry_expr.__class__ != two.carry_expr.__class__):
     1085            return False
     1086        if (one.carry_expr.var.varname != two.carry_expr.var.varname):
     1087            return False
     1088        if (len(one.stmts) != len(two.stmts)):
     1089            return False
     1090        for ind in range(len(one.stmts)):
     1091            if not are_the_same(one.stmts[ind], two.stmts[ind]):
     1092                return False
     1093        return True
     1094
     1095def get_factorable(cond):
     1096    earliest = 0
     1097    #print cond.control_expr, cond.control_expr.var.varname
     1098    l = min(len(cond.true_branch), len(cond.false_branch))
     1099    for index in reversed(range(-l,0)):
     1100        if (are_the_same(cond.true_branch[index], cond.false_branch[index])):
     1101            earliest = index
     1102        else:
     1103            break
     1104
     1105    common_length = -earliest
     1106    return common_length
     1107
     1108def get_common_code(cond, common):
     1109    true_len = len(cond.true_branch)-common
     1110    false_len = len(cond.false_branch)-common
     1111
     1112    new_cond=bitexpr.If(cond.control_expr, cond.true_branch[:true_len], cond.false_branch[:false_len])
     1113    common = cond.true_branch[true_len:]
     1114
     1115    return [new_cond], common
     1116
     1117def do_factorization(code, pos):
     1118    indices = [key for key in pos]
     1119    indices.sort()
     1120
     1121    for index in reversed(indices):
     1122        new_cond, common = get_common_code(code[index], pos[index])
     1123        code = code[:index]+new_cond+common+code[index+1:]
     1124    return code
     1125
     1126def factor_out(code):
     1127    for index, loc in enumerate(code):
     1128        if isinstance(loc, bitexpr.If):
     1129            code[index].true_branch = factor_out(code[index].true_branch)
     1130            code[index].false_branch = factor_out(code[index].false_branch)
     1131        if isinstance(loc, bitexpr.WhileLoop):
     1132            code[index].stmts = factor_out(code[index].stmts)
     1133
     1134    pos = {}
     1135    for index, loc in enumerate(code):
     1136        if isinstance(loc, bitexpr.If):
     1137            common_length = get_factorable(code[index])
     1138            pos[index] = common_length
     1139
     1140    return do_factorization(code, pos)
     1141
     1142#################################################################################################################
     1143## Generates C code given conditions-tree, by a recursive traversal of the tree --- Pass 10
     1144## ***Changes needed here are the same as Pass 7 and Pass 10***
     1145#################################################################################################################
     1146
     1147def generate_condition(expr):
     1148    if isinstance(expr, bitexpr.isAllZero):
     1149        return "!bitblock_has_bit(%s)"%(expr.var.varname)
     1150    elif isinstance(expr, bitexpr.isAllOne):
     1151        return "!bitblock_has_bit(simd_not(%s))"%(expr.var.varname)
     1152    elif isinstance(expr, bitexpr.isNoneZero):
     1153        return "bitblock_has_bit(%s)"%(expr.var.varname)
    6101154    else:
     1155        print expr
    6111156        assert (1==0)
    6121157
    613 
    614 def unwind(tree, head = [], indent = 0):
    615     #TODO: There might be a bug here, previous declarations are not passed to CodeGenObject
     1158def generate_statement(expr, indent, cond_stmt):
     1159    if isinstance(expr, bitexpr.isAllZero):
     1160        return "\n%s(%s)  {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
     1161    elif isinstance(expr, bitexpr.isAllOne):
     1162        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
     1163    elif isinstance(expr, bitexpr.isNoneZero):
     1164        return "\n%s(%s) {\n"%(" "*indent+cond_stmt+" ", generate_condition(expr))
     1165    else:
     1166        print expr
     1167        assert (1==0)
     1168
     1169def print_prog(s, indent = 0):
    6161170    indent_unit = 4
    617     s = ""
    618     total = []
    619     if tree[0] is None:
    620             return tree[1].showcode(indent)
    621     else:
    622             bnum = len(tree) - 1 #number of branches
    623             s += tree[1].showcode(indent)
    624             if bnum == 1:
    625                 return s
    626             s += generate_if_stmt(tree[0], indent)
    627             indent += indent_unit
    628             s+=unwind(tree[2], tree[1], indent)
    629             indent -= indent_unit
    630             s+=" "*indent+"}\n"
    631             if bnum == 2:
    632                 return s
    633             s+=" "*indent+"else {\n"
    634             indent += indent_unit
    635             s+=unwind(tree[3], tree[1], indent)
    636             indent -= indent_unit
    637             s+=" "*indent+"}\n"
    638             if bnum == 3:
    639                 return s
    640             s+=unwind(tree[4], tree[1], indent)
    641             return s
    642 
     1171    code = ""
     1172    if len(s) == 0:
     1173        return ""
     1174    if isinstance(s[0], bitexpr.If):
     1175        code = generate_statement(s[0].control_expr, indent, "if")
     1176        code += print_prog(s[0].true_branch, indent+indent_unit)
     1177        code += " "*indent+"}\n"
     1178        code += " "*indent+"else {\n"
     1179        code += print_prog(s[0].false_branch, indent+indent_unit)
     1180        code += " "*indent+"}\n"
     1181
     1182    if isinstance(s[0], bitexpr.WhileLoop):
     1183        #code = generate_statement(s[0].control_expr, indent, "while")
     1184        code = "\n%s(%s|%s) {\n"%(" "*indent+"while"+" ", generate_condition(s[0].control_expr), s[0].carry_expr.var.varname+">0")
     1185        code += print_prog(s[0].stmts, indent+indent_unit)
     1186        code += " "*indent+"}\n"
     1187
     1188    if isinstance(s[0], bitexpr.BitAssign):
     1189        code += " "*indent + s[0].LHS.varname
     1190        code += " = "
     1191
     1192        if s[0].RHS.data_type == "vector":
     1193            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
     1194                code += s[0].RHS.operand1.varname
     1195                code += ";\n"
     1196            elif isinstance(s[0].RHS, bitexpr.TrueLiteral):
     1197                code += s[0].RHS.operand1.varname
     1198                code += ";\n"
     1199            elif isinstance(s[0].RHS, bitexpr.Var):
     1200                code += s[0].RHS.operand1.varname
     1201                code += ";\n"
     1202            elif isinstance(s[0].RHS, bitexpr.Add):
     1203                code += s[0].RHS.op_C + "("
     1204                code += s[0].RHS.operand1.varname
     1205                code += ','
     1206                code += s[0].RHS.operand2.varname
     1207                code += ','
     1208                code += s[0].RHS.carry
     1209                code += ");\n"
     1210
     1211            else:
     1212                code += s[0].RHS.op_C + "("
     1213                code += s[0].RHS.operand1.varname
     1214                code += ','
     1215                code += s[0].RHS.operand2.varname
     1216                code += ");\n"
     1217        elif s[0].RHS.data_type == "int":
     1218            if isinstance(s[0].RHS, bitexpr.Or):
     1219                code += s[0].RHS.operand1.varname
     1220                code += '|'
     1221                code += s[0].RHS.operand2.varname
     1222                code += ";\n"
     1223            if isinstance(s[0].RHS, bitexpr.FalseLiteral):
     1224                code += "0;\n"
     1225
     1226    s.pop(0)
     1227    return code+print_prog(s, indent)
     1228
     1229#################################################################################################################
     1230## Auxiliary Functions
     1231#################################################################################################################
     1232
     1233def get_bb_of(code, index):
     1234    if index < 0 or index >= len(code):
     1235        return None
     1236   
     1237    if isinstance(code[index], bitexpr.If):
     1238        return code[index]
     1239    if isinstance(code[index], bitexpr.WhileLoop):
     1240        return code[index]
     1241    if isinstance(code[index], bitexpr.BitAssign):
     1242        last = index
     1243        for loc in code[index+1:]:
     1244            if isinstance(loc, bitexpr.BitAssign):
     1245                last += 1
     1246            else:
     1247                break
     1248 
     1249        first = index
     1250        for loc in reversed(code[:index]):
     1251            if isinstance(loc, bitexpr.BitAssign):
     1252                first -= 1
     1253            else:
     1254                break
     1255
     1256        return code[first:last+1]
     1257
     1258def get_previous_bb(code, index):
     1259    if index < 0:
     1260        return None
     1261    if index >= len(code):
     1262        return get_bb_of(code, len(code)-1)
     1263    if isinstance(code[index], bitexpr.If) or isinstance(code[index], bitexpr.WhileLoop):
     1264        return get_bb_of(code, index-1)
     1265    if isinstance(code[index], bitexpr.BitAssign):
     1266        for ind, loc in reversed(enumerate(code[:index])):
     1267            if not isinstance(loc, bitexpr.BitAssign):
     1268                return get_bb_of(code, ind)
     1269        return get_bb_of(code, -1)
     1270    assert(1==0)
     1271
     1272def recurse_forward(code):
     1273    if isinstance(code[0], bitexpr.If):
     1274        return code[0], code[1:]
     1275    if isinstance(code[0], bitexpr.WhileLoop):
     1276        return code[0], code[1:]
     1277    if isinstance(code[0], bitexpr.BitAssign):
     1278        last = 0
     1279        for loc in code[1:]:
     1280            if isinstance(loc, bitexpr.BitAssign):
     1281                last += 1
     1282            else:
     1283                break
     1284        return code[0:last+1], code[last+1:]
Note: See TracChangeset for help on using the changeset viewer.