Changeset 2636


Ignore:
Timestamp:
Nov 14, 2012, 9:06:32 PM (6 years ago)
Author:
cameron
Message:

Experimental version progress: carryInfoSet, more ccgo methods

File:
1 edited

Legend:

Unmodified
Added
Removed
  • proto/Compiler/pablo.py

    r2631 r2636  
    1616
    1717def isCarryGenerating(builtin_fn):
    18    return builtin_fn in ['Advance', 'ScanThru', 'ScanTo', 'AdvanceThenScanThru', 'AdvanceThenScanTo', 'SpanUpTo', 'InclusiveSpan', 'ExclusiveSpan', 'ScanToFirst']
     18   return builtin_fn in ['ScanThru', 'ScanTo', 'AdvanceThenScanThru', 'AdvanceThenScanTo', 'SpanUpTo', 'InclusiveSpan', 'ExclusiveSpan', 'ScanToFirst']
    1919def usesCarryInit1(builtin_fn):
    2020   return builtin_fn in ['ScanToFirst']
    21 def usesCarryCountArgument(builtin_fn):
     21def isAdvance(builtin_fn):
    2222   return builtin_fn in ['Advance']
    2323
    24 def GetBuiltinFn(fncall, builtin_fnmod_noprefix='pablo'):
     24
     25def CheckForBuiltin(fncall, builtin_fnmod_noprefix='pablo'):
    2526  if not isinstance(fncall, ast.Call): return None
    2627  if isinstance(fncall.func, ast.Name): fn_name = fncall.func.id
     
    2930    else: return None
    3031  else: return None
    31   if isCarryGenerating(fn_name): return fn_name
     32  if isCarryGenerating(fn_name) or isAdvance(fn_name): return fn_name
    3233  else: return None
    33 
    3434
    3535def CarryCountOfFn(fncall, builtin_fnmod_noprefix='pablo'):
     
    4040    else: return 0
    4141  else: return 0
    42   if usesCarryCountArgument(fn_name):
     42  if isAdvance(fn_name):
    4343    if len(fncall.args) == 1: return 1
    4444    else: return 0 #  return fncall.args[1].n  # Possibly count Advance(m, n) as generating n carries.
     
    7373  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
    7474  return ast.Expr(ast.Call(fn_name, args, [], None, None))
     75 
     76 
     77#
     78# Carry Info Set
     79#
     80class CarryInfoSetVisitor(ast.NodeVisitor):
     81  def __init__(self):
     82    self.operation_no = 0
     83    self.block_no = 0
     84    self.next_block_no = 1
     85   
     86    self.parent_block = {}
     87    self.block_first_op = {}
     88    self.block_op_count = {}
     89    self.advance_amount = {}
     90    self.init_one_list = []
     91   
     92  def visit_Call(self, callnode):
     93    self.generic_visit(callnode)
     94    builtin = CheckForBuiltin(callnode)
     95    if builtin == None: return
     96    if isCarryGenerating(builtin):
     97      if usesCarryInit1(builtin): self.init_one_list.append(self.operation_no)
     98      self.operation_no += 1
     99    elif isAdvance(builtin):
     100      if len(callnode.args) > 1:
     101        adv_amount = callnode.args[1].n
     102      else: adv_amount = 1
     103      self.advance_amount[self.operation_no] = adv_amount
     104      self.operation_no += 1
     105    else: return
     106
     107  def visit_If(self, ifNode):
     108    this_block_no = self.next_block_no
     109    self.next_block_no += 1
     110    self.parent_block[this_block_no] = self.block_no
     111    self.block_no = this_block_no
     112    self.block_first_op[this_block_no] = self.operation_no
     113    self.generic_visit(ifNode)
     114   
     115    self.block_op_count[this_block_no] = self.operation_no - self.block_first_op[this_block_no]
     116
     117    # reset
     118    self.block_no = self.parent_block[this_block_no]
     119 
     120  def visit_While(self, whileNode):
     121    this_block_no = self.next_block_no
     122    self.next_block_no += 1
     123    self.parent_block[this_block_no] = self.block_no
     124    self.block_no = this_block_no
     125    self.block_first_op[this_block_no] = self.operation_no
     126    self.generic_visit(whileNode)
     127    self.block_op_count[this_block_no] = self.operation_no - self.block_first_op[this_block_no]
     128    # reset
     129    self.block_no = self.parent_block[this_block_no]
     130   
     131  def getInfoSet(self, nodeToVisit):
     132    self.operation_no = 0
     133    self.block_no = 0
     134    self.next_block_no = 1
     135   
     136    self.parent_block = {}
     137    self.block_first_op = {}
     138    self.block_op_count = {}
     139    self.advance_amount = {}
     140    self.init_one_list = []
     141   
     142    self.generic_visit(nodeToVisit)
     143    self.block_first_op[0] = 0
     144    self.block_op_count[0] = self.operation_no
     145    return self
     146
     147  def countBlockCarrysWithAdv1(self, blk):
     148    op_count = self.block_op_count[blk]
     149    if op_count == 0: return 0
     150    carries = 0
     151    for op in range(self.block_first_op[blk], self.block_first_op[blk] + op_count):
     152      if op not in self.advance_amount.keys(): carries += 1
     153      elif self.advance_amount[op] == 1: carries += 1
     154    return carries
    75155
    76156#
     
    406486
    407487
    408 
    409 
    410488class testCCGO():
    411     def GenerateCarryInAccess(self, carry_group_var, carry_index):
    412         return mkCall(carry_group_var + "." + 'get_carry_in', [ast.Num(carry_index)])
    413     def GenerateCarryOutStore(self, carry_group_var, carry_index, carry_out_expr):
    414         return ast.Assign([ast.Subscript(ast.Attribute(ast.Name(carry_group_var, ast.Load()), 'cq', ast.Load()), ast.Index(ast.Num(carry_index)), ast.Store())],
     489    def __init__(self, carryInfoSet, carryGroupVar='carryQ'):
     490        self.carryInfoSet = carryInfoSet
     491        self.carryGroupVar = carryGroupVar
     492        self.carryIndex = {}
     493        self.operation_offset = 0
     494        carry_counter = 0
     495        for op_no in range(carryInfoSet.operation_no):
     496          self.carryIndex[op_no] = carry_counter
     497          if not op_no in carryInfoSet.advance_amount.keys(): carry_counter += 1
     498          elif carryInfoSet.advance_amount[op_no] == 1: carry_counter += 1
     499        # Add a dummy entry for any possible final block that is empty.
     500        self.carryIndex[carryInfoSet.operation_no] = carry_counter
     501         
     502    def GenerateCarryInAccess(self, operation_no):
     503        carry_index = self.carryIndex[operation_no - self.operation_offset]
     504        return mkCall(self.carryGroupVar + "." + 'get_carry_in', [ast.Num(carry_index)])
     505    def GenerateCarryOutStore(self, operation_no, carry_out_expr):
     506        carry_index = self.carryIndex[operation_no - self.operation_offset]
     507        return ast.Assign([ast.Subscript(ast.Attribute(ast.Name(self.carryGroupVar, ast.Load()), 'cq', ast.Load()), ast.Index(ast.Num(carry_index)), ast.Store())],
    415508                          mkCall("bitblock::srli<127>", [carry_out_expr]))
     509    def GenerateCarryIfTest(self, block_no, ifTest):
     510        carry_count = self.carryInfoSet.block_op_count[block_no]
     511        if carry_count == 0: return ifTest
     512        ifIndex = self.carryIndex[self.carryInfoSet.block_first_op[block_no]]       
     513        return ast.BoolOp(ast.Or(), [ifTest, mkCall(ast.Attribute(ast.Name(self.carryGroupVar, ast.Load()), 'CarryTest', ast.Load()), [ast.Num(ifIndex), ast.Num(carry_count)])])
     514    def GenerateCarryWhileTest(self, whileIndex): pass
     515    def EnterLocalWhileBlock(self, operation_offset): 
     516        self.carryGroupVar = "sub" + self.carryGroupVar
     517        self.operation_offset = operation_offset
     518    def ExitLocalWhileBlock(self): 
     519        self.operation_offset = 0
     520        self.carryGroupVar = self.carryGroupVar[3:]
     521    def GenerateStreamFunctionFinalization(self):
     522        carry_count = self.carryInfoSet.countBlockCarrysWithAdv1(0)
     523        if carry_count == 0: return []
     524        else: return [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(carry_count)])]
     525
     526def Strategic_CCGO_Factory(carryInfoSet):
     527    ccgo = testCCGO(carryInfoSet, 'carryQ')
     528    return ccgo
    416529
    417530
    418531class CarryIntro(ast.NodeTransformer):
    419   def __init__(self, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
     532  def __init__(self, ccgo, carryvar="carryQ", carryin = "_ci", carryout = "_co"):
    420533    self.carryvar = ast.Name(carryvar, ast.Load())
    421534    self.carryin = carryin
    422535    self.carryout = carryout
    423     self.ccgo = testCCGO()
     536    self.ccgo = ccgo
    424537  def xfrm_fndef(self, fndef):
     538    self.block_no = 0
     539    self.operation_no = 0
    425540    self.current_carry = 0
    426541    self.current_adv_n = 0
     
    429544    self.generic_visit(fndef)
    430545  def xfrm_fndef_final(self, fndef):
     546    self.block_no = 0
     547    self.operation_no = 0
    431548    self.carryout = ""
    432549    self.current_carry = 0
     
    439556    return fndef
    440557  def generic_xfrm(self, node):
     558    self.block_no = 0
     559    self.operation_no = 0
    441560    self.current_carry = 0
    442561    self.current_adv_n = 0
     
    450569   
    451570  def local_while_xfrm(self, local_carryvar, whileNode):
    452     saved_state = (self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n)
    453     (self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n) = (local_carryvar, '', self.carryout, 0, 0)
     571    saved_state = (self.block_no, self.operation_no, self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n)
     572    (self.carryvar, self.carryin, self.current_carry, self.current_adv_n) = (local_carryvar, '', 0, 0)
     573    self.ccgo.EnterLocalWhileBlock(self.operation_no);
    454574    inner_while = self.generic_visit(whileNode)
    455     (self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n) = saved_state
     575    self.ccgo.ExitLocalWhileBlock();
     576    (self.block_no, self.operation_no, self.carryvar, self.carryin, self.carryout, self.current_carry, self.current_adv_n) = saved_state
    456577    return inner_while
    457578   
     
    587708  def visit_Assign(self, assigNode):
    588709    self.last_stmt_carries = CarryCounter().count(assigNode)
    589     f = GetBuiltinFn(assigNode.value)
    590     if f == None:
     710    f = CheckForBuiltin(assigNode.value)
     711    if f == None or not experimentalMode:
    591712            self.generic_visit(assigNode)
    592713            self.last_stmt = assigNode
    593714            return assigNode
    594     elif isCarryGenerating(f) and CarryCountOfFn(assigNode.value) == 1 and experimentalMode:
     715    elif isCarryGenerating(f) or (isAdvance(f) and ((len(assigNode.value.args) == 1) or (assigNode.value.args[1].n==1))):
    595716    # We have an assignment v = pablo.SomeCarryGeneratingFunction()
    596717    #elif f == 'ScanThru':
    597718            if self.carryin == "_ci":
    598                 carry_in_expr = self.ccgo.GenerateCarryInAccess(self.carryvar.id, self.current_carry)
     719                carry_in_expr = self.ccgo.GenerateCarryInAccess(self.operation_no)
    599720            else:
    600721                carry_in_expr = mkCall('simd<1>::constant<0>', [])
    601722            callnode = assigNode.value
    602             if f in ['ScanTo', 'AdvanceThenScanTo']:
     723            if isAdvance(f):
     724               pablo_routine_call = mkCall('pablo_blk_' + f, [assigNode.value.args[0], carry_in_expr, assigNode.targets[0]])
     725            elif f in ['ScanTo', 'AdvanceThenScanTo']:
    603726               if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
    604727               else: scanclass = mkCall('simd_not', [callnode.args[1]])
     
    607730               pablo_routine_call = mkCall('pablo_blk_' + f, assigNode.value.args + [carry_in_expr, assigNode.targets[0]])
    608731            self.last_stmt = pablo_routine_call
    609             compiled = self.ccgo.GenerateCarryOutStore(self.carryvar.id,  self.current_carry, pablo_routine_call)
     732            compiled = self.ccgo.GenerateCarryOutStore(self.operation_no, pablo_routine_call)
     733            self.operation_no += 1
    610734            self.current_carry += 1
    611735            return compiled
     
    613737            self.generic_visit(assigNode)
    614738            self.last_stmt = assigNode
     739            self.operation_no += 1
    615740            return assigNode
    616 
     741           
     742           
    617743  def visit_If(self, ifNode):
     744    self.block_no += 1
     745    new_test = self.ccgo.GenerateCarryIfTest(self.block_no, ifNode.test)
    618746    carry_base = self.current_carry
    619747    carries = CarryCounter().count(ifNode)
     
    625753    #CARRYSET
    626754    carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
    627     new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
     755    #new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
    628756    new_else_part = ifNode.orelse + [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]
    629757    newIf = ast.If(new_test, ifNode.body, new_else_part)
     
    659787  def visit_While(self, whileNode):
    660788    # Determine the original test expression, now encloded in bitblock::any()
     789    self.block_no += 1
    661790    original_test_expr = whileNode.test.args[0]
    662791    if self.carryout == '':
     
    10841213                AugAssignRemoval().xfrm(node)
    10851214
     1215                carry_info_set = CarryInfoSetVisitor().getInfoSet(node)
     1216                ccgo = Strategic_CCGO_Factory(carry_info_set)
    10861217               
    10871218                Bitwise_to_SIMD().xfrm(node)
     
    10901221                        carryQname = stream_function.instance_name + ".carryQ"
    10911222                else: carryQname = "carryQ"
    1092                 CarryIntroVisitor = CarryIntro(carryQname)
     1223                CarryIntroVisitor = CarryIntro(ccgo, carryQname)
    10931224                CarryIntroVisitor.xfrm_fndef(node)
    10941225                CarryIntroVisitor.xfrm_fndef_final(final_block_node)
     
    11021233                        Add_Assert_BitBlock_Align().xfrm(final_block_node)
    11031234
    1104                 if stream_function.carry_count > 0:
    1105                         node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
     1235                #if stream_function.carry_count > 0:
     1236                #       node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(stream_function.carry_count)])]
     1237                node.body += ccgo.GenerateStreamFunctionFinalization()
     1238
    11061239               
    11071240                stream_function.statements = Cgen.py2C(4).gen(node.body)
     
    11651298    Bitwise_to_SIMD().xfrm(self.main_node)
    11661299    Bitwise_to_SIMD().xfrm(self.main_node)
    1167     AssertCompiler().xfrm(self.main_node)
    11681300    final_block_main = copy.deepcopy(self.main_node)
    1169     CarryIntroVisitor = CarryIntro()
     1301    carry_info_set = CarryInfoSetVisitor().getInfoSet(self.main_node)
     1302    ccgo = Strategic_CCGO_Factory(carry_info_set)
     1303    CarryIntroVisitor = CarryIntro(ccgo)
    11701304    CarryIntroVisitor.xfrm_fndef(self.main_node)
    11711305    CarryIntroVisitor.xfrm_fndef_final(final_block_main)
     1306    AssertCompiler().xfrm(self.main_node)
    11721307    if self.add_dump_stmts:
    11731308        Add_SIMD_Register_Dump().xfrm(self.main_node)
     
    11821317    StreamFunctionCallXlator('final').xfrm(final_block_main, self.stream_function_node.keys(), self.use_C_syntax)
    11831318   
    1184     if self.main_carry_count > 0:
     1319    #if self.main_carry_count > 0:
    11851320                #self.main_node.body += [mkCallStmt('CarryQ_Adjust', [ast.Name('carryQ', ast.Load()), ast.Num(self.main_carry_count)])]
    1186                 self.main_node.body += [mkCallStmt('carryQ.CarryQ_Adjust', [ast.Num(self.main_carry_count)])]
    1187     
     1321    self.main_node.body += ccgo.GenerateStreamFunctionFinalization()
     1322   
    11881323   
    11891324       
Note: See TracChangeset for help on using the changeset viewer.