Changeset 880


Ignore:
Timestamp:
Feb 2, 2011, 11:35:01 PM (8 years ago)
Author:
ksherdy
Message:

Add separate compilation functionality.

Location:
proto/Compiler
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • proto/Compiler/compiler2.py

    r865 r880  
    4242        s = ast.parse(input_string)
    4343        # Analysis and Transformation
    44         xfrmr = pablo.MainLoopTransformer(s, "carryQ")
     44        xfrmr = pablo.MainLoopTransformer(s)
     45        xfrmr.gen_globals()
    4546        xfrmr.gen_declarations()
    4647        xfrmr.gen_initializations()
     
    4950        # BACK END
    5051        template_contents = self.read_template()
    51         template_contents = self.output(xfrmr.getCglobals(),
    52                             xfrmr.getCdecls(),
    53                             xfrmr.getCinits(),
    54                             xfrmr.getCstmts(),
     52               
     53        template_contents = self.output(xfrmr.Cglobals,
     54                            xfrmr.Cdecls,
     55                            xfrmr.Cinits,
     56                            xfrmr.Cstmts,
    5557                            xfrmr.any_carry_expr(),
    5658                            template_contents)
  • proto/Compiler/pablo.py

    r865 r880  
    88import ast, copy, sys
    99import Cgen
    10 
    11 
    12 # HELPER functions for AST recognition/construction
    1310
    1411def is_BuiltIn_Call(fncall, builtin_fnname, builtin_arg_cnt, builtin_fnmod_noprefix='bitutil'):
     
    9794    return numnode  # no recursive modifications of index expressions
    9895
    99 
    10096#
    10197#  Generating BitBlock declarations for Local Variables
    102 #
    103 #
    104 class LocalVars(ast.NodeVisitor):
     98#
     99class FunctionVars(ast.NodeVisitor):
     100  def __init__(self,node):
     101        self.params = []
     102        self.stores = []
     103        self.generic_visit(node)
    105104  def visit_Name(self, nm):
    106105    if isinstance(nm.ctx, ast.Param):
     
    108107    if isinstance(nm.ctx, ast.Store):
    109108      if nm.id not in self.stores: self.stores.append(nm.id)
    110   def get(self, node):
    111     self.params=[]
    112     self.stores=[]
    113     self.generic_visit(node)
     109  def getLocals(self):
    114110    return [v for v in self.stores if not v in self.params]
    115111
     
    118114def BitBlock_decls_from_vars(varlist):
    119115  global MAX_LINE_LENGTH
    120   decls = "  BitBlock"
    121   pending = ""
    122   linelgth = 10
    123   for v in varlist:
    124     if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
    125       decls += pending + " " + v
    126       linelgth += len(pending + v) + 1
    127     else:
    128       decls += ";\n  BitBlock " + v
    129       linelgth = 11 + len(v)
    130     pending = ","
    131   decls += ";\n"
     116  decls =  ""
     117  if not len(varlist) == 0:
     118          decls = "             BitBlock"
     119          pending = ""
     120          linelgth = 10
     121          for v in varlist:
     122            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
     123              decls += pending + " " + v
     124              linelgth += len(pending + v) + 1
     125            else:
     126              decls += ";\n             BitBlock " + v
     127              linelgth = 11 + len(v)
     128            pending = ","
     129          decls += ";"
    132130  return decls
    133131
    134132def BitBlock_decls_of_fn(fndef):
    135   return BitBlock_decls_from_vars(LocalVars().get(fndef))
     133  return BitBlock_decls_from_vars(FunctionVars(fndef).getLocals())
    136134
    137135def BitBlock_header_of_fn(fndef):
     
    177175    return node
    178176
    179 
    180177#
    181178# Carry Introduction Transformation
    182179#
    183 
    184180class CarryCounter(ast.NodeVisitor):
    185181  def visit_Call(self, callnode):
     
    197193    self.generic_visit(nodeToVisit)
    198194    return self.carry_count
    199 
    200195
    201196class CarryIntro(ast.NodeTransformer):
     
    321316    class_name = node.name[0].upper() + node.name[1:]
    322317    instance_name = node.name[0].lower() + node.name[1:]
    323     self.Ccode += "struct " + class_name
     318    self.Ccode += "  struct " + class_name
    324319    if self.asType:
    325320            self.Ccode += " {\n"
     
    349344    self.Ccode += ");\n"
    350345
    351 
    352 
    353346#
    354347# Adding Debugging Statements
     
    362355    dump_stmt = mkCallStmt('print_simd_register', [ast.Str(Cgen.py2C().gen(v)), v])
    363356    return [t, dump_stmt]
    364    
    365 #
    366 #  Translate a function
    367 #
    368 
    369 class FunctionXlator(ast.NodeVisitor):
    370   def xlat(self, node):
    371     self.Ccode=""
    372     self.generic_visit(node)
    373     return self.Ccode
    374   def visit_FunctionDef(self, fndef):
    375     AugAssignRemoval().xfrm(fndef)
    376     Bitwise_to_SIMD().xfrm(fndef)
    377     self.Ccode += BitBlock_header_of_fn(fndef) + " {\n"
    378     self.Ccode += BitBlock_decls_of_fn(fndef)
    379     CarryIntro().xfrm_fndef(fndef)
    380     self.Ccode += Cgen.py2C().gen(fndef.body)
    381     self.Ccode += "\n}\n"
    382 
     357
     358class StreamFunctionCarryCounter(ast.NodeVisitor):
     359  def __init__(self):
     360        self.carry_count = {}
     361       
     362  def count(self, node):
     363        self.generic_visit(node)
     364        return self.carry_count
     365                                   
     366  def visit_FunctionDef(self, node):   
     367        type_name = node.name[0].upper() + node.name[1:]                       
     368        self.carry_count[type_name] = CarryCounter().count(node)
     369     
     370class StreamFunctionCallXlator(ast.NodeTransformer):
     371  def __init__(self):
     372        self.stream_function_type_names = []
     373       
     374  def xfrm(self, node, stream_function_type_names):
     375        self.stream_function_type_names = stream_function_type_names
     376        self.generic_visit(node)
     377       
     378  def visit_Call(self, node):   
     379        self.generic_visit(node)
     380
     381        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in self.stream_function_type_names:
     382                        node.func.id = node.func.id[0].lower() + node.func.id[1:] + ".do_block"
     383        return node     
     384               
     385class StreamFunctionVisitor(ast.NodeVisitor):
     386        def __init__(self,node):
     387                self.stream_function_node = {}
     388                self.generic_visit(node)
     389                                                             
     390        def visit_FunctionDef(self, node):                     
     391                key = node.name[0].upper() + node.name[1:]
     392                self.stream_function_node[key] = node
     393               
     394class StreamFunction():
     395        def __init__(self):
     396                self.carry_count = 0
     397                self.type_name = ""
     398                self.instance_name = ""
     399                self.parameters = ""
     400                self.declarations = ""
     401                self.initializations = ""
     402#
     403# TODO Consolidate *all* C code generation into the Emitter class.
     404# TODO Implement 'pretty print' indentation.
     405# TODO Migrate Emiter() class to Emitter module.               
     406class Emitter():
     407
     408        def definition(self, stream_function, icount=0):
     409               
     410                constructor = ""
     411                carry_declaration = ""
     412               
     413                if stream_function.carry_count > 0:
     414                        constructor = self.constructor(stream_function.type_name, stream_function.carry_count)
     415                        carry_declaration = self.carry_declare(stream_function.carry_count)
     416
     417                do_block_function = self.do_block(self.do_block_parameters(stream_function.parameters),
     418                                                stream_function.declarations,
     419                                                stream_function.initializations,
     420                                                stream_function.statements)             
     421                return self.indent(icount) + "struct " + stream_function.type_name + " {" \
     422                + "\n" + self.indent(icount) + constructor \
     423                + "\n" + self.indent(icount) + do_block_function \
     424                + "\n" + self.indent(icount) + carry_declaration \
     425                + "\n" + self.indent(icount) + "};\n\n"
     426
     427        def constructor(self, type_name, carry_count, icount=0):
     428                return self.indent(icount) + "%s() { """ % (type_name) + self.carry_init(carry_count) + " }"
     429                       
     430        def do_block(self, parameters, declarations, initializations, statements, icount=0):
     431                return self.indent(icount) + "void do_block(" + parameters + ") {" \
     432            + "\n" + self.indent(icount) + declarations \
     433                + "\n" + self.indent(icount) + initializations \
     434                + "\n" + self.indent(icount) + statements \
     435                + "\n" + self.indent(icount + 2) + "}"
     436
     437        def declaration(self, type_name, instance_name, icount=0):     
     438                return self.indent(icount) + type_name + " " + instance_name + ";\n"
     439               
     440        def carry_init(self, carry_count, icount=0):   
     441                return self.indent(icount) + "CarryInit(carryQ, %i);" % (carry_count)
     442       
     443        def carry_declare(self, carry_count, icount=0):
     444                return self.indent(icount) + "CarryDeclare(carryQ, %i);" % (carry_count)
     445
     446        def carry_test(self, carry_variable, carry_count, icount=0):
     447                return self.indent(icount) + "CarryTest(%s, 0, %i)" % (carry_variable, carry_count)             
     448               
     449        def indent(self, icount):
     450                s = ""
     451                for i in range(0,icount): s += " "
     452                return s       
     453               
     454        def do_block_parameters(self, parameters):
     455               
     456                do_block_parameters = []
     457               
     458                for name in parameters:
     459                        type_name = name[0].upper() + name[1:]
     460                        argument_name = name[0].lower() + name[1:]
     461                        do_block_parameters.append(type_name)
     462                        do_block_parameters.append(" & ")
     463                        do_block_parameters.append(argument_name)
     464                        do_block_parameters.append(", ")
     465                if len(do_block_parameters) > 0:               
     466                        do_block_parameters.pop()
     467                                       
     468                return "".join(do_block_parameters)             
     469               
    383470def main(infilename, outfile = sys.stdout):
    384471  t = ast.parse(file(infilename).read())
     
    387474
    388475#
    389 #
    390476#  Routines for compatibility with the old compiler/template.
    391477#  Quick and dirty hacks for now - Dec. 2010.
     
    393479
    394480class MainLoopTransformer:
    395   def __init__(self, main_module, carry_var = "carryQ"):
     481  def __init__(self, main_module, main_node_id='Main'):
    396482    self.main_module = main_module
    397     self.main_fn = main_module.body[-1]
    398     assert (isinstance(self.main_fn, ast.FunctionDef))
    399     self.carry_count = CarryCounter().count(self.main_fn)
    400     self.carry_var = carry_var
     483    self.main_node_id = main_node_id
     484   
     485        # Gather and partition function definition nodes.
     486    stream_function_visitor = StreamFunctionVisitor(self.main_module)
     487    self.stream_function_node = stream_function_visitor.stream_function_node
     488    self.main_node = self.stream_function_node[main_node_id]
     489    self.main_carry_count = CarryCounter().count(self.main_node)
     490    del self.stream_function_node[self.main_node_id]
     491   
     492    self.stream_functions = {}
     493    for key, node in self.stream_function_node.iteritems():
     494                stream_function = StreamFunction()
     495                stream_function.carry_count = CarryCounter().count(node)
     496                stream_function.type_name = node.name[0].upper() + node.name[1:]
     497                stream_function.instance_name = node.name[0].lower() + node.name[1:]
     498                stream_function.parameters = FunctionVars(node).params
     499                stream_function.declarations = BitBlock_decls_of_fn(node)
     500                stream_function.initializations = StreamInitializations().xfrm(node)
     501               
     502                AugAssignRemoval().xfrm(node)
     503                Bitwise_to_SIMD().xfrm(node)
     504                CarryIntro().xfrm_fndef(node)
     505               
     506                stream_function.statements = Cgen.py2C(4).gen(node.body)
     507                self.stream_functions[stream_function.type_name] = stream_function
     508       
     509    self.emitter = Emitter()
     510   
    401511  def any_carry_expr(self):
    402     if self.carry_count == 0: return "1"
    403     else: return "CarryTest(%s, 0, %i)\n" % (self.carry_var, self.carry_count)
    404   def gen_declarations(self):
     512       
     513        carry_test = []
     514       
     515        if self.main_carry_count > 0:
     516                        carry_test.append(self.emitter.carry_test('carryQ', self.main_carry_count))
     517                        carry_test.append(" || ")
     518
     519        for key in self.stream_functions.keys():               
     520                if self.stream_functions[key].carry_count > 0:
     521                        carry_test.append(self.emitter.carry_test(self.stream_functions[key].instance_name + ".carryQ", self.stream_functions[key].carry_count))
     522                        carry_test.append(" || ")
     523
     524        if len(carry_test) > 0:
     525                carry_test.pop()
     526        return "".join(carry_test)
     527   
     528        return "1"
     529
     530  def gen_globals(self):
    405531    self.Cglobals = StreamStructGen().gen_struct_types(self.main_module)
     532    for key in self.stream_functions.keys():
     533                self.Cglobals += Emitter().definition(self.stream_functions[key],2)       
     534                       
     535  def gen_declarations(self): 
    406536    self.Cdecls = StreamStructGen().gen_struct_vars(self.main_module)
    407     self.Cdecls += BitBlock_decls_of_fn(self.main_fn)
    408     if self.carry_count > 0: self.Cdecls += "CarryDeclare(%s, %i);\n" % (self.carry_var, self.carry_count)
     537    self.Cdecls += BitBlock_decls_of_fn(self.main_node)
     538    if self.main_carry_count > 0:
     539        self.Cdecls += self.emitter.carry_declare(self.main_carry_count)
     540               
    409541  def gen_initializations(self):
    410542    self.Cinits = ""
    411     if self.carry_count > 0: self.Cinits += "CarryInit(%s, %i);\n" % (self.carry_var, self.carry_count)
     543    if self.main_carry_count > 0:
     544        self.Cinits += self.emitter.carry_init(self.main_carry_count)
    412545    self.Cinits += StreamInitializations().xfrm(self.main_module)
     546   
     547    for key in self.stream_functions.keys():
     548                self.Cinits += self.emitter.declaration(self.stream_functions[key].type_name, self.stream_functions[key].instance_name, 2)
     549                       
    413550  def xfrm_block_stmts(self, add_dump_stmts=False):
    414     AugAssignRemoval().xfrm(self.main_fn)
    415     Bitwise_to_SIMD().xfrm(self.main_fn)
    416     CarryIntro().xfrm_fndef(self.main_fn)
    417     if add_dump_stmts: Add_SIMD_Register_Dump().xfrm(self.main_fn)
     551    AugAssignRemoval().xfrm(self.main_node)
     552    Bitwise_to_SIMD().xfrm(self.main_node)
     553    CarryIntro().xfrm_fndef(self.main_node)
     554    if add_dump_stmts:
     555        Add_SIMD_Register_Dump().xfrm(self.main_node)
     556               
     557    StreamFunctionCallXlator().xfrm(self.main_node, self.stream_function_node.keys())   
     558    self.Cstmts = Cgen.py2C().gen(self.main_node.body)
     559   
    418560  def add_loop_carryQ_adjust(self):
    419     self.main_fn.body += [mkCallStmt('CarryQ_Adjust', [ast.Name(self.carry_var, ast.Load()), ast.Num(self.carry_count)])]
    420   def getCglobals(self):
    421     return self.Cglobals
    422   def getCdecls(self):
    423     return self.Cdecls
    424   def getCinits(self):
    425     return self.Cinits
    426   def getCstmts(self):
    427     return Cgen.py2C().gen(self.main_fn.body)
    428 
     561    self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
     562   
    429563if __name__ == "__main__":
    430564                import doctest
    431565                doctest.testmod()
    432 
    433 
Note: See TracChangeset for help on using the changeset viewer.