source: proto/Compiler/CCGO_HMCPS.py @ 2701

Last change on this file since 2701 was 2701, checked in by cameron, 6 years ago

Various fixes

File size: 19.2 KB
Line 
1#
2# CCGO_HMCPS.py
3#
4# Carry Code Generator Object using Hierarchical Merging Carry Pack Strategy
5#
6# Robert D. Cameron
7# November 26, 2012
8# Licensed under Open Software License 3.0
9#
10import ast
11import CCGO
12
13# Copyright 2012, Robert D. Cameron
14# All rights reserved.
15#
16# Helper functions
17#
18def TestHelper_Bitblock_Or(testExpr, bitBlockExpr):
19    assert isinstance(testExpr, ast.Call)
20    assert isinstance(testExpr.func, ast.Name)
21    assert testExpr.func.id == 'bitblock::any'
22    testExpr.args[0] = make_call('simd_or', [bitBlockExpr, testExpr.args[0]])
23    return testExpr
24
25def TestHelper_Integer_Or(testExpr, intExpr):
26    return ast.BinOp(testExpr, ast.BitOr(), intExpr)
27
28def mk_var(var, mode=ast.Load()):
29  if isinstance(var, str): 
30        var = ast.Name(var, mode)
31  return var
32 
33def make_mergeh(fw, x, y):
34  #return "esimd<%i>::mergeh(%s, %s)" % (fw, x, y)
35  return make_call("esimd<%i>::mergeh" % fw, [mk_var(x), mk_var(y)])
36
37def make_assign(var, expr):
38  #return "%s = %s;\n" % (v, expr)
39  if isinstance(var, str): 
40        var = ast.Name(var, ast.Store())
41  return ast.Assign([var], expr)
42
43def make_zero(fw):
44  #return "simd<%i>::constant<0>() % fw
45  return make_call("simd<%i>::constant<0>" % fw, [])
46
47def make_index_load(var, num):
48  if isinstance(var, str): 
49        var = ast.Name(var, ast.Load())
50  return ast.Subscript(var, ast.Index(ast.Num(num)), ast.Load())
51
52def make_index_store(var, num):
53  if isinstance(var, str): 
54        var = ast.Name(var, ast.Load())
55  return ast.Subscript(var, ast.Index(ast.Num(num)), ast.Store())
56
57def make_att_load(var, att):
58  if isinstance(var, str): 
59        var = ast.Name(var, ast.Load())
60  return ast.Attribute(var, att, ast.Load())
61
62def make_att_store(var, att):
63  if isinstance(var, str): 
64        var = ast.Name(var, ast.Load())
65  return ast.Attribute(var, att, ast.Store())
66
67def make_call(fn_name, args):
68  if isinstance(fn_name, str): 
69        fn_name = ast.Name(fn_name, ast.Load())
70  return ast.Call(fn_name, args, [], None, None)
71
72def make_callStmt(fn_name, args):
73  if isinstance(fn_name, str): fn_name = ast.Name(fn_name, ast.Load())
74  return ast.Expr(ast.Call(fn_name, args, [], None, None))
75 
76#
77#
78# Carry Pack Assignment Strategy
79#
80# The hierarchical merging carry pack strategy packs carries
81# into groups of 2, 4, 8 and 16.   For example, to pack
82# 4 carries c0, c1, c2, and c3 into the 32-bit fields of
83# a 128-bit register, the following operations are used.
84#
85# c0 = pablo.SomeCarryGeneratingFn(...)
86# c1 = pablo.SomeCarryGeneratingFn(...)
87# c1_0 = esimd::mergeh<32>(c1, c0)
88# c2 = pablo.SomeCarryGeneratingFn(...)
89# c3 = pablo.SomeCarryGeneratingFn(...)
90# c3_2 = esimd::mergeh<32>(c3, c2)
91# c3_0 = esimd::mergeh<64>(c3_2, c1_0)
92#
93#
94# Packing operations are generated sequentially when
95# the appropriate individual carries or subpacks become
96# available.   
97#
98# Generate the packing operations assuming that the
99# carry_num carry has just been generated.
100#
101def gen_carry_pack(pack_fw, carry_num, temp_pfx):
102  # The range of carries now to be packed depends on
103  # the number of rightmost contiguous 1 bits
104  carry_range_size = (carry_num + 1) &~ carry_num
105  assign_list = []
106  i = 2
107  v1 = temp_pfx + repr(carry_num)
108  v0 = temp_pfx + repr(carry_num ^ 1)
109  fw = pack_fw
110  while i <= carry_range_size:
111    p = '%s%i_%i' % (temp_pfx, carry_num, carry_num - i + 1)
112    assign_list.append(make_assign(p, make_mergeh(fw, v1, v0)))
113    v1 = p
114    v0 = '%s%i_%i' % (temp_pfx, carry_num - i, carry_num - 2*i+1)
115    i *= 2
116    fw *= 2
117  return assign_list
118
119#
120# Pack in a zero carry value
121#
122def gen_carry_zero_then_pack(pack_fw, carry_num, temp_pfx):
123  # The range of carries now to be packed depends on
124  # the number of rightmost contiguous 1 bits
125  carry_range_size = (carry_num + 1) &~ carry_num
126  i = 2
127  v1 = temp_pfx + repr(carry_num)
128  v0 = temp_pfx + repr(carry_num ^ 1)
129  fw = pack_fw
130  assign_list = [make_assign(v1, make_zero(fw))]
131  while i <= carry_range_size:
132    p = '%s%i_%i' % (temp_pfx, carry_num, carry_num - i + 1)
133    assign_list.append(make_assign(p, make_mergeh(fw, v1, v0)))
134    v1 = p
135    v0 = '%s%i_%i' % (temp_pfx, carry_num - i, carry_num - 2*i+1)
136    i *= 2
137    fw *= 2
138  return assign_list
139
140#
141# Generate multiple zero carries to complete a carry pack.
142#
143#
144def gen_multiple_carry_zero_then_pack(pack_fw, carry_num, carry_count, temp_pfx):
145  assign_list = []
146  last = carry_num + carry_count
147  p2f = pow2floor(last)
148  if carry_num == 0:
149    assign_list.append(make_assign('%s%i_0' % (temp_pfx, p2f-1), make_zero(pack_fw)))
150    carry_num = p2f
151    carry_count -= p2f
152  else:
153    low_bit = carry_num &~ (carry_num - 1)
154    base = carry_num - low_bit
155    if low_bit == 1: pending = '%s%i' % (temp_pfx, carry_num - 1)
156    else: pending = '%s%i_%i' % (temp_pfx, carry_num - 1, base)
157    while base != 0 and carry_num <= p2f:
158       next_bit = base &~ (base - 1)
159       shift = next_bit - low_bit
160       shift_result = '%s%i_%i' % (temp_pfx, carry_num - 1 + shift, base)
161       assign_list.append(make_assign(shift_result, make_call('mvmd<%i>::slli<%i>' % (pack_fw, shift), [mk_var(pending)])))
162       pending2 = '%s%i_%i' % (temp_pfx, base - 1, base - next_bit)
163       merge_result = '%s%i_%i' % (temp_pfx, carry_num - 1 + shift, base - next_bit)
164       assign_list.append(make_assign(merge_result, make_mergeh(pack_fw * next_bit, shift_result, pending2)))
165       carry_count -= shift
166       carry_num += shift
167       low_bit = carry_num &~ (carry_num - 1)
168       base = carry_num - low_bit
169       pending = merge_result
170    shift = p2f - low_bit
171    if shift != 0:
172       shift_result = '%s%i_%i' % (temp_pfx, carry_num - 1 + shift, base)
173       assign_list.append(make_assign(shift_result, make_call('mvmd<%i>::slli<%i>' % (pack_fw, shift), [mk_var(pending)])))
174       carry_count -= shift
175       carry_num += shift
176 
177  for i in range(carry_count):
178    assign_list += gen_carry_zero_then_pack(pack_fw, carry_num + i, temp_pfx)
179  return assign_list
180
181
182
183#
184# Carry Storage/Access
185#
186# Carries are stored in one or more ubitblocks as byte values.
187# For each block, the carry count is rounded up to the nearest power of 2 ceiling P,
188# so that the carry test for that block is accessible as a single value of P bytes.
189# Packs of 1, 2, 4 or 8 carries are respectively represented
190# as one or more _8, _16, _32 or _64 values.  (Members of ubitblock union.)
191#
192#
193# Allocation phase determines the ubitblock_no and count for each block.
194
195#  carry-in access is a byte load  carryG[packno]._8[offset]
196#  carryout store is to a local pack var until we get to the final byte of a pack
197#
198#  if-test: let P be pack_size in {1,2,4,8,...}
199#    if P <= 8, use an integer test expression cG[packno]._%i % (P * 8)[block_offset]
200#     
201#  while test similar
202#    local while decl: use a copy of carryGroup
203#    while finalize  carry combine:   round up and |= into structure
204#
205def pow2ceil(n):
206   c = 1
207   while c < n: c *= 2 
208   return c
209
210def pow2floor(n):
211   c = 1
212   while c <= n: c *= 2 
213   return c/2
214   
215def align(n, align_base):
216  return ((n + align_base - 1) / align_base) * align_base
217
218def determine_aligned_block_sizes(pack_size, cis):
219  aligned_size = {}
220  for i in range(cis.block_count): aligned_size[i] = 0
221  seen = []
222  for i in range(cis.block_count):
223    # Work backwards to process all child blocks before the parent
224    # so that the parent incorporates the updated child counts.
225    b = cis.block_count - i - 1
226    b_carries = 0
227    op = cis.block_first_op[b]
228    while op < cis.block_first_op[b] + cis.block_op_count[b]:
229      sb = cis.containing_block[op]
230      if sb == b:
231        if op not in cis.advance_amount.keys(): b_carries += 1
232        elif cis.advance_amount[op] == 1: b_carries += 1
233        op += 1
234      else: 
235        align_base = aligned_size[sb]
236        if align_base > pack_size: align_base = pack_size
237        b_carries = align(b_carries, align_base)
238        b_carries += aligned_size[sb]
239        op += cis.block_op_count[sb]
240 #   if cis.whileblock[b] or aligned_size[b] > pack_size:
241    if b_carries > pack_size:
242      aligned_size[b] = align(b_carries, pack_size)
243    else:
244      aligned_size[b] = pow2ceil(b_carries)
245  print aligned_size
246  return aligned_size
247 
248MAX_LINE_LENGTH = 80
249
250def BitBlock_decls_from_vars(varlist):
251  global MAX_LINE_LENGTH
252  decls =  ""
253  if not len(varlist) == 0:
254          decls = "             BitBlock"
255          pending = ""
256          linelgth = 10
257          for v in varlist:
258            if linelgth + len(v) + 2 <= MAX_LINE_LENGTH:
259              decls += pending + " " + v
260              linelgth += len(pending + v) + 1
261            else:
262              decls += ";\n             BitBlock " + v
263              linelgth = 11 + len(v)
264            pending = ","
265          decls += ";"
266  return decls
267 
268def block_contains(b0, b1, parent_block_map):
269  if b0 == b1: return True
270  elif b1 == 0: return False
271  else: return block_contains(b0, parent_block_map[b1], parent_block_map)
272 
273class HMCPS_CCGO(CCGO.CCGO):
274    def __init__(self, fw, carryInfoSet, carryGroupVarName='carryG', temp_prefix='__c'):
275        self.fw = fw
276        self.field_count = 128/fw
277        self.carryInfoSet = carryInfoSet
278        self.carryGroupVar = carryGroupVarName
279        self.temp_prefix = temp_prefix
280        self.aligned_size = determine_aligned_block_sizes(self.field_count, carryInfoSet)
281        self.ubitblock_count = (self.aligned_size[0] + self.field_count - 1) / self.field_count
282        self.alloc_map = {}
283        self.alloc_map[0] = 0
284        self.block_base = {}
285        self.allocate_ops()
286        # carry_offset is used within the inner body of while loops to access local carries.
287        # The calculated (ub, rp) value is reduced by this amount for the local carry group(s).
288        self.carry_offset = 0
289
290    def allocate_ops(self):
291      carry_count = 0
292      for op in range(self.carryInfoSet.operation_count):
293        b = self.carryInfoSet.containing_block[op]
294        if op != 0: 
295          # If we've just left a block, ensure that we are aligned.
296          b_last = self.carryInfoSet.containing_block[op-1]
297          if not block_contains(b_last, b, self.carryInfoSet.parent_block):
298            # find the max-sized block just exited.
299            while not block_contains(self.carryInfoSet.parent_block[b_last], b, self.carryInfoSet.parent_block):
300              b_last = self.carryInfoSet.parent_block[b_last]
301            align_base = self.aligned_size[b_last]
302            if align_base > self.field_count: align_base = self.field_count
303            carry_count = align(carry_count, align_base)         
304        if self.carryInfoSet.block_first_op[b] == op:
305          # If we're just entering a block, ensure that we are aligned.
306          align_base = self.aligned_size[b]
307          if align_base > self.field_count: align_base = self.field_count
308          carry_count = align(carry_count, align_base)
309          self.block_base[b] = carry_count
310        if op not in self.carryInfoSet.advance_amount.keys():
311          self.alloc_map[op] = carry_count
312          carry_count += 1
313        elif self.carryInfoSet.advance_amount[op] == 1: 
314          self.alloc_map[op] = carry_count
315          carry_count += 1
316      # When processing the last operation, make sure that the "next" operation
317      # appears to start a new pack.
318      self.alloc_map[self.carryInfoSet.operation_count] = align(carry_count, self.field_count)
319      print self.alloc_map
320     
321    def GenerateCarryDecls(self):
322        return "  ubitblock %s [%i];\n" % (self.carryGroupVar, self.ubitblock_count)
323    def GenerateInitializations(self):
324        v = self.carryGroupVar       
325        #const_0 = make_zero(self.fw)
326        #inits = [make_assign(make_index_store(v, i), const_0) for i in range(0, self.ubitblock_count)]
327        inits = ""
328        for i in range(0, self.ubitblock_count):
329          inits += "%s[%i]._128 = simd<%i>::constant<0>();\n" % (v, i, self.fw)
330        for op_no in range(self.carryInfoSet.block_op_count[0]):
331          if op_no in self.carryInfoSet.init_one_list: 
332            posn = self.alloc_map[op_no]
333            ub = posn/self.field_count
334            rp = posn%self.field_count
335            inits += "%s[%i]._%i[%i] = 1;\n" % (self.carryGroupVar, ub, self.fw, rp)
336            #v_ub = make_index_load(self.carryGroupVar, ub)
337            #v_ub_fw = make_att_load(v_ub, '_%i' % self.fw)
338            #inits.append(make_assign(make_index_store(v_ub_fw, rp), ast.Num(1)))
339        return inits
340    def GenerateStreamFunctionDecls(self):
341        f = self.field_count
342        decls = [self.temp_prefix + repr(i) for i in range(self.field_count)]
343        while f > 1:
344          f = f/2
345          s = self.field_count/f
346          decls += [self.temp_prefix + "%i_%i" % (s*(i+1)-1, s*i) for i in range(f)]
347        return BitBlock_decls_from_vars(decls)
348
349    def GenerateCarryInAccess(self, operation_no):
350        block_no = self.carryInfoSet.containing_block[operation_no]
351        posn = self.alloc_map[operation_no] - self.carry_offset
352        ub = posn/self.field_count
353        rp = posn%self.field_count
354        v_ub = make_index_load(self.carryGroupVar, ub)
355        v_ub_fw = make_att_load(v_ub, '_%i' % self.fw)
356        return make_call("convert", [make_index_load(v_ub_fw, rp)])
357    def GenerateCarryOutStore(self, operation_no, carry_out_expr):
358        block_no = self.carryInfoSet.containing_block[operation_no]
359        posn = self.alloc_map[operation_no] - self.carry_offset
360        ub = posn/self.field_count
361        rp = posn%self.field_count
362        # Only generate an actual store for the last carryout
363        assigs = [make_assign(self.temp_prefix + repr(rp), carry_out_expr)] 
364        assigs += gen_carry_pack(self.fw, rp, self.temp_prefix)
365        next_posn = self.alloc_map[operation_no + 1] - self.carry_offset
366        skip = next_posn - posn - 1
367        if skip > 0: 
368          assigs += gen_multiple_carry_zero_then_pack(self.fw, rp+1, skip, self.temp_prefix)
369        #print (posn, skip)
370        if next_posn % self.field_count == 0:
371          v_ub = make_index_load(self.carryGroupVar, ub)
372          shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
373          storable_carry_in_form = make_call(shift_op, [mk_var(self.temp_prefix + '%i_0' % (self.field_count - 1))])
374          assigs.append(make_assign(make_att_store(v_ub, '_128'), storable_carry_in_form))
375        return assigs
376    def GenerateAdvanceInAccess(self, operation_no): pass
377        #adv_index = self.advIndex[operation_no - self.operation_offset]
378        #return mkCall(self.carryGroupVar + "." + 'get_pending64', [ast.Num(adv_index)])
379    def GenerateAdvanceOutStore(self, operation_no, adv_out_expr): pass
380        #adv_index = self.advIndex[operation_no - self.operation_offset]
381        #cq_index = adv_index + self.carry_count
382        #return [ast.Assign([ast.Subscript(self.CarryGroupAtt('cq'), ast.Index(ast.Num(cq_index)), ast.Store())],
383                           #mkCall("bitblock::srli<64>", [adv_out_expr]))]
384    def GenerateTest(self, block_no, testExpr):
385        posn = self.block_base[block_no] - self.carry_offset
386        ub = posn/self.field_count
387        rp = posn%self.field_count
388        count = self.aligned_size[block_no] 
389        width = count * self.fw
390        v_ub = make_index_load(self.carryGroupVar, ub)
391        if width <= 64:
392            t = make_index_load(make_att_load(v_ub, '_%i' % width), rp/count)
393            return TestHelper_Integer_Or(testExpr, t)
394        else:
395            t = make_att_load(v_ub, '_128')
396            for i in range(1, count/self.field_count): 
397              v2 = make_att_load(make_index_load(self.carryGroupVar, ub + i), '_128')
398              t = make_call('simd_or', [t, v2])
399            return TestHelper_Bitblock_Or(testExpr, t)
400    def GenerateCarryIfTest(self, block_no, ifTest):
401        return self.GenerateTest(block_no, ifTest)
402
403    def GenerateCarryElseFinalization(self, block_no):
404        # if the block consists of full carry packs, then
405        # no action need be taken: the corresponding carry-in packs
406        # must already be zero, or the then branch would have been taken.
407        count = self.aligned_size[block_no]
408        if count % self.field_count == 0: return []
409        # The block has half a carry-pack or less.
410        assigs = []
411        posn = self.block_base[block_no] - self.carry_offset
412        ub = posn / self.field_count
413        rp = posn % self.field_count
414        next_op = self.carryInfoSet.block_first_op[block_no] + self.carryInfoSet.block_op_count[block_no]
415        end_pos = (self.alloc_map[next_op] - 1) % self.field_count
416        print rp, next_op,self.alloc_map[next_op]
417        if rp == end_pos: v = mk_var('%s%i' % (self.temp_prefix, rp))
418        else: v = mk_var('%s%i_%i' % (self.temp_prefix, end_pos, rp))
419        assigs = [make_assign(v, make_zero(self.fw))]
420        #assigs = gen_multiple_carry_zero_then_pack(self.fw, rp, end_pos - rp + 1, self.temp_prefix)
421        return assigs
422
423    def GenerateLocalDeclare(self, block_no):
424        if self.carryInfoSet.block_op_count[block_no] == 0: return []
425        count = self.aligned_size[block_no] 
426        if count >= self.field_count:
427          ub_count = count / self.field_count
428          decls = [make_callStmt('ubitblock_declare', [mk_var('sub' + self.carryGroupVar), ast.Num(ub_count)])]
429          count = self.field_count
430        else: decls = []
431        # Generate carry pack temps.
432        temps = ["sub" + self.temp_prefix + repr(i) for i in range(count)]
433        f = count
434        while f > 1:
435          f = f/2
436          s = count/f
437          temps += ["sub" + self.temp_prefix + "%i_%i" % (s*(i+1)-1, s*i) for i in range(f)]
438        #return BitBlock_decls_from_vars(decls)
439        return decls + [make_callStmt('BitBlock_declare', [mk_var(t)]) for t in temps]
440   
441    def GenerateCarryWhileTest(self, block_no, testExpr):
442        return self.GenerateTest(block_no, testExpr)
443
444    def EnterLocalWhileBlock(self, operation_offset): 
445        self.carryGroupVar = "sub" + self.carryGroupVar
446        self.temp_prefix = "sub" + self.temp_prefix
447        self.carry_offset = self.alloc_map[operation_offset]
448        print "self.carry_offset = %i" % self.carry_offset
449    def ExitLocalWhileBlock(self): 
450        self.carryGroupVar = self.carryGroupVar[3:]
451        self.temp_prefix = self.temp_prefix[3:]
452        self.carry_offset = 0
453       
454    def GenerateCarryWhileFinalization(self, block_no):
455        posn = self.block_base[block_no]
456        ub = posn/self.field_count
457        rp = posn%self.field_count
458        count = self.aligned_size[block_no] 
459        v = self.carryGroupVar
460        lv = "sub" + v
461        if count < self.field_count:
462          if count == 1: 
463            v0 = '%s%i' % (self.temp_prefix, rp)
464            lv0 = '%s%0' % ("sub" + self.temp_prefix)
465          else:
466            v0 = '%s%i_%i' % (self.temp_prefix, rp + count - 1, rp)
467            lv0 = '%s%i_0' % ("sub" + self.temp_prefix, count - 1)
468          return [make_assign(v0, make_call('simd_or', [mk_var(v0), mk_var(lv0)]))]
469        n = (count+self.field_count-1)/self.field_count
470        assigs = []
471        for i in range(n):
472          v_ub_i = make_index_load(v, ub + i)
473          assigs.append(make_assign(make_att_store(v_ub_i, '_128'), make_call('simd_or', [make_att_load(v_ub_i, '_128'), make_att_load(make_index_load(lv, i), '_128')])))
474        return assigs
475    def GenerateStreamFunctionFinalization(self):
476        if self.carryInfoSet.carry_count == 0: return []
477        # Generate statements to shift all carries from carry-out form to carry-in form.
478        #v = self.carryGroupVar
479        #n = (self.aligned_size[0] + self.field_count - 1)/self.field_count
480        #shift_op = "simd<%i>::srli<%i>" % (self.fw, self.fw-1)
481        #return [make_assign(make_att_store(make_index_load(v, i), '_128'), make_call(shift_op, [make_att_load(make_index_load(v, i), '_128')])) for i in range(n)]
482        #
483        # Now arranging shifts with original stores.
484        return []
485
Note: See TracBrowser for help on using the repository browser.