source: proto/pablo/src/compiler/backend/visitors/Pablo2CarryXFormer.java @ 2710

Last change on this file since 2710 was 2710, checked in by ksherdy, 6 years ago

General refactoring.

File size: 17.3 KB
Line 
1package compiler.backend.visitors;
2
3import java.util.ArrayList;
4import java.util.List;
5
6import ast.*;
7import compiler.ast.Accessors;
8import compiler.ast.Generators;
9import compiler.lang.carryset.CarrySet;
10import compiler.lang.carryset.CarrySet2Lang;
11import compiler.lang.idisa.SIMD;
12import compiler.lang.pablo.*;
13
14//TODO - Add while type to switch on/off carry in mode. Carry in is an internal to this visitor.
15
16public class Pablo2CarryXFormer {
17               
18        private ASTNode ASTTree;       
19       
20        private Builtins2Lang builtins2Lang;
21        private CarrySet2Lang carrySet2Lang;
22               
23    public Pablo2CarryXFormer(ASTNode node, Builtins2Lang builtins2Lang, CarrySet2Lang carrySet2Lang) {
24        this.ASTTree = node; 
25        this.builtins2Lang = builtins2Lang;
26        this.carrySet2Lang = carrySet2Lang;
27    }
28
29    public void XForm(boolean isFinalBlock/* boolean ci, boolean co*/) {
30                XFormer visitor = new XFormer(isFinalBlock/*ci , co*/);
31                ASTTree.accept(visitor);
32    }                   
33       
34        private class XFormer extends VoidVisitor.Default {
35
36                private boolean finalBlockMode; 
37                private boolean ciMode;
38                //private boolean coMode;
39                               
40                private int currentCarry;
41                private int currentAdvN;
42                //private int lastStmtCarries;
43               
44                XFormer(boolean isFinalBlock) { // TODO - ciMode parameter
45                        this.finalBlockMode = isFinalBlock;
46                        this.ciMode = true;
47                        //this.coMode = coMode;
48                        this.currentCarry = 0;
49                        this.currentAdvN = 0;
50                        //this.lastStmtCarries = 0;
51                }
52               
53                //              def xfrm_fndef(self, fndef):
54                //          self.current_carry = 0
55                //          self.current_adv_n = 0
56                //          carry_count = CarryCounter().count(fndef)
57                //          if carry_count == 0: return fndef
58                //          self.generic_visit(fndef)
59                //      #   
60                //      #    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
61                //          return fndef
62                public void visitEnter(FuncDefNode node) { // TODO - CarryIf duplicates
63                        this.currentCarry = 0;
64                        this.currentAdvN = 0;
65                        //this.lastStmtCarries = 0;                     
66                }               
67               
68                public void visitLeave(FuncCallNode node) {
69                       
70                        ASTNode replacementNode;
71                       
72                        ASTNode carryCall;
73                        IntegerConstantNode currentCarry = Generators.makeIntegerConstantNode(this.currentCarry, node.getToken());
74                       
75                        ASTNode advNCall;
76                        IntegerConstantNode currentAdvN = Generators.makeIntegerConstantNode(this.currentAdvN, node.getToken());
77
78//                  if self.carryin == "_ci":
79//              carry_args = [mkCall(self.carryvar.id + "." + 'get_carry_in', [ast.Num(self.current_carry)]), ast.Num(self.current_carry)]
80//              adv_n_args = [mkCall(self.carryvar.id + "." + 'get_pending64', [ast.Num(self.current_adv_n)]), ast.Num(self.current_adv_n)]
81//          else:
82//              carry_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_carry)]
83//              adv_n_args = [mkCall('simd<1>::constant<0>', []), ast.Num(self.current_adv_n)]
84                       
85                        if(ciMode) {
86                                carryCall = Generators.makeFuncCallNode(
87                                                new String [] {CarrySet.CarryQ_IDENTIFIER, carrySet2Lang.getCode(CarrySet.GETCARRYIN)}, 
88                                                node.getToken(),
89                                                new ASTNode [] {currentCarry});
90                               
91                                advNCall = Generators.makeFuncCallNode(
92                                                new String [] {CarrySet.CarryQ_IDENTIFIER, carrySet2Lang.getCode(CarrySet.GETPENDING64)}, 
93                                                node.getToken(),
94                                                new ASTNode [] {currentAdvN});
95                        } else {
96                                carryCall = Generators.makeIntegerConstantNode(0, node.getToken());
97                                advNCall = Generators.makeIntegerConstantNode(0, node.getToken());
98                        }
99                                               
100        //                  if is_BuiltIn_Call(callnode, 'Advance', 1):         
101        //                    #CARRYSET
102        //                    rtn = self.carryvar.id + "." + "BitBlock_advance_ci_co"
103        //                    c = mkCall(rtn, callnode.args + carry_args)
104        //                    self.current_carry += 1
105        //                    return c                 
106                        if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.ADVANCE.pabloName(), Builtins.ADVANCE.argCount())) {         
107                                if(finalBlockMode) {
108                                       
109                                }
110                               
111                                replaceFuncCallNode(node, 
112                                                CarrySet.CarryQ_IDENTIFIER, 
113                                                builtins2Lang.getCode(Builtins.ADVANCE), 
114                                                carryCall, 
115                                                currentCarry);
116                               
117                                this.currentCarry += 1;
118                        }
119       
120
121        //            #CARRYSET
122        //            rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co"
123        //            c = mkCall(rtn, callnode.args + carry_args)
124        //            self.current_carry += 1
125        //            return c         
126                       
127                                       
128                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.SCANTHRU.pabloName(), Builtins.SCANTHRU.argCount())) {                           
129                                if(finalBlockMode) {
130                                       
131                                }
132                               
133                                replaceFuncCallNode(node, 
134                                                CarrySet.CarryQ_IDENTIFIER, 
135                                                builtins2Lang.getCode(Builtins.SCANTHRU), 
136                                                carryCall, 
137                                                currentCarry);
138                                this.currentCarry += 1;
139                        }
140                               
141        //                  elif is_BuiltIn_Call(callnode, 'AdvanceThenScanThru', 2):
142        //                  #CARRYSET
143        //                  rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru"
144        //                  c = mkCall(rtn, callnode.args + carry_args)
145        //                  self.current_carry += 1
146        //                  return c                   
147                       
148                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.ADVANCETHENSCANTHRU.pabloName(), Builtins.ADVANCETHENSCANTHRU.argCount())) {
149                                if(finalBlockMode) {
150                                                                       
151                                }
152                               
153                                replaceFuncCallNode(node, 
154                                                CarrySet.CarryQ_IDENTIFIER, 
155                                                builtins2Lang.getCode(Builtins.ADVANCETHENSCANTHRU), 
156                                                carryCall, 
157                                                currentCarry);
158                                this.currentCarry += 1;
159                        }               
160       
161        //            elif is_BuiltIn_Call(callnode, 'SpanUpTo', 2):
162        //            #CARRYSET
163        //            rtn = self.carryvar.id + "." + "BitBlock_span_upto"
164        //            c = mkCall(rtn, callnode.args + carry_args)
165        //            self.current_carry += 1
166        //            return c         
167                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.SPANUPTO.pabloName(), Builtins.SPANUPTO.argCount())) {
168                                if(finalBlockMode) {
169                                               
170                                }
171       
172                                replaceFuncCallNode(node, 
173                                                CarrySet.CarryQ_IDENTIFIER, 
174                                                builtins2Lang.getCode(Builtins.SPANUPTO), 
175                                                carryCall, 
176                                                currentCarry);                         
177                                this.currentCarry += 1;
178                        }               
179                       
180        //              elif is_BuiltIn_Call(callnode, 'AdvanceThenScanTo', 2):
181        //      #CARRYSET
182        //      rtn = self.carryvar.id + "." + "BitBlock_advance_then_scanthru"
183        //      if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
184        //      else: scanclass = mkCall('simd_not', [callnode.args[1]])
185        //      c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
186        //      self.current_carry += 1
187        //          return c                   
188                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.ADVANCETHENSCANTO.pabloName(), Builtins.ADVANCETHENSCANTO.argCount())) {
189                                if(finalBlockMode) {
190                                       
191                                }
192                               
193                                replaceFuncCallNode(node, 
194                                                CarrySet.CarryQ_IDENTIFIER, 
195                                                builtins2Lang.getCode(Builtins.ADVANCETHENSCANTO), 
196                                                carryCall, 
197                                                currentCarry);
198                                this.currentCarry += 1;
199                        }               
200                       
201        //          elif is_BuiltIn_Call(callnode, 'InclusiveSpan', 2):
202        //            #CARRYSET
203        //      #      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
204        //      #      c = mkCall('simd_or', [mkCall(rtn, callnode.args + carry_args), callnode.args[1]])
205        //            rtn = self.carryvar.id + "." + "BitBlock_inclusive_span"
206        //            c = mkCall(rtn, callnode.args + carry_args)
207        //            self.current_carry += 1
208        //            return c
209       
210                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.INCLUSIVESPAN.pabloName(), Builtins.INCLUSIVESPAN.argCount())) {
211                                if(finalBlockMode) {
212                                       
213                                }
214
215                                replaceFuncCallNode(node, 
216                                                CarrySet.CarryQ_IDENTIFIER, 
217                                                builtins2Lang.getCode(Builtins.INCLUSIVESPAN), 
218                                                carryCall, 
219                                                currentCarry);
220                                this.currentCarry += 1;
221                        }                               
222                       
223        //          elif is_BuiltIn_Call(callnode, 'ExclusiveSpan', 2):
224        //            #CARRYSET
225        //      #      rtn = self.carryvar.id + "." + "BitBlock_span_upto"
226        //      #      c = mkCall('simd_andc', [mkCall(rtn, callnode.args + carry_args), callnode.args[0]])
227        //            rtn = self.carryvar.id + "." + "BitBlock_exclusive_span"
228        //            c = mkCall(rtn, callnode.args + carry_args)
229        //            self.current_carry += 1
230        //            return c
231       
232                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.EXCLUSIVESPAN.pabloName(), Builtins.EXCLUSIVESPAN.argCount())) {
233                                if(finalBlockMode) {
234                                       
235                                }
236
237                                replaceFuncCallNode(node, 
238                                                CarrySet.CarryQ_IDENTIFIER, 
239                                                builtins2Lang.getCode(Builtins.EXCLUSIVESPAN), 
240                                                carryCall, 
241                                                currentCarry);
242                                this.currentCarry += 1;
243                        }                                               
244                       
245        //          elif is_BuiltIn_Call(callnode, 'ScanTo', 2):
246        //            # Modified Oct. 9, 2011 to directly use BitBlock_scanthru, eliminating duplication
247        //            # in having a separate BitBlock_scanto routine.
248        //            #CARRYSET
249        //            rtn = self.carryvar.id + "." + "BitBlock_scanthru_ci_co"
250        //            if self.carryout == "":  scanclass = mkCall('simd_andc', [ast.Name('EOF_mask', ast.Load()), callnode.args[1]])
251        //            else: scanclass = mkCall('simd_not', [callnode.args[1]])
252        //            c = mkCall(rtn, [callnode.args[0], scanclass] + carry_args)
253        //            self.current_carry += 1
254        //            return c
255       
256                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.SCANTO.pabloName(), Builtins.SCANTO.argCount())) {
257                                if(finalBlockMode) {
258                                       
259                                }
260
261                                replaceFuncCallNode(node, 
262                                                CarrySet.CarryQ_IDENTIFIER, 
263                                                builtins2Lang.getCode(Builtins.SCANTO), 
264                                                carryCall, 
265                                                currentCarry);
266                                this.currentCarry += 1;
267                        }                                                               
268                       
269        //          elif is_BuiltIn_Call(callnode, 'ScanToFirst', 1):
270        //            #CARRYSET
271        //            rtn = self.carryvar.id + "." + "BitBlock_scantofirst"
272        //            #if self.carryout == "":  carry_args = [ast.Name('EOF_mask', ast.Load())] + carry_args
273        //            c = mkCall(rtn, callnode.args + carry_args)
274        //            self.current_carry += 1
275        //            return c
276                       
277
278       
279                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.SCANTOFIRST.pabloName(), Builtins.SCANTOFIRST.argCount())) {
280                                if(finalBlockMode) {
281                                       
282                                }
283
284                                replaceFuncCallNode(node, 
285                                                CarrySet.CarryQ_IDENTIFIER, 
286                                                builtins2Lang.getCode(Builtins.SCANTOFIRST), 
287                                                carryCall, 
288                                                currentCarry);
289                                this.currentCarry += 1;
290                        }                                                                               
291       
292        //          elif is_BuiltIn_Call(callnode, 'Advance32', 1):     
293        //            #CARRYSET
294        //            rtn = self.carryvar.id + "." + "BitBlock_advance_n_<32>"
295        //            c = mkCall(rtn, callnode.args + adv_n_args)
296        //            self.current_adv_n += 1
297        //            return c
298                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.ADVANCE32.pabloName(), Builtins.ADVANCE32.argCount())) {
299                                //replaceFuncCallNode(node, CarryQ.CarryQ_PACKAGE_NAME, BuiltinOperations.ADVANCE32.cPPCode(), carryCall, currentAdvN);
300                                // TODO ==> Verify implementation.
301                                if(finalBlockMode) {
302                                       
303                                }
304                               
305                                replaceFuncCallNode(node, 
306                                                CarrySet.CarryQ_IDENTIFIER, 
307                                                builtins2Lang.getCode(Builtins.ADVANCE32), 
308                                                advNCall, 
309                                                currentAdvN);
310                                this.currentAdvN += 1;
311                        }                       
312                       
313        //          if is_BuiltIn_Call(callnode, 'Advance', 2):         
314        //            #CARRYSET
315        //            rtn = self.carryvar.id + "." + "BitBlock_advance_n_<%i>" % callnode.args[1].n
316        //            c = mkCall(rtn, [callnode.args[0]] + adv_n_args)
317        //            self.current_adv_n += 1
318        //            return c         
319                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.ADVANCEN.pabloName(), Builtins.ADVANCEN.argCount())) {
320                                //replaceFuncCallNode(node, CarryQ.CarryQ_PACKAGE_NAME, BuiltinOperations.ADVANCEN.cPPCode(), carryCall, currentAdvN);                                 
321                                // TODO - Verify implementation.
322                                if(finalBlockMode) {
323                                       
324                                }
325                                                               
326                                replaceFuncCallNode(node, 
327                                                CarrySet.CarryQ_IDENTIFIER, 
328                                                builtins2Lang.getCode(Builtins.ADVANCE32), 
329                                                advNCall, 
330                                                currentAdvN);           
331                                this.currentAdvN += 1;                 
332                        }                                       
333               
334                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.ATEOF.pabloName(), Builtins.ATEOF.argCount())) {
335                                 
336                                if(finalBlockMode) { 
337                                        ASTNode arg0 = Accessors.funcCallArg(node, 0);
338                                        replacementNode = Generators.makeSIMDAndCEOFMaskFuncCall(arg0, builtins2Lang);
339                                } else {
340                                        replacementNode = Generators.makeFuncCallNode(SIMD.CONSTANT.idisaConstantName(1, 0), node.getToken());                                 
341                                }
342                                node.updateSelf(replacementNode);
343                        }                                       
344       
345        //          elif is_BuiltIn_Call(callnode, 'inFile', 1):
346        //            if self.carryout != "": 
347        //              # Non final block: inFile(x) = x.
348        //              return callnode.args[0]
349        //            else: return mkCall('simd_and', [callnode.args[0], ast.Name('EOF_mask', ast.Load())])
350
351                        else if (BuiltinsUtil.isCall(node, BuiltinsUtil.BUILTIN_PACKAGE_NAME, Builtins.INFILE.pabloName(), Builtins.INFILE.argCount())) {
352                               
353                                if(finalBlockMode) {                                   
354                                        ASTNode arg0 = Accessors.funcCallArg(node, 0);
355                                        replacementNode = Generators.makeSIMDAndEOFMaskFuncCall(arg0, builtins2Lang);
356                                } else {
357                                        replacementNode = Accessors.funcCallArg(node,0);
358                                }
359                                node.updateSelf(replacementNode);
360                        }                       
361                       
362//                      Deprecated - New bitblock iterators replace StreamScan.
363//                     
364//              elif is_BuiltIn_Call(callnode, 'StreamScan', 2):
365//                      rtn = "StreamScan"           
366//                      c = mkCall(rtn, [ast.Name('(ScanBlock *) &' + callnode.args[0].id, ast.Load()),
367//                                                 ast.Name('sizeof(BitBlock)/sizeof(ScanBlock)', ast.Load()),
368//                                                 ast.Name(callnode.args[1].id, ast.Load())])
369//              return c                                       
370//
371//                      else if (Builtins.isCall(node, Builtins.BUILTIN_PACKAGE_NAME, NoCarry.STREAMSCAN.pabloName(), NoCarry.STREAMSCAN.argCount())) {
372//                              replacementNode = Generators.makeFuncCallNode(NoCarry.STREAMSCAN.cPPCode(), node.getToken());
373//                             
374//                              ASTNode arg0 = Generators.makeIdentifierNode("(ScanBlock *) &" + Accessors.identifierLexeme(Accessors.FuncCallArg(node, 0)), node.getToken());
375//                              ASTNode arg1 = Generators.makeIdentifierNode("sizeof(BitBlock)/sizeof(ScanBlock)", node.getToken());
376//                              ASTNode arg2 = Accessors.FuncCallArg(node, 1);
377//                             
378//                              replacementNode.appendChild(arg0);
379//                              replacementNode.appendChild(arg1);
380//                              replacementNode.appendChild(arg2);
381//                      }
382                       
383//                  else:
384//                    #dump_Call(callnode)
385//                    return callnode
386                        else {
387                                // do nothing // TODO - Dump: allow Func calls to pass through the compiler.
388                        }
389                       
390                }
391                                       
392                // Helpers             
393                private void replaceFuncCallNode(FuncCallNode node, String targetPackage, String targetName,
394                                ASTNode call, IntegerConstantNode carry) {
395                        FuncCallNode replacementNode;
396
397                        List<ASTNode> args = new ArrayList<ASTNode>();
398                                               
399                        for(ASTNode arg : Accessors.funcCallArgsListNode(node).getChildren()) {
400                                args.add(arg);
401                        }       
402                        args.add(call);
403                        args.add(carry);
404                       
405                        replacementNode = Generators.makeFuncCallNode(
406                                        new String [] {targetPackage, targetName}, 
407                                        node.getToken(),
408                                        args);
409                       
410                        node.updateSelf(replacementNode);
411                }
412        }
413}
414
415//class CarryIntro(ast.NodeXFormer):
416
417//        def generic_xfrm(self, node):
418//          self.current_carry = 0
419//          self.current_adv_n = 0
420//          carry_count = CarryCounter().count(node)
421//          adv_n_count = adv_nCounter().count(node)
422//          if carry_count == 0 and adv_n_count == 0: return node
423//          self.generic_visit(node)
424//          return node
425
426//def visit_If(self, ifNode):
427//carry_base = self.current_carry
428//carries = CarryCounter().count(ifNode)
429//assert adv_nCounter().count(ifNode) == 0, "Advance(x,n) within if: illegal\n"
430//self.generic_visit(ifNode)
431//if carries == 0 or self.carryin == "": return ifNode
432//#CARRYSET
433//carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
434//new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
435//new_else_part = ifNode.orelse + [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]
436//return ast.If(new_test, ifNode.body, new_else_part)
437
438//        def visit_While(self, whileNode):
439//          if self.carryout == '':
440//            whileNode.test.args[0] = mkCall("simd_and", [whileNode.test.args[0], ast.Name('EOF_mask', ast.Load())])
441//          carry_base = self.current_carry
442//          assert adv_nCounter().count(whileNode) == 0, "Advance(x,n) within while: illegal\n"
443//          carries = CarryCounter().count(whileNode)
444//          #CARRYSET
445//          if carries == 0: return whileNode
446//          carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
447//          local_carryvar = 'sub' + self.carryvar.id
448//          inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
449//          self.generic_visit(whileNode)
450//          local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
451//          inner_while.body.insert(0, local_carry_decl)
452//          final_combine = mkCallStmt(ast.Attribute(self.carryvar, 'CarryCombine', ast.Load()), [ast.Attribute(ast.Name(local_carryvar, ast.Load()), 'cq', ast.Load()),ast.Num(carry_base), ast.Num(carries)])
453//          inner_while.body.append(final_combine)
454//          #CARRYSET
455//          if self.carryin == '': new_test = whileNode.test
456//          else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
457//          else_part = [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]   
458//          return ast.If(new_test, whileNode.body + [inner_while], else_part)
Note: See TracBrowser for help on using the repository browser.