source: proto/Compiler/CCGO_HMCPS.py @ 2786

Last change on this file since 2786 was 2786, checked in by cameron, 7 years ago

Add CCGO GenerateTestAll? method for @any_carry support

File size: 21.5 KB
Line 
1#
2# CCGO_HMCPS.py
3#
4# Carry Code Generator Object using Hierarchical Merging Carry Pack Strategy
5#
6# Robert D. Cameron
7# November 26, 2012
8# Licensed under Open Software License 3.0
9#
10import ast
11import CCGO
12
13# Copyright 2012, Robert D. Cameron
14# All rights reserved.
15#
16# Helper functions
17#
18def TestHelper_Bitblock_Or(testExpr, bitBlockExpr):
19    if isinstance(testExpr, ast.Call):
20      assert isinstance(testExpr.func, ast.Name)
21      assert testExpr.func.id == 'bitblock::any'
22      testExpr.args[0] = make_call('simd_or', [bitBlockExpr, testExpr.args[0]])
23      return testExpr
24    else:
25      return ast.BinOp(testExpr, ast.BitOr(), make_call('bitblock::any', [bitBlockExpr]))
26
27def TestHelper_Integer_Or(testExpr, intExpr):
28    return ast.BinOp(testExpr, ast.BitOr(), intExpr)
29
30def mk_var(var, mode=ast.Load()):
31  if isinstance(var, str): 
32        var = ast.Name(var, mode)
33  return var
34 
35def make_assign(var, expr):
36   if isinstance(var, str): 
37        var = ast.Name(var, ast.Store())
38   return ast.Assign([var], expr)
39
40def make_index(var, num, mode=ast.Load()):
41  if isinstance(var, str): 
42        var = ast.Name(var, ast.Load())
43  return ast.Subscript(var, ast.Index(ast.Num(num)), mode)
44
45def make_index_store(var, num):
46  if isinstance(var, str): 
47        var = ast.Name(var, ast.Load())
48  return ast.Subscript(var, ast.Index(ast.Num(num)), ast.Store())
49
50def make_att(var, att, mode=ast.Load()):
51  if isinstance(var, str): 
52        var = ast.Name(var, ast.Load())
53  return ast.Attribute(var, att, mode)
54
55def make_att_store(var, att):
56  if isinstance(var, str): 
57        var = ast.Name(var, ast.Load())
58  return ast.Attribute(var, att, ast.Store())
59
60def make_call(fn_name, args):
61  if isinstance(fn_name, str): 
62        fn_name = ast.Name(fn_name, ast.Load())
63  return ast.Call(fn_name, args, [], None, None)
64
65def make_callStmt(fn_name, args):
66  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
67  return ast.Expr(ast.Call(fn_name, args, [], None, None))
68
69def make_mergeh(fw, x, y):
70  return make_call("esimd<%i>::mergeh" % fw, [mk_var(x), mk_var(y)])
71
72def make_zero(fw):
73  return make_call("simd<%i>::constant<0>" % fw, [])
74
75 
76#
77#
78# Carry Pack Assignment Strategy
79#
80# The hierarchical merging carry pack strategy packs carries
81# into groups of 2, 4, 8 and 16.   For example, to pack
82# 4 carries c0, c1, c2, and c3 into the 32-bit fields of
83# a 128-bit register, the following operations are used.
84#
85# c0 = pablo.SomeCarryGeneratingFn(...)
86# c1 = pablo.SomeCarryGeneratingFn(...)
87# c1_0 = esimd::mergeh<32>(c1, c0)
88# c2 = pablo.SomeCarryGeneratingFn(...)
89# c3 = pablo.SomeCarryGeneratingFn(...)
90# c3_2 = esimd::mergeh<32>(c3, c2)
91# c3_0 = esimd::mergeh<64>(c3_2, c1_0)
92#
93#
94# Packing operations are generated sequentially when
95# the appropriate individual carries or subpacks become
96# available.   
97#
98# Generate the packing operations assuming that the
99# carry_num carry has just been generated.
100#
101
102def pow2ceil(n):
103   c = 1
104   while c < n: c *= 2 
105   return c
106
107def pow2floor(n):
108   c = 1
109   while c <= n: c *= 2 
110   return c/2
111
112def low_bit(n):
113   return n - (n & (n-1))
114   
115def align(n, align_base):
116  return ((n + align_base - 1) / align_base) * align_base
117
118def determine_aligned_block_sizes(pack_size, cis):
119  aligned_size = {}
120  for i in range(cis.block_count): aligned_size[i] = 0
121  seen = []
122  for i in range(cis.block_count):
123    # Work backwards to process all child blocks before the parent
124    # so that the parent incorporates the updated child counts.
125    b = cis.block_count - i - 1
126    b_carries = 0
127    op = cis.block_first_op[b]
128    while op < cis.block_first_op[b] + cis.block_op_count[b]:
129      sb = cis.containing_block[op]
130      if sb == b:
131        if op not in cis.advance_amount.keys(): b_carries += 1
132        elif cis.advance_amount[op] == 1: b_carries += 1
133        op += 1
134      else: 
135        align_base = aligned_size[sb]
136        if align_base > pack_size: align_base = pack_size
137        b_carries = align(b_carries, align_base)
138        b_carries += aligned_size[sb]
139        op += cis.block_op_count[sb]
140    # Force whiles to use full blocks; this possibly can be relaxed.
141    if cis.whileblock[b] or b_carries > pack_size:
142      aligned_size[b] = align(b_carries, pack_size)
143    else:
144      aligned_size[b] = pow2ceil(b_carries)
145  return aligned_size
146 
147MAX_LINE_LENGTH = 80
148
149def BitBlock_decls_from_vars(varlist):
150  global MAX_LINE_LENGTH
151  decls =  ""
152  if not len(varlist) == 0:
153          decls = "             BitBlock"
154          pending = ""
155          linelgth = 10
156          for v in varlist:
157            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
158              decls += pending + " " + v
159              linelgth += len(pending + v) + 1
160            else:
161              decls += ";\n             BitBlock " + v
162              linelgth = 11 + len(v)
163            pending = ","
164          decls += ";"
165  return decls
166 
167def block_contains(b0, b1, parent_block_map):
168  if b0 == b1: return True
169  elif b1 == 0: return False
170  else: return block_contains(b0, parent_block_map[b1], parent_block_map)
171 
172class HMCPS_CCGO(CCGO.CCGO):
173    def __init__(self, fw, carryInfoSet, carryGroupVarName='carryG', temp_prefix='__c'):
174        self.fw = fw
175        self.field_count = 128/fw
176        self.carryInfoSet = carryInfoSet
177        self.carryGroupVar = carryGroupVarName
178        self.temp_prefix = temp_prefix
179        self.aligned_size = determine_aligned_block_sizes(self.field_count, carryInfoSet)
180        self.ubitblock_count = (self.aligned_size[0] + self.field_count - 1) / self.field_count
181        self.alloc_map = {}
182        self.alloc_map[0] = 0
183        self.block_base = {}
184        self.allocate_ops()
185        # carry_offset is used within the inner body of while loops to access local carries.
186        # The calculated (ub, rp) value is reduced by this amount for the local carry group(s).
187        self.carry_offset = 0
188#
189# Carry Storage/Access
190#
191# Carries are stored in one or more ubitblocks as byte values.
192# For each block, the carry count is rounded up to the nearest power of 2 ceiling P,
193# so that the carry test for that block is accessible as a single value of P bytes.
194# Packs of 1, 2, 4 or 8 carries are respectively represented
195# as one or more _8, _16, _32 or _64 values.  (Members of ubitblock union.)
196#
197#
198# Allocation phase determines the ubitblock_no and count for each block.
199
200#  carry-in access is a byte load  carryG[packno]._8[offset]
201#  carryout store is to a local pack var until we get to the final byte of a pack
202#
203#  if-test: let P be pack_size in {1,2,4,8,...}
204#    if P <= 8, use an integer test expression cG[packno]._%i % (P * 8)[block_offset]
205#     
206#  while test similar
207#    local while decl: use a copy of carryGroup
208#    while finalize  carry combine:   round up and |= into structure
209#
210    def carry_pack_full(self, ub, mode = ast.Load()):
211       return make_att(make_index(self.carryGroupVar, ub), '_128', mode)
212
213    def carry_pack_index(self, fw, ub, rp, mode = ast.Load()):
214       return make_index(make_att(make_index(self.carryGroupVar, ub), '_%i' % fw), rp, mode)
215
216    def local_pack_full(self, ub, mode = ast.Load()):
217       return make_att(make_index("sub" + self.carryGroupVar, ub), '_128', mode)
218
219
220
221    def cg_temp(self, hi_carry, lo_carry = None):
222      if lo_carry == None or hi_carry == lo_carry: return "%s%i" % (self.temp_prefix, hi_carry)
223      else: return "%s%i_%i" % (self.temp_prefix, hi_carry, lo_carry)
224   
225    def local_temp(self, hi_carry, lo_carry = None):
226      if lo_carry == None or hi_carry == lo_carry: return "sub%s%i" % (self.temp_prefix, hi_carry)
227      else: return "sub%s%i_%i" % (self.temp_prefix, hi_carry, lo_carry)
228   
229    def gen_merges(self, carry_last, carry_base):
230      size = carry_last - carry_base + 1
231      if carry_last & size: 
232        v1 = mk_var(self.cg_temp(carry_last, carry_base))
233        v0 = mk_var(self.cg_temp(carry_last - size, carry_base - size))
234        v2 = mk_var(self.cg_temp(carry_last, carry_base - size), ast.Store())
235        return [make_assign(v2, make_mergeh(self.fw * size, v1, v0))] + self.gen_merges(carry_last, carry_base - size)
236      else: return []
237
238    #
239    #  Given that carry_num carries have been generated and packed,
240    #  add zero_count additional carry zero values and pack.
241    #  Use shifts to introduce multiple zeroes, where possible.
242    #
243    def gen_multiple_carry_zero_then_pack(self, carry_num, zero_count):
244      if zero_count == 0: return []
245      pending_carry_pack_size = low_bit(carry_num)
246      pending_carry_base = carry_num - pending_carry_pack_size
247      # We may be able to fill zeroes by shifting.
248      # But the shift is limited by any further pending carry pack and
249      # the constraint that the result must produce a well-formed pack
250      # having a power-of-2 entries.
251      #
252      final_num = carry_num + zero_count
253      pack_size2 = low_bit(pending_carry_base)
254      if pending_carry_base == 0:
255        shift = pow2floor(final_num) - pending_carry_pack_size
256      else:
257        shift = min(low_bit(pending_carry_base), low_bit(final_num)) - pending_carry_pack_size
258      if pending_carry_pack_size == 0 or shift == 0:
259        # There is either no pending pack or we are not generating enough
260        # carry zeroes to combine into the pending pack, so we can only add new
261        # packs.
262        #
263        if zero_count == 1:  return [make_assign(self.cg_temp(carry_num), make_zero(self.fw))]
264        else: 
265          zero_count_floor = pow2floor(zero_count)
266          hi_num = carry_num + zero_count_floor
267          a1 = make_assign(self.cg_temp(hi_num - 1, carry_num), make_zero(self.fw))
268          remaining_zeroes = zero_count - zero_count_floor
269          return [a1] + self.gen_multiple_carry_zero_then_pack(hi_num, remaining_zeroes) 
270      #
271      shift_result = self.cg_temp(carry_num + shift - 1, pending_carry_base)
272      pending = self.cg_temp(carry_num - 1, pending_carry_base)
273      #print shift_result, " by shift ", pending, shift
274      a1 = make_assign(shift_result, make_call('bitblock::srli<%i>' % (self.fw * shift), [mk_var(pending)]))
275      # Do any necessary merges
276      m = self.gen_merges(carry_num + shift - 1,  pending_carry_base)
277      return [a1] + m + self.gen_multiple_carry_zero_then_pack(carry_num + shift, zero_count - shift)
278
279
280    def allocate_ops(self):
281      carry_count = 0
282      for op in range(self.carryInfoSet.operation_count):
283        b = self.carryInfoSet.containing_block[op]
284        if op != 0: 
285          # If we've just left a block, ensure that we are aligned.
286          b_last = self.carryInfoSet.containing_block[op-1]
287          if not block_contains(b_last, b, self.carryInfoSet.parent_block):
288            # find the max-sized block just exited.
289            while not block_contains(self.carryInfoSet.parent_block[b_last], b, self.carryInfoSet.parent_block):
290              b_last = self.carryInfoSet.parent_block[b_last]
291            align_base = self.aligned_size[b_last]
292            if align_base > self.field_count: align_base = self.field_count
293            carry_count = align(carry_count, align_base)         
294        if self.carryInfoSet.block_first_op[b] == op:
295          # If we're just entering a block, ensure that we are aligned.
296          align_base = self.aligned_size[b]
297          if align_base > self.field_count: align_base = self.field_count
298          carry_count = align(carry_count, align_base)
299        if op not in self.carryInfoSet.advance_amount.keys():
300          self.alloc_map[op] = carry_count
301          carry_count += 1
302        elif self.carryInfoSet.advance_amount[op] == 1: 
303          self.alloc_map[op] = carry_count
304          carry_count += 1
305      # When processing the last operation, make sure that the "next" operation
306      # appears to start a new pack.
307      self.alloc_map[self.carryInfoSet.operation_count] = align(carry_count, self.field_count)
308      for b in range(self.carryInfoSet.block_count): 
309         self.block_base[b] = self.alloc_map[self.carryInfoSet.block_first_op[b]]
310     
311    def GenerateCarryDecls(self):
312        return "  ubitblock %s [%i];\n" % (self.carryGroupVar, self.ubitblock_count)
313    def GenerateInitializations(self):
314        v = self.carryGroupVar       
315        inits = ""
316        for i in range(0, self.ubitblock_count):
317          inits += "%s[%i]._128 = simd<%i>::constant<0>();\n" % (v, i, self.fw)
318        for op_no in range(self.carryInfoSet.block_op_count[0]):
319          if op_no in self.carryInfoSet.init_one_list: 
320            posn = self.alloc_map[op_no]
321            ub = posn/self.field_count
322            rp = posn%self.field_count
323            inits += "%s[%i]._%i[%i] = 1;\n" % (self.carryGroupVar, ub, self.fw, rp)
324        return inits
325    def GenerateStreamFunctionDecls(self):
326        f = self.field_count
327        s = 1
328        decls = []
329        while f > 0:
330          decls += [self.cg_temp(s*(i+1)-1, s*i) for i in range(f)]
331          f = f/2
332          s = s * 2
333        return BitBlock_decls_from_vars(decls)
334
335    def GenerateCarryInAccess(self, operation_no):
336        block_no = self.carryInfoSet.containing_block[operation_no]
337        posn = self.alloc_map[operation_no] - self.carry_offset
338        ub = posn/self.field_count
339        rp = posn%self.field_count
340        return make_call("convert", [self.carry_pack_index(self.fw, ub, rp)])
341    def GenerateCarryOutStore(self, operation_no, carry_out_expr):
342        block_no = self.carryInfoSet.containing_block[operation_no]
343        posn = self.alloc_map[operation_no] - self.carry_offset
344        ub = posn/self.field_count
345        rp = posn%self.field_count
346        # Only generate an actual store for the last carryout
347        assigs = [make_assign(self.temp_prefix + repr(rp), carry_out_expr)] 
348        assigs += self.gen_merges(rp, rp)
349        next_posn = self.alloc_map[operation_no + 1] - self.carry_offset
350        skip = next_posn - posn - 1
351        if skip > 0: 
352          assigs += self.gen_multiple_carry_zero_then_pack(rp+1, skip)
353        #print (posn, skip)
354        if next_posn % self.field_count == 0:
355          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
356          storable_carry_in_form = make_call(shift_op, [mk_var(self.cg_temp(self.field_count - 1, 0))])
357          assigs.append(make_assign(self.carry_pack_full(ub, ast.Store()), storable_carry_in_form))
358        return assigs
359    def GenerateAdvanceInAccess(self, operation_no): pass
360        #adv_index = self.advIndex[operation_no - self.operation_offset]
361        #return mkCall(self.carryGroupVar + "." + 'get_pending64', [ast.Num(adv_index)])
362    def GenerateAdvanceOutStore(self, operation_no, adv_out_expr): pass
363        #adv_index = self.advIndex[operation_no - self.operation_offset]
364        #cq_index = adv_index + self.carry_count
365        #return [ast.Assign([ast.Subscript(self.CarryGroupAtt('cq'), ast.Index(ast.Num(cq_index)), ast.Store())],
366                           #mkCall("bitblock::srli<64>", [adv_out_expr]))]
367    def GenerateTestAll(self, instance_name):
368        if self.ubitblock_count == 0: return ast.Num(0)
369        else:
370            t = make_att(make_index(make_att(instance_name, self.carryGroupVar), 0), '_128')
371            for i in range(1, self.ubitblock_count): 
372              t2 = make_att(make_index(make_att(instance_name, self.carryGroupVar), i), '_128')
373              t = make_call('simd_or', [t, t2])
374            return make_call('bitblock::any', [t])
375    def GenerateTest(self, block_no, testExpr):
376        posn = self.block_base[block_no] - self.carry_offset
377        ub = posn/self.field_count
378        rp = posn%self.field_count
379        count = self.aligned_size[block_no] 
380        width = count * self.fw
381        if count < self.field_count:
382            t = self.carry_pack_index(width, ub, rp/count)
383            return TestHelper_Integer_Or(testExpr, t)
384        else:
385            t = self.carry_pack_full(ub)
386            for i in range(1, count/self.field_count): 
387              v2 = self.carry_pack_full(ub + i)
388              t = make_call('simd_or', [t, v2])
389            return TestHelper_Bitblock_Or(testExpr, t)
390    def GenerateCarryIfTest(self, block_no, ifTest):
391        return self.GenerateTest(block_no, ifTest)
392
393    def GenerateCarryElseFinalization(self, block_no):
394        # if the block consists of full carry packs, then
395        # no action need be taken: the corresponding carry-in packs
396        # must already be zero, or the then branch would have been taken.
397        count = self.aligned_size[block_no]
398        if count % self.field_count == 0: return []
399        # The block has half a carry-pack or less.
400        assigs = []
401        posn = self.block_base[block_no] - self.carry_offset
402        ub = posn / self.field_count
403        rp = posn % self.field_count
404        next_op = self.carryInfoSet.block_first_op[block_no] + self.carryInfoSet.block_op_count[block_no]
405        end_pos = (self.alloc_map[next_op]  - self.carry_offset - 1) % self.field_count
406        #print rp, next_op,self.alloc_map[next_op]
407        #assigs = [make_assign(self.cg_temp(end_pos, rp), make_zero(self.fw))]
408        assigs = self.gen_multiple_carry_zero_then_pack(rp, end_pos - rp + 1)
409        if (end_pos + 1) % self.field_count == 0:
410          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
411          storable_carry_in_form = make_call(shift_op, [mk_var(self.cg_temp(self.field_count - 1, 0))])
412          assigs.append(make_assign(self.carry_pack_full(ub, ast.Store()), storable_carry_in_form))
413        return assigs
414
415    def GenerateLocalDeclare(self, block_no):
416        if self.carryInfoSet.block_op_count[block_no] == 0: return []
417        count = self.aligned_size[block_no] 
418        if count >= self.field_count:
419          ub_count = count / self.field_count
420          decls = [make_callStmt('ubitblock_declare', [mk_var('sub' + self.carryGroupVar), ast.Num(ub_count)])]
421          count = self.field_count
422        else: decls = []
423        # Generate carry pack temps.
424        f = count
425        s = 1
426        temps = []
427        while f > 0:
428          temps += [self.local_temp(s*(i+1)-1, s*i) for i in range(f)]
429          f = f/2
430          s = s * 2
431        #return BitBlock_decls_from_vars(decls)
432        return decls + [make_callStmt('BitBlock_declare', [mk_var(t)]) for t in temps]
433   
434    def GenerateCarryWhileTest(self, block_no, testExpr):
435        return self.GenerateTest(block_no, testExpr)
436
437    def EnterLocalWhileBlock(self, operation_offset): 
438        self.carryGroupVar = "sub" + self.carryGroupVar
439        self.temp_prefix = "sub" + self.temp_prefix
440        self.carry_offset = self.alloc_map[operation_offset]
441        #print "self.carry_offset = %i" % self.carry_offset
442    def ExitLocalWhileBlock(self): 
443        self.carryGroupVar = self.carryGroupVar[3:]
444        self.temp_prefix = self.temp_prefix[3:]
445        self.carry_offset = 0
446       
447    def GenerateCarryWhileFinalization(self, block_no):
448        posn = self.block_base[block_no]
449        ub = posn/self.field_count
450        rp = posn%self.field_count
451        count = self.aligned_size[block_no] 
452        if count < self.field_count:
453          v0 = self.cg_temp(rp + count - 1, rp)
454          lv0 = self.local_temp(count - 1, 0)
455          return [make_assign(v0, make_call('simd_or', [mk_var(v0), mk_var(lv0)]))]
456        n = (count+self.field_count-1)/self.field_count
457        assigs = []
458        for i in range(n):
459          assigs.append(make_assign(self.carry_pack_full(ub + i, ast.Store()), make_call('simd_or', [self.carry_pack_full(ub + i), self.local_pack_full(i)])))
460        return assigs
461    def GenerateStreamFunctionFinalization(self):
462        return []
463
464#
465#  A version of HMCPS_CCGO eliminating ubitblocks
466#
467class HMCPS_CCGO2(HMCPS_CCGO):
468
469    def carry_pack_full(self, ub, mode = ast.Load()):
470       return make_index(self.carryGroupVar, ub, mode)
471
472    def carry_pack_index(self, fw, ub, rp, mode = ast.Load()):
473       return make_call("mvmd<%i>::extract<%i>" % (fw, rp), [self.carry_pack_full(ub)])
474
475    def local_pack_full(self, ub, mode = ast.Load()):
476       return make_index("sub" + self.carryGroupVar, ub, mode)
477
478    def GenerateCarryDecls(self):
479        return "  BitBlock simd_const_1\n;    BitBlock %s [%i];\n" % (self.carryGroupVar, self.ubitblock_count)
480
481    def GenerateInitializations(self):
482        v = self.carryGroupVar       
483        inits = "simd_const_1 = mvmd<16>::srli<7>(simd<16>::constant<1>());\n"
484        for i in range(0, self.ubitblock_count):
485          inits += "%s[%i] = simd<%i>::constant<0>();\n" % (v, i, self.fw)
486        for op_no in range(self.carryInfoSet.block_op_count[0]):
487          if op_no in self.carryInfoSet.init_one_list: 
488            posn = self.alloc_map[op_no]
489            ub = posn/self.field_count
490            rp = posn%self.field_count
491            v = "%s[%i]" % (self.carryGroupVar, ub)
492            inits += "%s = simd_or(%s, mvmd<%i>::slli<%i>(simd_const_1)) ;\n" % (v, v, self.fw, rp)
493        return inits
494
495    def GenerateCarryInAccess(self, operation_no):
496        block_no = self.carryInfoSet.containing_block[operation_no]
497        posn = self.alloc_map[operation_no] - self.carry_offset
498        ub = posn/self.field_count
499        rp = posn%self.field_count
500        #return make_call("convert", [self.carry_pack_index(self.fw, ub, rp)])
501        if rp == 0: e = self.carry_pack_full(ub)
502        else: e = make_call("mvmd<%i>::srli<%i>" %(self.fw, rp), [self.carry_pack_full(ub)])
503        if rp == self.field_count - 1:
504          return e
505        else: return make_call('simd_and', [e, mk_var("simd_const_1")])
506
507    def GenerateLocalDeclare(self, block_no):
508        if self.carryInfoSet.block_op_count[block_no] == 0: return []
509        count = self.aligned_size[block_no] 
510        if count >= self.field_count:
511          ub_count = count / self.field_count
512          decls = [make_callStmt('BitBlock_declare', [self.local_pack_full(ub_count)])]
513          decls += [make_assign(self.local_pack_full(i, ast.Store()), make_zero(self.fw)) for i in range(ub_count)]
514          count = self.field_count
515        else: decls = []
516        # Generate carry pack temps.
517        f = count
518        s = 1
519        temps = []
520        while f > 0:
521          temps += [self.local_temp(s*(i+1)-1, s*i) for i in range(f)]
522          f = f/2
523          s = s * 2
524        #return BitBlock_decls_from_vars(decls)
525        return decls + [make_callStmt('BitBlock_declare', [mk_var(t)]) for t in temps]
526   
527
Note: See TracBrowser for help on using the repository browser.