source: proto/Compiler/CCGO_HMCPS.py @ 2787

Last change on this file since 2787 was 2787, checked in by cameron, 6 years ago

Update HMCGS_CCGO2 for @any_carry

File size: 21.4 KB
Line 
1#
2# CCGO_HMCPS.py
3#
4# Carry Code Generator Object using Hierarchical Merging Carry Pack Strategy
5#
6# Robert D. Cameron
7# November 26, 2012
8# Licensed under Open Software License 3.0
9#
10import ast
11import CCGO
12
13# Copyright 2012, Robert D. Cameron
14# All rights reserved.
15#
16# Helper functions
17#
18def TestHelper_Bitblock_Or(testExpr, bitBlockExpr):
19    if isinstance(testExpr, ast.Call):
20      assert isinstance(testExpr.func, ast.Name)
21      assert testExpr.func.id == 'bitblock::any'
22      testExpr.args[0] = make_call('simd_or', [bitBlockExpr, testExpr.args[0]])
23      return testExpr
24    else:
25      return ast.BinOp(testExpr, ast.BitOr(), make_call('bitblock::any', [bitBlockExpr]))
26
27def TestHelper_Integer_Or(testExpr, intExpr):
28    return ast.BinOp(testExpr, ast.BitOr(), intExpr)
29
30def mk_var(var, mode=ast.Load()):
31  if isinstance(var, str): 
32        var = ast.Name(var, mode)
33  return var
34 
35def make_assign(var, expr):
36   if isinstance(var, str): 
37        var = ast.Name(var, ast.Store())
38   return ast.Assign([var], expr)
39
40def make_index(var, num, mode=ast.Load()):
41  if isinstance(var, str): 
42        var = ast.Name(var, ast.Load())
43  return ast.Subscript(var, ast.Index(ast.Num(num)), mode)
44
45def make_index_store(var, num):
46  if isinstance(var, str): 
47        var = ast.Name(var, ast.Load())
48  return ast.Subscript(var, ast.Index(ast.Num(num)), ast.Store())
49
50def make_att(var, att, mode=ast.Load()):
51  if isinstance(var, str): 
52        var = ast.Name(var, ast.Load())
53  return ast.Attribute(var, att, mode)
54
55def make_att_store(var, att):
56  if isinstance(var, str): 
57        var = ast.Name(var, ast.Load())
58  return ast.Attribute(var, att, ast.Store())
59
60def make_call(fn_name, args):
61  if isinstance(fn_name, str): 
62        fn_name = ast.Name(fn_name, ast.Load())
63  return ast.Call(fn_name, args, [], None, None)
64
65def make_callStmt(fn_name, args):
66  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
67  return ast.Expr(ast.Call(fn_name, args, [], None, None))
68
69def make_mergeh(fw, x, y):
70  return make_call("esimd<%i>::mergeh" % fw, [mk_var(x), mk_var(y)])
71
72def make_zero(fw):
73  return make_call("simd<%i>::constant<0>" % fw, [])
74
75 
76#
77#
78# Carry Pack Assignment Strategy
79#
80# The hierarchical merging carry pack strategy packs carries
81# into groups of 2, 4, 8 and 16.   For example, to pack
82# 4 carries c0, c1, c2, and c3 into the 32-bit fields of
83# a 128-bit register, the following operations are used.
84#
85# c0 = pablo.SomeCarryGeneratingFn(...)
86# c1 = pablo.SomeCarryGeneratingFn(...)
87# c1_0 = esimd::mergeh<32>(c1, c0)
88# c2 = pablo.SomeCarryGeneratingFn(...)
89# c3 = pablo.SomeCarryGeneratingFn(...)
90# c3_2 = esimd::mergeh<32>(c3, c2)
91# c3_0 = esimd::mergeh<64>(c3_2, c1_0)
92#
93#
94# Packing operations are generated sequentially when
95# the appropriate individual carries or subpacks become
96# available.   
97#
98# Generate the packing operations assuming that the
99# carry_num carry has just been generated.
100#
101
102def pow2ceil(n):
103   c = 1
104   while c < n: c *= 2 
105   return c
106
107def pow2floor(n):
108   c = 1
109   while c <= n: c *= 2 
110   return c/2
111
112def low_bit(n):
113   return n - (n & (n-1))
114   
115def align(n, align_base):
116  return ((n + align_base - 1) / align_base) * align_base
117
118def determine_aligned_block_sizes(pack_size, cis):
119  aligned_size = {}
120  for i in range(cis.block_count): aligned_size[i] = 0
121  seen = []
122  for i in range(cis.block_count):
123    # Work backwards to process all child blocks before the parent
124    # so that the parent incorporates the updated child counts.
125    b = cis.block_count - i - 1
126    b_carries = 0
127    op = cis.block_first_op[b]
128    while op < cis.block_first_op[b] + cis.block_op_count[b]:
129      sb = cis.containing_block[op]
130      if sb == b:
131        if op not in cis.advance_amount.keys(): b_carries += 1
132        elif cis.advance_amount[op] == 1: b_carries += 1
133        op += 1
134      else: 
135        align_base = aligned_size[sb]
136        if align_base > pack_size: align_base = pack_size
137        b_carries = align(b_carries, align_base)
138        b_carries += aligned_size[sb]
139        op += cis.block_op_count[sb]
140    # Force whiles to use full blocks; this possibly can be relaxed.
141    if cis.whileblock[b] or b_carries > pack_size:
142      aligned_size[b] = align(b_carries, pack_size)
143    else:
144      aligned_size[b] = pow2ceil(b_carries)
145  return aligned_size
146 
147MAX_LINE_LENGTH = 80
148
149def BitBlock_decls_from_vars(varlist):
150  global MAX_LINE_LENGTH
151  decls =  ""
152  if not len(varlist) == 0:
153          decls = "             BitBlock"
154          pending = ""
155          linelgth = 10
156          for v in varlist:
157            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
158              decls += pending + " " + v
159              linelgth += len(pending + v) + 1
160            else:
161              decls += ";\n             BitBlock " + v
162              linelgth = 11 + len(v)
163            pending = ","
164          decls += ";"
165  return decls
166 
167def block_contains(b0, b1, parent_block_map):
168  if b0 == b1: return True
169  elif b1 == 0: return False
170  else: return block_contains(b0, parent_block_map[b1], parent_block_map)
171 
172class HMCPS_CCGO(CCGO.CCGO):
173    def __init__(self, fw, carryInfoSet, carryGroupVarName='carryG', temp_prefix='__c'):
174        self.fw = fw
175        self.field_count = 128/fw
176        self.carryInfoSet = carryInfoSet
177        self.carryGroupVar = carryGroupVarName
178        self.temp_prefix = temp_prefix
179        self.aligned_size = determine_aligned_block_sizes(self.field_count, carryInfoSet)
180        self.ubitblock_count = (self.aligned_size[0] + self.field_count - 1) / self.field_count
181        self.alloc_map = {}
182        self.alloc_map[0] = 0
183        self.block_base = {}
184        self.allocate_ops()
185        # carry_offset is used within the inner body of while loops to access local carries.
186        # The calculated (ub, rp) value is reduced by this amount for the local carry group(s).
187        self.carry_offset = 0
188#
189# Carry Storage/Access
190#
191# Carries are stored in one or more ubitblocks as byte values.
192# For each block, the carry count is rounded up to the nearest power of 2 ceiling P,
193# so that the carry test for that block is accessible as a single value of P bytes.
194# Packs of 1, 2, 4 or 8 carries are respectively represented
195# as one or more _8, _16, _32 or _64 values.  (Members of ubitblock union.)
196#
197#
198# Allocation phase determines the ubitblock_no and count for each block.
199
200#  carry-in access is a byte load  carryG[packno]._8[offset]
201#  carryout store is to a local pack var until we get to the final byte of a pack
202#
203#  if-test: let P be pack_size in {1,2,4,8,...}
204#    if P <= 8, use an integer test expression cG[packno]._%i % (P * 8)[block_offset]
205#     
206#  while test similar
207#    local while decl: use a copy of carryGroup
208#    while finalize  carry combine:   round up and |= into structure
209#
210    def carry_pack_full(self, ub, v = None, mode = ast.Load()):
211       if v == None: v = self.carryGroupVar
212       return make_att(make_index(v, ub), '_128', mode)
213
214    def carry_pack_index(self, fw, ub, rp, mode = ast.Load()):
215       return make_index(make_att(make_index(self.carryGroupVar, ub), '_%i' % fw), rp, mode)
216
217    def local_pack_full(self, ub, mode = ast.Load()):
218       return self.carry_pack_full(ub, "sub" + self.carryGroupVar, mode)
219
220
221
222    def cg_temp(self, hi_carry, lo_carry = None):
223      if lo_carry == None or hi_carry == lo_carry: return "%s%i" % (self.temp_prefix, hi_carry)
224      else: return "%s%i_%i" % (self.temp_prefix, hi_carry, lo_carry)
225   
226    def local_temp(self, hi_carry, lo_carry = None):
227      if lo_carry == None or hi_carry == lo_carry: return "sub%s%i" % (self.temp_prefix, hi_carry)
228      else: return "sub%s%i_%i" % (self.temp_prefix, hi_carry, lo_carry)
229   
230    def gen_merges(self, carry_last, carry_base):
231      size = carry_last - carry_base + 1
232      if carry_last & size: 
233        v1 = mk_var(self.cg_temp(carry_last, carry_base))
234        v0 = mk_var(self.cg_temp(carry_last - size, carry_base - size))
235        v2 = mk_var(self.cg_temp(carry_last, carry_base - size), ast.Store())
236        return [make_assign(v2, make_mergeh(self.fw * size, v1, v0))] + self.gen_merges(carry_last, carry_base - size)
237      else: return []
238
239    #
240    #  Given that carry_num carries have been generated and packed,
241    #  add zero_count additional carry zero values and pack.
242    #  Use shifts to introduce multiple zeroes, where possible.
243    #
244    def gen_multiple_carry_zero_then_pack(self, carry_num, zero_count):
245      if zero_count == 0: return []
246      pending_carry_pack_size = low_bit(carry_num)
247      pending_carry_base = carry_num - pending_carry_pack_size
248      # We may be able to fill zeroes by shifting.
249      # But the shift is limited by any further pending carry pack and
250      # the constraint that the result must produce a well-formed pack
251      # having a power-of-2 entries.
252      #
253      final_num = carry_num + zero_count
254      pack_size2 = low_bit(pending_carry_base)
255      if pending_carry_base == 0:
256        shift = pow2floor(final_num) - pending_carry_pack_size
257      else:
258        shift = min(low_bit(pending_carry_base), low_bit(final_num)) - pending_carry_pack_size
259      if pending_carry_pack_size == 0 or shift == 0:
260        # There is either no pending pack or we are not generating enough
261        # carry zeroes to combine into the pending pack, so we can only add new
262        # packs.
263        #
264        if zero_count == 1:  return [make_assign(self.cg_temp(carry_num), make_zero(self.fw))]
265        else: 
266          zero_count_floor = pow2floor(zero_count)
267          hi_num = carry_num + zero_count_floor
268          a1 = make_assign(self.cg_temp(hi_num - 1, carry_num), make_zero(self.fw))
269          remaining_zeroes = zero_count - zero_count_floor
270          return [a1] + self.gen_multiple_carry_zero_then_pack(hi_num, remaining_zeroes) 
271      #
272      shift_result = self.cg_temp(carry_num + shift - 1, pending_carry_base)
273      pending = self.cg_temp(carry_num - 1, pending_carry_base)
274      #print shift_result, " by shift ", pending, shift
275      a1 = make_assign(shift_result, make_call('bitblock::srli<%i>' % (self.fw * shift), [mk_var(pending)]))
276      # Do any necessary merges
277      m = self.gen_merges(carry_num + shift - 1,  pending_carry_base)
278      return [a1] + m + self.gen_multiple_carry_zero_then_pack(carry_num + shift, zero_count - shift)
279
280
281    def allocate_ops(self):
282      carry_count = 0
283      for op in range(self.carryInfoSet.operation_count):
284        b = self.carryInfoSet.containing_block[op]
285        if op != 0: 
286          # If we've just left a block, ensure that we are aligned.
287          b_last = self.carryInfoSet.containing_block[op-1]
288          if not block_contains(b_last, b, self.carryInfoSet.parent_block):
289            # find the max-sized block just exited.
290            while not block_contains(self.carryInfoSet.parent_block[b_last], b, self.carryInfoSet.parent_block):
291              b_last = self.carryInfoSet.parent_block[b_last]
292            align_base = self.aligned_size[b_last]
293            if align_base > self.field_count: align_base = self.field_count
294            carry_count = align(carry_count, align_base)         
295        if self.carryInfoSet.block_first_op[b] == op:
296          # If we're just entering a block, ensure that we are aligned.
297          align_base = self.aligned_size[b]
298          if align_base > self.field_count: align_base = self.field_count
299          carry_count = align(carry_count, align_base)
300        if op not in self.carryInfoSet.advance_amount.keys():
301          self.alloc_map[op] = carry_count
302          carry_count += 1
303        elif self.carryInfoSet.advance_amount[op] == 1: 
304          self.alloc_map[op] = carry_count
305          carry_count += 1
306      # When processing the last operation, make sure that the "next" operation
307      # appears to start a new pack.
308      self.alloc_map[self.carryInfoSet.operation_count] = align(carry_count, self.field_count)
309      for b in range(self.carryInfoSet.block_count): 
310         self.block_base[b] = self.alloc_map[self.carryInfoSet.block_first_op[b]]
311     
312    def GenerateCarryDecls(self):
313        return "  ubitblock %s [%i];\n" % (self.carryGroupVar, self.ubitblock_count)
314    def GenerateInitializations(self):
315        v = self.carryGroupVar       
316        inits = ""
317        for i in range(0, self.ubitblock_count):
318          inits += "%s[%i]._128 = simd<%i>::constant<0>();\n" % (v, i, self.fw)
319        for op_no in range(self.carryInfoSet.block_op_count[0]):
320          if op_no in self.carryInfoSet.init_one_list: 
321            posn = self.alloc_map[op_no]
322            ub = posn/self.field_count
323            rp = posn%self.field_count
324            inits += "%s[%i]._%i[%i] = 1;\n" % (self.carryGroupVar, ub, self.fw, rp)
325        return inits
326    def GenerateStreamFunctionDecls(self):
327        f = self.field_count
328        s = 1
329        decls = []
330        while f > 0:
331          decls += [self.cg_temp(s*(i+1)-1, s*i) for i in range(f)]
332          f = f/2
333          s = s * 2
334        return BitBlock_decls_from_vars(decls)
335
336    def GenerateCarryInAccess(self, operation_no):
337        block_no = self.carryInfoSet.containing_block[operation_no]
338        posn = self.alloc_map[operation_no] - self.carry_offset
339        ub = posn/self.field_count
340        rp = posn%self.field_count
341        return make_call("convert", [self.carry_pack_index(self.fw, ub, rp)])
342    def GenerateCarryOutStore(self, operation_no, carry_out_expr):
343        block_no = self.carryInfoSet.containing_block[operation_no]
344        posn = self.alloc_map[operation_no] - self.carry_offset
345        ub = posn/self.field_count
346        rp = posn%self.field_count
347        # Only generate an actual store for the last carryout
348        assigs = [make_assign(self.temp_prefix + repr(rp), carry_out_expr)] 
349        assigs += self.gen_merges(rp, rp)
350        next_posn = self.alloc_map[operation_no + 1] - self.carry_offset
351        skip = next_posn - posn - 1
352        if skip > 0: 
353          assigs += self.gen_multiple_carry_zero_then_pack(rp+1, skip)
354        #print (posn, skip)
355        if next_posn % self.field_count == 0:
356          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
357          storable_carry_in_form = make_call(shift_op, [mk_var(self.cg_temp(self.field_count - 1, 0))])
358          assigs.append(make_assign(self.carry_pack_full(ub, mode = ast.Store()), storable_carry_in_form))
359        return assigs
360    def GenerateAdvanceInAccess(self, operation_no): pass
361        #adv_index = self.advIndex[operation_no - self.operation_offset]
362        #return mkCall(self.carryGroupVar + "." + 'get_pending64', [ast.Num(adv_index)])
363    def GenerateAdvanceOutStore(self, operation_no, adv_out_expr): pass
364        #adv_index = self.advIndex[operation_no - self.operation_offset]
365        #cq_index = adv_index + self.carry_count
366        #return [ast.Assign([ast.Subscript(self.CarryGroupAtt('cq'), ast.Index(ast.Num(cq_index)), ast.Store())],
367                           #mkCall("bitblock::srli<64>", [adv_out_expr]))]
368    def GenerateTestAll(self, instance_name):
369        if self.ubitblock_count == 0: return ast.Num(0)
370        else:
371            v = make_att(instance_name, self.carryGroupVar)
372            t = self.carry_pack_full(0, v)
373            for i in range(1, self.ubitblock_count): 
374              t2 = self.carry_pack_full(i, v)
375              t = make_call('simd_or', [t, t2])
376            return make_call('bitblock::any', [t])
377    def GenerateTest(self, block_no, testExpr):
378        posn = self.block_base[block_no] - self.carry_offset
379        ub = posn/self.field_count
380        rp = posn%self.field_count
381        count = self.aligned_size[block_no] 
382        width = count * self.fw
383        if count < self.field_count:
384            t = self.carry_pack_index(width, ub, rp/count)
385            return TestHelper_Integer_Or(testExpr, t)
386        else:
387            t = self.carry_pack_full(ub)
388            for i in range(1, count/self.field_count): 
389              v2 = self.carry_pack_full(ub + i)
390              t = make_call('simd_or', [t, v2])
391            return TestHelper_Bitblock_Or(testExpr, t)
392    def GenerateCarryIfTest(self, block_no, ifTest):
393        return self.GenerateTest(block_no, ifTest)
394
395    def GenerateCarryElseFinalization(self, block_no):
396        # if the block consists of full carry packs, then
397        # no action need be taken: the corresponding carry-in packs
398        # must already be zero, or the then branch would have been taken.
399        count = self.aligned_size[block_no]
400        if count % self.field_count == 0: return []
401        # The block has half a carry-pack or less.
402        assigs = []
403        posn = self.block_base[block_no] - self.carry_offset
404        ub = posn / self.field_count
405        rp = posn % self.field_count
406        next_op = self.carryInfoSet.block_first_op[block_no] + self.carryInfoSet.block_op_count[block_no]
407        end_pos = (self.alloc_map[next_op]  - self.carry_offset - 1) % self.field_count
408        #print rp, next_op,self.alloc_map[next_op]
409        #assigs = [make_assign(self.cg_temp(end_pos, rp), make_zero(self.fw))]
410        assigs = self.gen_multiple_carry_zero_then_pack(rp, end_pos - rp + 1)
411        if (end_pos + 1) % self.field_count == 0:
412          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
413          storable_carry_in_form = make_call(shift_op, [mk_var(self.cg_temp(self.field_count - 1, 0))])
414          assigs.append(make_assign(self.carry_pack_full(ub, mode = ast.Store()), storable_carry_in_form))
415        return assigs
416
417    def GenerateLocalDeclare(self, block_no):
418        if self.carryInfoSet.block_op_count[block_no] == 0: return []
419        count = self.aligned_size[block_no] 
420        if count >= self.field_count:
421          ub_count = count / self.field_count
422          decls = [make_callStmt('ubitblock_declare', [mk_var('sub' + self.carryGroupVar), ast.Num(ub_count)])]
423          count = self.field_count
424        else: decls = []
425        # Generate carry pack temps.
426        f = count
427        s = 1
428        temps = []
429        while f > 0:
430          temps += [self.local_temp(s*(i+1)-1, s*i) for i in range(f)]
431          f = f/2
432          s = s * 2
433        #return BitBlock_decls_from_vars(decls)
434        return decls + [make_callStmt('BitBlock_declare', [mk_var(t)]) for t in temps]
435   
436    def GenerateCarryWhileTest(self, block_no, testExpr):
437        return self.GenerateTest(block_no, testExpr)
438
439    def EnterLocalWhileBlock(self, operation_offset): 
440        self.carryGroupVar = "sub" + self.carryGroupVar
441        self.temp_prefix = "sub" + self.temp_prefix
442        self.carry_offset = self.alloc_map[operation_offset]
443        #print "self.carry_offset = %i" % self.carry_offset
444    def ExitLocalWhileBlock(self): 
445        self.carryGroupVar = self.carryGroupVar[3:]
446        self.temp_prefix = self.temp_prefix[3:]
447        self.carry_offset = 0
448       
449    def GenerateCarryWhileFinalization(self, block_no):
450        posn = self.block_base[block_no]
451        ub = posn/self.field_count
452        rp = posn%self.field_count
453        count = self.aligned_size[block_no] 
454        if count < self.field_count:
455          v0 = self.cg_temp(rp + count - 1, rp)
456          lv0 = self.local_temp(count - 1, 0)
457          return [make_assign(v0, make_call('simd_or', [mk_var(v0), mk_var(lv0)]))]
458        n = (count+self.field_count-1)/self.field_count
459        assigs = []
460        for i in range(n):
461          assigs.append(make_assign(self.carry_pack_full(ub + i, mode = ast.Store()), make_call('simd_or', [self.carry_pack_full(ub + i), self.local_pack_full(i)])))
462        return assigs
463    def GenerateStreamFunctionFinalization(self):
464        return []
465
466#
467#  A version of HMCPS_CCGO eliminating ubitblocks
468#
469class HMCPS_CCGO2(HMCPS_CCGO):
470
471    def carry_pack_full(self, ub, v = None, mode = ast.Load()):
472       if v == None: v = self.carryGroupVar
473       return make_index(v, ub, mode)
474
475    def carry_pack_index(self, fw, ub, rp, mode = ast.Load()):
476       return make_call("mvmd<%i>::extract<%i>" % (fw, rp), [self.carry_pack_full(ub)])
477
478    def GenerateCarryDecls(self):
479        return "  BitBlock simd_const_1\n;    BitBlock %s [%i];\n" % (self.carryGroupVar, self.ubitblock_count)
480
481    def GenerateInitializations(self):
482        v = self.carryGroupVar       
483        inits = "simd_const_1 = mvmd<16>::srli<7>(simd<16>::constant<1>());\n"
484        for i in range(0, self.ubitblock_count):
485          inits += "%s[%i] = simd<%i>::constant<0>();\n" % (v, i, self.fw)
486        for op_no in range(self.carryInfoSet.block_op_count[0]):
487          if op_no in self.carryInfoSet.init_one_list: 
488            posn = self.alloc_map[op_no]
489            ub = posn/self.field_count
490            rp = posn%self.field_count
491            v = "%s[%i]" % (self.carryGroupVar, ub)
492            inits += "%s = simd_or(%s, mvmd<%i>::slli<%i>(simd_const_1)) ;\n" % (v, v, self.fw, rp)
493        return inits
494
495    def GenerateCarryInAccess(self, operation_no):
496        block_no = self.carryInfoSet.containing_block[operation_no]
497        posn = self.alloc_map[operation_no] - self.carry_offset
498        ub = posn/self.field_count
499        rp = posn%self.field_count
500        #return make_call("convert", [self.carry_pack_index(self.fw, ub, rp)])
501        if rp == 0: e = self.carry_pack_full(ub)
502        else: e = make_call("mvmd<%i>::srli<%i>" %(self.fw, rp), [self.carry_pack_full(ub)])
503        if rp == self.field_count - 1:
504          return e
505        else: return make_call('simd_and', [e, mk_var("simd_const_1")])
506
507    def GenerateLocalDeclare(self, block_no):
508        if self.carryInfoSet.block_op_count[block_no] == 0: return []
509        count = self.aligned_size[block_no] 
510        if count >= self.field_count:
511          ub_count = count / self.field_count
512          decls = [make_callStmt('BitBlock_declare', [self.local_pack_full(ub_count)])]
513          decls += [make_assign(self.local_pack_full(i, ast.Store()), make_zero(self.fw)) for i in range(ub_count)]
514          count = self.field_count
515        else: decls = []
516        # Generate carry pack temps.
517        f = count
518        s = 1
519        temps = []
520        while f > 0:
521          temps += [self.local_temp(s*(i+1)-1, s*i) for i in range(f)]
522          f = f/2
523          s = s * 2
524        #return BitBlock_decls_from_vars(decls)
525        return decls + [make_callStmt('BitBlock_declare', [mk_var(t)]) for t in temps]
526   
527
Note: See TracBrowser for help on using the repository browser.