source: proto/Compiler/CCGO_While.py @ 3695

Last change on this file since 3695 was 3444, checked in by cameron, 6 years ago

Change hard constant 127 to BLOCK_SIZE-1

File size: 9.8 KB
Line 
1#
2# CCGO_While.py
3#
4# Copyright 2013, Robert D. Cameron
5# All rights reserved.
6#
7# Carry Code Generator Objects (CCGOs) for General While Processing
8#
9# General While Processing refers to the implementation of while
10# loops with one carry per iteration for each operation contained
11# within the while loop.
12#
13import ast, mkast
14import CCGO
15class Unimplemented(Exception): pass
16
17#
18#  The testCCGO is a legacy class designed to duplicate the previous Pablo
19#  functionality of introducing a set of abstract macros into generated
20#  code (CarryArray, CarryTest, CarryDequeueEnqueue, CarryAdjust, etc.)
21#
22class CCGO_While1(CCGO.CCGO):
23    def __init__(self, BLOCK_SIZE, carryInfoSet):
24        self.BLOCK_SIZE = BLOCK_SIZE
25        self.carryInfoSet = carryInfoSet
26
27
28    def GenerateCarryDecls(self):
29        decls = []
30        carry_counter = 0
31        adv_n_counter = 0
32        for op_no in range(self.carryInfoSet.block_op_count[0]):
33          blk = self.carryInfoSet.containing_block[op_no]
34          dpth = self.carryInfoSet.whileDepth(blk)
35          if dpth == 0:
36            if not op_no in self.carryInfoSet.advance_amount.keys(): decls.append("Carry%i" % op_no)
37            elif self.carryInfoSet.advance_amount[op_no] == 1: decls.append("Carry%i" % op_no)
38            else: decls.append("Advance%i" % op_no)
39          elif dpth == 1:
40            if not op_no in self.carryInfoSet.advance_amount.keys(): decls.append("CarryVector%i" % op_no)
41            elif self.carryInfoSet.advance_amount[op_no] == 1: decls.append("CarryVector%i" % op_no)
42            else: raise Unimplemented("Advance n inside while unimplemented.")
43          else: raise Unimplemented("Nested whiles unimplemented.")
44        for b in range(1, self.carryInfoSet.block_count):
45          dpth = self.carryInfoSet.whileDepth(b)
46          if dpth == 0:decls.append("Test%i" % b)
47          else: decls.append("TestVector%i" % b)
48        return "".join(["BitBlock %s;\n" % d  for d in decls])
49    def GenerateInitializations(self):
50        inits = []
51        for op_no in range(self.carryInfoSet.block_op_count[0]):
52          blk = self.carryInfoSet.containing_block[op_no]
53          dpth = self.carryInfoSet.whileDepth(blk)
54          if dpth == 0:
55            if op_no in self.carryInfoSet.init_one_list: inits.append("Carry%i = simd_const_1" % op_no)       
56            elif not op_no in self.carryInfoSet.advance_amount.keys(): inits.append("Carry%i = simd<1>::constant<0>()" % op_no)
57            elif self.carryInfoSet.advance_amount[op_no] == 1: inits.append("Carry%i = simd<1>::constant<0>()" % op_no)
58            else: inits.append("Advance%i = simd<1>::constant<0>()" % op_no)
59          elif dpth == 1:
60            if not op_no in self.carryInfoSet.advance_amount.keys(): inits.append("CarryVector%i = simd<1>::constant<0>()" % op_no)
61            elif self.carryInfoSet.advance_amount[op_no] == 1: inits.append("CarryVector%i = simd<1>::constant<0>()" % op_no)
62            else: raise Unimplemented("Advance n inside while unimplemented.")
63          else: raise Unimplemented("Nested whiles unimplemented.")
64        for b in range(1, self.carryInfoSet.block_count):
65          dpth = self.carryInfoSet.whileDepth(b)
66          if dpth == 0: inits.append("Test%i = simd<1>::constant<0>()" % b)
67          else: inits.append("TestVector%i = simd<1>::constant<0>()" % b)
68        return "".join(["%s;\n" % i  for i in inits])
69
70    def GenerateStreamFunctionDecls(self):
71        ctrs = []
72        shfts = []
73        for b in range(1, self.carryInfoSet.block_count):
74          if self.carryInfoSet.whileblock[b]:
75            ctrs.append("iterCount%i" % b)
76            shfts.append("shift%i" % b)
77        return "".join(["int %s = 0;\n" % d  for d in ctrs] + ["BitBlock %s;\n" % d  for d in shfts]) 
78
79
80    def GenerateCarryInAccess(self, op_no):
81          blk = self.carryInfoSet.containing_block[op_no]
82          dpth = self.carryInfoSet.whileDepth(blk)
83          if dpth == 0:
84            #return mkast.call('simd_and', [mkast.var("Carry%i" % op_no), mkast.var("simd_const_1")])
85            return mkast.var("Carry%i" % op_no)
86          elif dpth == 1:
87            return mkast.call('simd_and', [mkast.var("CarryVector%i" % op_no), mkast.var("simd_const_1")])
88          else: raise Unimplemented("Nested whiles unimplemented.")
89
90    def GenerateCarryOutStore(self, op_no, carry_out_expr):
91          blk = self.carryInfoSet.containing_block[op_no]
92          dpth = self.carryInfoSet.whileDepth(blk)
93          if dpth == 0:
94            shift127 = mkast.call("bitblock::srli<BLOCK_SIZE-1>", [carry_out_expr])
95            return mkast.assign(mkast.var("Carry%i" % op_no), shift127)
96          elif dpth == 1:
97            v = mkast.var("CarryVector%i" % op_no)           
98            shift1 = mkast.call("simd<64>::srli<1>", [v])
99            carry1 = mkast.call("simd_and", [carry_out_expr, mkast.var("simd_sign_bit")])
100            return mkast.assign(v, mkast.call("simd_or", [carry1, shift1]))
101          else: raise Unimplemented("Nested whiles unimplemented.")
102
103    def GenerateAdvanceInAccess(self, op_no):
104        return mkast.var("Advance%i" % op_no)
105    def GenerateAdvanceOutStore(self, op_no, adv_out_expr):
106        return [ast.Assign([mkast.var("Advance%i" % op_no, mode=ast.Store())], 
107                           mkast.call("bitblock::srli<64>", [adv_out_expr]))]
108
109    def GenerateCarryIfTest(self, block_no, ifTest): 
110        if self.carryInfoSet.whileDepth(block_no) == 0: v = mkast.var("Test%i" % block_no)
111        else: v = mkast.call("Dequeue_bit", [mkast.var("TestVector%i" % block_no)])
112        return mkast.TestHelper_Bitblock_Or(ifTest, v)
113
114
115    def GenerateCarryThenFinalization(self, block_no): 
116          if self.carryInfoSet.whileDepth(block_no) == 0:
117              return [mkast.assign([mkast.var("Test%i" % block_no)], self.GenerateTestExpression(block_no))]
118          else: return []
119
120    def GenerateCarryElseFinalization(self, block_no): 
121          dpth = self.carryInfoSet.whileDepth(block_no)
122          if dpth == 0: return []
123          else: 
124             op1 = self.carryInfoSet.block_first_op[block_no]
125             op_count = self.carryInfoSet.block_op_count[block_no]
126             shift_stmts = []
127             for i in range(op_count):
128               v = mkast.var("CarryVector%i" % (op1 + i))
129               shift_stmts.append(mkast.assign([v], mkast.call("simd<64>::srli<1>", [v])))
130             return shift_stmts
131
132    def GenerateLocalDeclare(self, block_no): 
133       return [mkast.assign([mkast.var("iterCount%i" % block_no)], 
134                            ast.BinOp(mkast.var("iterCount%i" % block_no), ast.Add(), ast.Num(1))),
135               mkast.assign([mkast.var("TestVector%i" % block_no)], 
136                            mkast.call("simd<64>::srli<1>", [mkast.var("TestVector%i" % block_no)]))]
137    def GenerateCarryWhileTest(self, block_no, testExpr): 
138        return mkast.TestHelper_Bitblock_Or(testExpr, mkast.var("TestVector%i" % block_no))
139
140    def GenerateCarryWhileFinalization(self, block_no): 
141          op1 = self.carryInfoSet.block_first_op[block_no]
142          op_count = self.carryInfoSet.block_op_count[block_no]
143          # Prepare the shift amount by which all carry vectors must be shifted
144          # to move accumulated carries from the high end to the low end of the
145          # vector.
146          shiftv = mkast.var("shift%i" % block_no)
147          stmts = [mkast.assign([shiftv], 
148                                 mkast.call("convert", [ast.BinOp(ast.Num(64), ast.Sub(), mkast.var("iterCount%i" % block_no))]))]
149          #
150          # Shift all contained vectors in the while block and sub-blocks.
151          for i in range(op_count):
152             op = op1 + i
153             v = mkast.var("CarryVector%i" % op)
154             stmts.append(mkast.assign([v], mkast.call("bitblock::srli<64>", [mkast.call("simd<64>::srl", [v, shiftv])])))
155          #
156          # For this and all contained sub-blocks, compute the test vectors.
157          stmts += self.GenerateContainedTestVectors(block_no)
158          return stmts
159
160    def GenerateContainedTestVectors(self, block_no):
161          stmts = []
162          for c in self.carryInfoSet.children[block_no]:
163            stmts += self.GenerateContainedTestVectors(c)
164          stmts.append(mkast.assign([mkast.var("TestVector%i" % block_no)], self.GenerateTestExpression(block_no)))
165          return stmts
166
167    def GenerateTestExpression(self, block_no):
168          dpth = self.carryInfoSet.whileDepth(block_no)
169          op1 = self.carryInfoSet.block_first_op[block_no]
170          op_count = self.carryInfoSet.block_op_count[block_no]
171          test_list = []
172          children_seen = []
173          b_last = block_no
174          for i in range(op_count):
175             op = op1 + i
176             b = self.carryInfoSet.containing_block[op]
177             if b == block_no:
178               if dpth == 0: 
179                 if op not in self.carryInfoSet.advance_amount.keys():
180                   test_list.append(mkast.var("Carry%i" % op))
181                 elif self.carryInfoSet.advance_amount[op] == 1:
182                   test_list.append(mkast.var("Carry%i" % op))
183                 else: test_list.append(mkast.var("Advance%i" % op))
184               else: test_list.append(mkast.var("CarryVector%i" % op))
185             elif b in self.carryInfoSet.children[block_no] and not b in children_seen:
186               if dpth == 0 and not self.carryInfoSet.whileblock[b]: test_list.append(mkast.var("Test%i" % b))
187               else: test_list.append(mkast.var("TestVector%i" % b))
188               children_seen.append(b)
189          if len(test_list) == 0: return ast.Num(0)
190          while len(test_list) > 1:
191             next_list = [mkast.call("simd_or", [test_list[2*i], test_list[2*i+1]]) for i in range(len(test_list)/2)]
192             if len(test_list) % 2 == 1: next_list.append(test_list[-1])
193             test_list = next_list
194          return test_list[0]
195
196    def GenerateStreamFunctionFinalization(self): return []
197
198    def GenerateTestAll(self, instance_name): 
199      # Needs to be modified for external access
200      #return self.GenerateTestExpression(0)
201      return ast.Num(1)
202
203
204
205
206
207
208
209
210
211
212
213
214
215
Note: See TracBrowser for help on using the repository browser.