source: proto/Compiler/CCGO_HMCPS.py @ 2703

Last change on this file since 2703 was 2703, checked in by cameron, 6 years ago

Various fixes

File size: 20.3 KB
Line 
1#
2# CCGO_HMCPS.py
3#
4# Carry Code Generator Object using Hierarchical Merging Carry Pack Strategy
5#
6# Robert D. Cameron
7# November 26, 2012
8# Licensed under Open Software License 3.0
9#
10import ast
11import CCGO
12
13# Copyright 2012, Robert D. Cameron
14# All rights reserved.
15#
16# Helper functions
17#
18def TestHelper_Bitblock_Or(testExpr, bitBlockExpr):
19    assert isinstance(testExpr, ast.Call)
20    assert isinstance(testExpr.func, ast.Name)
21    assert testExpr.func.id == 'bitblock::any'
22    testExpr.args[0] = make_call('simd_or', [bitBlockExpr, testExpr.args[0]])
23    return testExpr
24
25def TestHelper_Integer_Or(testExpr, intExpr):
26    return ast.BinOp(testExpr, ast.BitOr(), intExpr)
27
28def mk_var(var, mode=ast.Load()):
29  if isinstance(var, str): 
30        var = ast.Name(var, mode)
31  return var
32 
33def make_mergeh(fw, x, y):
34  #return "esimd<%i>::mergeh(%s, %s)" % (fw, x, y)
35  return make_call("esimd<%i>::mergeh" % fw, [mk_var(x), mk_var(y)])
36
37def make_assign(var, expr):
38  #return "%s = %s;\n" % (v, expr)
39  if isinstance(var, str): 
40        var = ast.Name(var, ast.Store())
41  return ast.Assign([var], expr)
42
43def make_zero(fw):
44  #return "simd<%i>::constant<0>() % fw
45  return make_call("simd<%i>::constant<0>" % fw, [])
46
47def make_index_load(var, num):
48  if isinstance(var, str): 
49        var = ast.Name(var, ast.Load())
50  return ast.Subscript(var, ast.Index(ast.Num(num)), ast.Load())
51
52def make_index_store(var, num):
53  if isinstance(var, str): 
54        var = ast.Name(var, ast.Load())
55  return ast.Subscript(var, ast.Index(ast.Num(num)), ast.Store())
56
57def make_att_load(var, att):
58  if isinstance(var, str): 
59        var = ast.Name(var, ast.Load())
60  return ast.Attribute(var, att, ast.Load())
61
62def make_att_store(var, att):
63  if isinstance(var, str): 
64        var = ast.Name(var, ast.Load())
65  return ast.Attribute(var, att, ast.Store())
66
67def make_call(fn_name, args):
68  if isinstance(fn_name, str): 
69        fn_name = ast.Name(fn_name, ast.Load())
70  return ast.Call(fn_name, args, [], None, None)
71
72def make_callStmt(fn_name, args):
73  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
74  return ast.Expr(ast.Call(fn_name, args, [], None, None))
75 
76#
77#
78# Carry Pack Assignment Strategy
79#
80# The hierarchical merging carry pack strategy packs carries
81# into groups of 2, 4, 8 and 16.   For example, to pack
82# 4 carries c0, c1, c2, and c3 into the 32-bit fields of
83# a 128-bit register, the following operations are used.
84#
85# c0 = pablo.SomeCarryGeneratingFn(...)
86# c1 = pablo.SomeCarryGeneratingFn(...)
87# c1_0 = esimd::mergeh<32>(c1, c0)
88# c2 = pablo.SomeCarryGeneratingFn(...)
89# c3 = pablo.SomeCarryGeneratingFn(...)
90# c3_2 = esimd::mergeh<32>(c3, c2)
91# c3_0 = esimd::mergeh<64>(c3_2, c1_0)
92#
93#
94# Packing operations are generated sequentially when
95# the appropriate individual carries or subpacks become
96# available.   
97#
98# Generate the packing operations assuming that the
99# carry_num carry has just been generated.
100#
101def gen_carry_pack(pack_fw, carry_num, temp_pfx):
102  # The range of carries now to be packed depends on
103  # the number of rightmost contiguous 1 bits
104  carry_range_size = (carry_num + 1) &~ carry_num
105  assign_list = []
106  i = 2
107  v1 = temp_pfx + repr(carry_num)
108  v0 = temp_pfx + repr(carry_num ^ 1)
109  fw = pack_fw
110  while i <= carry_range_size:
111    p = '%s%i_%i' % (temp_pfx, carry_num, carry_num - i + 1)
112    assign_list.append(make_assign(p, make_mergeh(fw, v1, v0)))
113    v1 = p
114    v0 = '%s%i_%i' % (temp_pfx, carry_num - i, carry_num - 2*i+1)
115    i *= 2
116    fw *= 2
117  return assign_list
118
119#
120# Pack in a zero carry value
121#
122def gen_carry_zero_then_pack(pack_fw, carry_num, temp_pfx):
123  # The range of carries now to be packed depends on
124  # the number of rightmost contiguous 1 bits
125  carry_range_size = (carry_num + 1) &~ carry_num
126  i = 2
127  v1 = temp_pfx + repr(carry_num)
128  v0 = temp_pfx + repr(carry_num ^ 1)
129  fw = pack_fw
130  assign_list = [make_assign(v1, make_zero(fw))]
131  while i <= carry_range_size:
132    p = '%s%i_%i' % (temp_pfx, carry_num, carry_num - i + 1)
133    assign_list.append(make_assign(p, make_mergeh(fw, v1, v0)))
134    v1 = p
135    v0 = '%s%i_%i' % (temp_pfx, carry_num - i, carry_num - 2*i+1)
136    i *= 2
137    fw *= 2
138  return assign_list
139
140
141
142#
143# Carry Storage/Access
144#
145# Carries are stored in one or more ubitblocks as byte values.
146# For each block, the carry count is rounded up to the nearest power of 2 ceiling P,
147# so that the carry test for that block is accessible as a single value of P bytes.
148# Packs of 1, 2, 4 or 8 carries are respectively represented
149# as one or more _8, _16, _32 or _64 values.  (Members of ubitblock union.)
150#
151#
152# Allocation phase determines the ubitblock_no and count for each block.
153
154#  carry-in access is a byte load  carryG[packno]._8[offset]
155#  carryout store is to a local pack var until we get to the final byte of a pack
156#
157#  if-test: let P be pack_size in {1,2,4,8,...}
158#    if P <= 8, use an integer test expression cG[packno]._%i % (P * 8)[block_offset]
159#     
160#  while test similar
161#    local while decl: use a copy of carryGroup
162#    while finalize  carry combine:   round up and |= into structure
163#
164def pow2ceil(n):
165   c = 1
166   while c < n: c *= 2 
167   return c
168
169def pow2floor(n):
170   c = 1
171   while c <= n: c *= 2 
172   return c/2
173
174def low_bit(n):
175   return n - (n & (n-1))
176   
177def align(n, align_base):
178  return ((n + align_base - 1) / align_base) * align_base
179
180def determine_aligned_block_sizes(pack_size, cis):
181  aligned_size = {}
182  for i in range(cis.block_count): aligned_size[i] = 0
183  seen = []
184  for i in range(cis.block_count):
185    # Work backwards to process all child blocks before the parent
186    # so that the parent incorporates the updated child counts.
187    b = cis.block_count - i - 1
188    b_carries = 0
189    op = cis.block_first_op[b]
190    while op < cis.block_first_op[b] + cis.block_op_count[b]:
191      sb = cis.containing_block[op]
192      if sb == b:
193        if op not in cis.advance_amount.keys(): b_carries += 1
194        elif cis.advance_amount[op] == 1: b_carries += 1
195        op += 1
196      else: 
197        align_base = aligned_size[sb]
198        if align_base > pack_size: align_base = pack_size
199        b_carries = align(b_carries, align_base)
200        b_carries += aligned_size[sb]
201        op += cis.block_op_count[sb]
202 #   if cis.whileblock[b] or aligned_size[b] > pack_size:
203    if b_carries > pack_size:
204      aligned_size[b] = align(b_carries, pack_size)
205    else:
206      aligned_size[b] = pow2ceil(b_carries)
207  return aligned_size
208 
209MAX_LINE_LENGTH = 80
210
211def BitBlock_decls_from_vars(varlist):
212  global MAX_LINE_LENGTH
213  decls =  ""
214  if not len(varlist) == 0:
215          decls = "             BitBlock"
216          pending = ""
217          linelgth = 10
218          for v in varlist:
219            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
220              decls += pending + " " + v
221              linelgth += len(pending + v) + 1
222            else:
223              decls += ";\n             BitBlock " + v
224              linelgth = 11 + len(v)
225            pending = ","
226          decls += ";"
227  return decls
228 
229def block_contains(b0, b1, parent_block_map):
230  if b0 == b1: return True
231  elif b1 == 0: return False
232  else: return block_contains(b0, parent_block_map[b1], parent_block_map)
233 
234class HMCPS_CCGO(CCGO.CCGO):
235    def __init__(self, fw, carryInfoSet, carryGroupVarName='carryG', temp_prefix='__c'):
236        self.fw = fw
237        self.field_count = 128/fw
238        self.carryInfoSet = carryInfoSet
239        self.carryGroupVar = carryGroupVarName
240        self.temp_prefix = temp_prefix
241        self.aligned_size = determine_aligned_block_sizes(self.field_count, carryInfoSet)
242        self.ubitblock_count = (self.aligned_size[0] + self.field_count - 1) / self.field_count
243        self.alloc_map = {}
244        self.alloc_map[0] = 0
245        self.block_base = {}
246        self.allocate_ops()
247        # carry_offset is used within the inner body of while loops to access local carries.
248        # The calculated (ub, rp) value is reduced by this amount for the local carry group(s).
249        self.carry_offset = 0
250
251    def cg_temp(self, hi_carry, lo_carry = None):
252      if lo_carry == None or hi_carry == lo_carry: return "%s%i" % (self.temp_prefix, hi_carry)
253      else: return "%s%i_%i" % (self.temp_prefix, hi_carry, lo_carry)
254   
255    def local_temp(self, hi_carry, lo_carry = None):
256      if lo_carry == None or hi_carry == lo_carry: return "sub%s%i" % (self.temp_prefix, hi_carry)
257      else: return "sub%s_%i_%i" % (self.temp_prefix, hi_carry, lo_carry)
258   
259    def gen_merges(self, carry_last, carry_base):
260      size = carry_last - carry_base + 1
261      if carry_last & size: 
262        v1 = mk_var(self.cg_temp(carry_last, carry_base))
263        v0 = mk_var(self.cg_temp(carry_last - size, carry_base - size))
264        v2 = mk_var(self.cg_temp(carry_last, carry_base - size), ast.Store())
265        return [make_assign(v2, make_mergeh(self.fw * size, v1, v0))] + self.gen_merges(carry_last, carry_base - size)
266      else: return []
267
268    #
269    #  Given that carry_num carries have been generated and packed,
270    #  add zero_count additional carry zero values and pack.
271    #  Use shifts to introduce multiple zeroes, where possible.
272    #
273    def gen_multiple_carry_zero_then_pack(self, carry_num, zero_count):
274      if zero_count == 0: return []
275      pending_carry_pack_size = low_bit(carry_num)
276      pending_carry_base = carry_num - pending_carry_pack_size
277      # We may be able to fill zeroes by shifting.
278      # But the shift is limited by any further pending carry pack and
279      # the constraint that the result must produce a well-formed pack
280      # having a power-of-2 entries.
281      #
282      final_num = carry_num + zero_count
283      pack_size2 = low_bit(pending_carry_base)
284      if pending_carry_base == 0:
285        shift = pow2floor(final_num) - pending_carry_pack_size
286      else:
287        shift = min(low_bit(pending_carry_base), low_bit(final_num)) - pending_carry_pack_size
288      if pending_carry_pack_size == 0 or shift == 0:
289        # There is either no pending pack or we are not generating enough
290        # carry zeroes to combine into the pending pack, so we can only add new
291        # packs.
292        #
293        if zero_count == 1:  return [make_assign(self.cg_temp(carry_num), make_zero(self.fw))]
294        else: 
295          zero_count_floor = pow2floor(zero_count)
296          hi_num = carry_num + zero_count_floor
297          a1 = make_assign(self.cg_temp(hi_num - 1, carry_num), make_zero(self.fw))
298          remaining_zeroes = zero_count - zero_count_floor
299          return [a1] + self.gen_multiple_carry_zero_then_pack(hi_num, remaining_zeroes) 
300      #
301      shift_result = self.cg_temp(carry_num + shift - 1, pending_carry_base)
302      pending = self.cg_temp(carry_num - 1, pending_carry_base)
303      #print shift_result, " by shift ", pending, shift
304      a1 = make_assign(shift_result, make_call('mvmd<%i>::slli<%i>' % (self.fw, shift), [mk_var(pending)]))
305      # Do any necessary merges
306      m = self.gen_merges(carry_num + shift - 1,  pending_carry_base)
307      return [a1] + m + self.gen_multiple_carry_zero_then_pack(carry_num + shift, zero_count - shift)
308
309
310    def allocate_ops(self):
311      carry_count = 0
312      for op in range(self.carryInfoSet.operation_count):
313        b = self.carryInfoSet.containing_block[op]
314        if op != 0: 
315          # If we've just left a block, ensure that we are aligned.
316          b_last = self.carryInfoSet.containing_block[op-1]
317          if not block_contains(b_last, b, self.carryInfoSet.parent_block):
318            # find the max-sized block just exited.
319            while not block_contains(self.carryInfoSet.parent_block[b_last], b, self.carryInfoSet.parent_block):
320              b_last = self.carryInfoSet.parent_block[b_last]
321            align_base = self.aligned_size[b_last]
322            if align_base > self.field_count: align_base = self.field_count
323            carry_count = align(carry_count, align_base)         
324        if self.carryInfoSet.block_first_op[b] == op:
325          # If we're just entering a block, ensure that we are aligned.
326          align_base = self.aligned_size[b]
327          if align_base > self.field_count: align_base = self.field_count
328          carry_count = align(carry_count, align_base)
329          self.block_base[b] = carry_count
330        if op not in self.carryInfoSet.advance_amount.keys():
331          self.alloc_map[op] = carry_count
332          carry_count += 1
333        elif self.carryInfoSet.advance_amount[op] == 1: 
334          self.alloc_map[op] = carry_count
335          carry_count += 1
336      # When processing the last operation, make sure that the "next" operation
337      # appears to start a new pack.
338      self.alloc_map[self.carryInfoSet.operation_count] = align(carry_count, self.field_count)
339     
340    def GenerateCarryDecls(self):
341        return "  ubitblock %s [%i];\n" % (self.carryGroupVar, self.ubitblock_count)
342    def GenerateInitializations(self):
343        v = self.carryGroupVar       
344        #const_0 = make_zero(self.fw)
345        #inits = [make_assign(make_index_store(v, i), const_0) for i in range(0, self.ubitblock_count)]
346        inits = ""
347        for i in range(0, self.ubitblock_count):
348          inits += "%s[%i]._128 = simd<%i>::constant<0>();\n" % (v, i, self.fw)
349        for op_no in range(self.carryInfoSet.block_op_count[0]):
350          if op_no in self.carryInfoSet.init_one_list: 
351            posn = self.alloc_map[op_no]
352            ub = posn/self.field_count
353            rp = posn%self.field_count
354            inits += "%s[%i]._%i[%i] = 1;\n" % (self.carryGroupVar, ub, self.fw, rp)
355            #v_ub = make_index_load(self.carryGroupVar, ub)
356            #v_ub_fw = make_att_load(v_ub, '_%i' % self.fw)
357            #inits.append(make_assign(make_index_store(v_ub_fw, rp), ast.Num(1)))
358        return inits
359    def GenerateStreamFunctionDecls(self):
360        f = self.field_count
361        s = 1
362        decls = []
363        while f > 0:
364          decls += [self.cg_temp(s*(i+1)-1, s*i) for i in range(f)]
365          f = f/2
366          s = s * 2
367        return BitBlock_decls_from_vars(decls)
368
369    def GenerateCarryInAccess(self, operation_no):
370        block_no = self.carryInfoSet.containing_block[operation_no]
371        posn = self.alloc_map[operation_no] - self.carry_offset
372        ub = posn/self.field_count
373        rp = posn%self.field_count
374        v_ub = make_index_load(self.carryGroupVar, ub)
375        v_ub_fw = make_att_load(v_ub, '_%i' % self.fw)
376        return make_call("convert", [make_index_load(v_ub_fw, rp)])
377    def GenerateCarryOutStore(self, operation_no, carry_out_expr):
378        block_no = self.carryInfoSet.containing_block[operation_no]
379        posn = self.alloc_map[operation_no] - self.carry_offset
380        ub = posn/self.field_count
381        rp = posn%self.field_count
382        # Only generate an actual store for the last carryout
383        assigs = [make_assign(self.temp_prefix + repr(rp), carry_out_expr)] 
384        assigs += self.gen_merges(rp, rp)
385        next_posn = self.alloc_map[operation_no + 1] - self.carry_offset
386        skip = next_posn - posn - 1
387        if skip > 0: 
388          assigs += self.gen_multiple_carry_zero_then_pack(rp+1, skip)
389        #print (posn, skip)
390        if next_posn % self.field_count == 0:
391          v_ub = make_index_load(self.carryGroupVar, ub)
392          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
393          storable_carry_in_form = make_call(shift_op, [mk_var(self.cg_temp(self.field_count - 1))])
394          assigs.append(make_assign(make_att_store(v_ub, '_128'), storable_carry_in_form))
395        return assigs
396    def GenerateAdvanceInAccess(self, operation_no): pass
397        #adv_index = self.advIndex[operation_no - self.operation_offset]
398        #return mkCall(self.carryGroupVar + "." + 'get_pending64', [ast.Num(adv_index)])
399    def GenerateAdvanceOutStore(self, operation_no, adv_out_expr): pass
400        #adv_index = self.advIndex[operation_no - self.operation_offset]
401        #cq_index = adv_index + self.carry_count
402        #return [ast.Assign([ast.Subscript(self.CarryGroupAtt('cq'), ast.Index(ast.Num(cq_index)), ast.Store())],
403                           #mkCall("bitblock::srli<64>", [adv_out_expr]))]
404    def GenerateTest(self, block_no, testExpr):
405        posn = self.block_base[block_no] - self.carry_offset
406        ub = posn/self.field_count
407        rp = posn%self.field_count
408        count = self.aligned_size[block_no] 
409        width = count * self.fw
410        v_ub = make_index_load(self.carryGroupVar, ub)
411        if width <= 64:
412            t = make_index_load(make_att_load(v_ub, '_%i' % width), rp/count)
413            return TestHelper_Integer_Or(testExpr, t)
414        else:
415            t = make_att_load(v_ub, '_128')
416            for i in range(1, count/self.field_count): 
417              v2 = make_att_load(make_index_load(self.carryGroupVar, ub + i), '_128')
418              t = make_call('simd_or', [t, v2])
419            return TestHelper_Bitblock_Or(testExpr, t)
420    def GenerateCarryIfTest(self, block_no, ifTest):
421        return self.GenerateTest(block_no, ifTest)
422
423    def GenerateCarryElseFinalization(self, block_no):
424        # if the block consists of full carry packs, then
425        # no action need be taken: the corresponding carry-in packs
426        # must already be zero, or the then branch would have been taken.
427        count = self.aligned_size[block_no]
428        if count % self.field_count == 0: return []
429        # The block has half a carry-pack or less.
430        assigs = []
431        posn = self.block_base[block_no] - self.carry_offset
432        ub = posn / self.field_count
433        rp = posn % self.field_count
434        next_op = self.carryInfoSet.block_first_op[block_no] + self.carryInfoSet.block_op_count[block_no]
435        end_pos = (self.alloc_map[next_op]  - self.carry_offset - 1) % self.field_count
436        #print rp, next_op,self.alloc_map[next_op]
437        #assigs = [make_assign(self.cg_temp(end_pos, rp), make_zero(self.fw))]
438        assigs = self.gen_multiple_carry_zero_then_pack(rp, end_pos - rp + 1)
439        if (end_pos + 1) % self.field_count == 0:
440          v_ub = make_index_load(self.carryGroupVar, ub)
441          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
442          storable_carry_in_form = make_call(shift_op, [mk_var(self.cg_temp(self.field_count - 1))])
443          assigs.append(make_assign(make_att_store(v_ub, '_128'), storable_carry_in_form))
444
445        return assigs
446
447    def GenerateLocalDeclare(self, block_no):
448        if self.carryInfoSet.block_op_count[block_no] == 0: return []
449        count = self.aligned_size[block_no] 
450        if count >= self.field_count:
451          ub_count = count / self.field_count
452          decls = [make_callStmt('ubitblock_declare', [mk_var('sub' + self.carryGroupVar), ast.Num(ub_count)])]
453          count = self.field_count
454        else: decls = []
455        # Generate carry pack temps.
456        f = count
457        s = 1
458        temps = []
459        while f > 0:
460          temps += [self.local_temp(s*(i+1)-1, s*i) for i in range(f)]
461          f = f/2
462          s = s * 2
463        #return BitBlock_decls_from_vars(decls)
464        return decls + [make_callStmt('BitBlock_declare', [mk_var(t)]) for t in temps]
465   
466    def GenerateCarryWhileTest(self, block_no, testExpr):
467        return self.GenerateTest(block_no, testExpr)
468
469    def EnterLocalWhileBlock(self, operation_offset): 
470        self.carryGroupVar = "sub" + self.carryGroupVar
471        self.temp_prefix = "sub" + self.temp_prefix
472        self.carry_offset = self.alloc_map[operation_offset]
473        #print "self.carry_offset = %i" % self.carry_offset
474    def ExitLocalWhileBlock(self): 
475        self.carryGroupVar = self.carryGroupVar[3:]
476        self.temp_prefix = self.temp_prefix[3:]
477        self.carry_offset = 0
478       
479    def GenerateCarryWhileFinalization(self, block_no):
480        posn = self.block_base[block_no]
481        ub = posn/self.field_count
482        rp = posn%self.field_count
483        count = self.aligned_size[block_no] 
484        v = self.carryGroupVar
485        lv = "sub" + v
486        if count < self.field_count:
487          v0 = self.cg_temp(rp + count - 1, rp)
488          lv0 = self.local_temp(count - 1, 0)
489          return [make_assign(v0, make_call('simd_or', [mk_var(v0), mk_var(lv0)]))]
490        n = (count+self.field_count-1)/self.field_count
491        assigs = []
492        for i in range(n):
493          v_ub_i = make_index_load(v, ub + i)
494          assigs.append(make_assign(make_att_store(v_ub_i, '_128'), make_call('simd_or', [make_att_load(v_ub_i, '_128'), make_att_load(make_index_load(lv, i), '_128')])))
495        return assigs
496    def GenerateStreamFunctionFinalization(self):
497        if self.carryInfoSet.carry_count == 0: return []
498        # Generate statements to shift all carries from carry-out form to carry-in form.
499        #v = self.carryGroupVar
500        #n = (self.aligned_size[0] + self.field_count - 1)/self.field_count
501        #shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
502        #return [make_assign(make_att_store(make_index_load(v, i), '_128'), make_call(shift_op, [make_att_load(make_index_load(v, i), '_128')])) for i in range(n)]
503        #
504        # Now arranging shifts with original stores.
505        return []
506
Note: See TracBrowser for help on using the repository browser.