source: proto/pablo/src/compiler/backend/visitors/CarryIntroXFormer.java @ 2710

Last change on this file since 2710 was 2710, checked in by ksherdy, 6 years ago

General refactoring.

File size: 8.0 KB
Line 
1package compiler.backend.visitors;
2
3import ast.*;
4import tokens.LextantToken;
5import lexicalAnalyzer.Lextant;
6import compiler.ast.Accessors;
7import compiler.ast.Generators;
8import compiler.lang.carryset.CarrySet;
9import compiler.lang.carryset.CarrySet2Lang;
10import compiler.lang.pablo.BuiltinsUtil;
11
12// TODO - Add while type to switch on/off carry in mode. Carry in is an internal to this visitor.
13
14public class CarryIntroXFormer {
15               
16        private ASTNode ASTTree; 
17        private CarrySet2Lang carrySet2Lang;
18               
19    public CarryIntroXFormer(ASTNode node, CarrySet2Lang carrySet2Lang) {
20        this.ASTTree = node; 
21        this.carrySet2Lang = carrySet2Lang;
22    }
23
24    public void XForm(/*boolean ci , boolean co*/) { 
25                XFormer visitor = new XFormer(/*ci , co*/);
26                ASTTree.accept(visitor);
27    }                   
28       
29        private class XFormer extends VoidVisitor.Default {
30
31                //private final String ciSuffix = "_ci";
32                //private final String coSuffix = "_co";
33                //private String [] pendindCarryQName = {CarryQ.CarryQ_PACKAGE_NAME, CarryQ.GETCARRYIN.cPPCode()};
34                //private String [] pending64QName = {CarryQ.CarryQ_PACKAGE_NAME, CarryQ.GETPENDING64.cPPCode()};
35               
36                private boolean isFinalBlock;
37               
38                private boolean ciMode;
39                //private boolean coMode;
40                               
41                private int currentCarry;
42                //private int currentAdvN;
43                //private int lastStmtCarries;
44               
45                XFormer(/*boolean ciMode , boolean coMode */) {
46                        this.ciMode = true;
47                        //this.coMode = coMode;
48                        this.currentCarry = 0;
49                        //this.currentAdvN = 0;
50                        //this.lastStmtCarries = 0;
51                }
52               
53                //              def xfrm_fndef(self, fndef):
54                //          self.current_carry = 0
55                //          self.current_adv_n = 0
56                //          carry_count = CarryCounter().count(fndef)
57                //          if carry_count == 0: return fndef
58                //          self.generic_visit(fndef)
59                //      #   
60                //      #    fndef.body.insert(0, mkCallStmt('CarryDeclare', [self.carryvar, ast.Num(carry_count)]))
61                //          return fndef
62                public void visitEnter(FuncDefNode node) { 
63                        this.currentCarry = 0;
64                        //this.currentAdvN = 0;
65                        //this.lastStmtCarries = 0;
66                       
67                        int carryCount = (new CarryCounterVisitor(node)).count();
68                        if(carryCount > 0) { 
69                       
70                                IntegerConstantNode carryCountNode =  Generators.makeIntegerConstantNode(carryCount, node.getToken());
71                               
72                                FuncCallNode carryAdjustFuncCall = (FuncCallNode) Generators.makeFuncCallNode(
73                                                new String [] {CarrySet.CarryQ_IDENTIFIER, carrySet2Lang.getCode(CarrySet.CARRYQADJUST)},
74                                                node.getToken(),
75                                                new ASTNode [] {carryCountNode});
76                               
77                                BlockStmtNode blockStmtNode = Accessors.blockStmtNode(node);
78                                blockStmtNode.appendChild(carryAdjustFuncCall);
79                        }
80                }               
81               
82                public void visitLeave(FuncCallNode node) {
83                        if(BuiltinsUtil.isCarryOne(node)) {
84                                this.currentCarry += 1;
85                        }                                       
86                }
87               
88                //def visit_If(self, ifNode):
89                //carry_base = self.current_carry
90                //carries = CarryCounter().count(ifNode)
91                //assert adv_nCounter().count(ifNode) == 0, "Advance(x,n) within if: illegal\n"
92                //self.generic_visit(ifNode)
93                //if carries == 0 or self.carryin == "": return ifNode
94                //#CARRYSET
95                //carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
96                //new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
97                //new_else_part = ifNode.orelse + [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]
98                //return ast.If(new_test, ifNode.body, new_else_part)
99
100                public void visitEnter(IfStmtNode node) { // Current if strategy does not test any_carry() on n bits
101                        assert (new AdvanceNCounterVisitor(node).count() == 0): "Advance(x,n) within if: illegal\n";
102                       
103                        int carryBase = this.currentCarry;
104                        int carryCount = (new CarryCounterVisitor(node)).count();
105                                               
106                        if(carryCount == 0) { 
107                                return;
108                        }
109                       
110                        if(!this.ciMode) { // while loop body
111                                return;
112                        }
113                               
114                        // if test, replace if test
115                        String lexeme = Lextant.OR.getPrimaryLexeme();
116                        LextantToken binaryOperatorToken = LextantToken.make(node.getToken().getLocation(), lexeme, Lextant.OR);
117                       
118                        ASTNode lhs = Accessors.ifTest(node);
119
120                        IntegerConstantNode carryBaseNode = Generators.makeIntegerConstantNode(carryBase, node.getToken());
121                        IntegerConstantNode carryCountNode =  Generators.makeIntegerConstantNode(carryCount, node.getToken());
122               
123                        FuncCallNode rhs = (FuncCallNode) Generators.makeFuncCallNode(
124                                        new String [] {CarrySet.CarryQ_IDENTIFIER, carrySet2Lang.getCode(CarrySet.CARRYTEST)},
125                                        node.getToken(),
126                                        new ASTNode [] {carryBaseNode, carryCountNode});
127                       
128                        BinaryOperatorNode replacementIfTestNode = Generators.makeBinaryOperatorNode(lhs, 
129                                                                                                                                        rhs,
130                                                                                                                                        binaryOperatorToken);
131                       
132                        node.replaceChild(Accessors.ifTest(node), replacementIfTestNode);
133                       
134                        // else part, append CarryDequeueEnqueue call
135                        FuncCallNode carryDequeueEnqueue = (FuncCallNode) Generators.makeFuncCallNode(
136                                        new String [] {CarrySet.CarryQ_IDENTIFIER, carrySet2Lang.getCode(CarrySet.CARRYDEQUEUEENQUEUE)},
137                                        node.getToken(),
138                                        new ASTNode [] {carryBaseNode, carryCountNode});
139
140                       
141                        if (Accessors.hasElseBlockStmt(node)) { 
142                                Accessors.elseBlockStmt(node).appendChild(carryDequeueEnqueue);
143                        } else {
144                                BlockStmtNode blockStmtNode = 
145                                                Generators.makeBlockStmtNode(LextantToken.make(node.getToken().getLocation(),
146                                                                                                                        Lextant.LCURLY.getPrimaryLexeme(),
147                                                                                                                        Lextant.LCURLY));
148                                blockStmtNode.appendChild(carryDequeueEnqueue);
149                                node.appendChild(blockStmtNode); 
150                        }
151                                               
152                }
153        }
154}
155
156//class CarryIntro(ast.NodeXFormer):
157
158//        def generic_xfrm(self, node):
159//          self.current_carry = 0
160//          self.current_adv_n = 0
161//          carry_count = CarryCounter().count(node)
162//          adv_n_count = adv_nCounter().count(node)
163//          if carry_count == 0 and adv_n_count == 0: return node
164//          self.generic_visit(node)
165//          return node
166
167//def visit_If(self, ifNode):
168//carry_base = self.current_carry
169//carries = CarryCounter().count(ifNode)
170//assert adv_nCounter().count(ifNode) == 0, "Advance(x,n) within if: illegal\n"
171//self.generic_visit(ifNode)
172//if carries == 0 or self.carryin == "": return ifNode
173//#CARRYSET
174//carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
175//new_test = ast.BoolOp(ast.Or(), [ifNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
176//new_else_part = ifNode.orelse + [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]
177//return ast.If(new_test, ifNode.body, new_else_part)
178
179//        def visit_While(self, whileNode):
180//          if self.carryout == '':
181//            whileNode.test.args[0] = mkCall("simd_and", [whileNode.test.args[0], ast.Name('EOF_mask', ast.Load())])
182//          carry_base = self.current_carry
183//          assert adv_nCounter().count(whileNode) == 0, "Advance(x,n) within while: illegal\n"
184//          carries = CarryCounter().count(whileNode)
185//          #CARRYSET
186//          if carries == 0: return whileNode
187//          carry_arglist = [ast.Num(carry_base), ast.Num(carries)]
188//          local_carryvar = 'sub' + self.carryvar.id
189//          inner_while = CarryIntro(local_carryvar, '', self.carryout).generic_xfrm(copy.deepcopy(whileNode))
190//          self.generic_visit(whileNode)
191//          local_carry_decl = mkCallStmt('LocalCarryDeclare', [ast.Name(local_carryvar, ast.Load()), ast.Num(carries)])
192//          inner_while.body.insert(0, local_carry_decl)
193//          final_combine = mkCallStmt(ast.Attribute(self.carryvar, 'CarryCombine', ast.Load()), [ast.Attribute(ast.Name(local_carryvar, ast.Load()), 'cq', ast.Load()),ast.Num(carry_base), ast.Num(carries)])
194//          inner_while.body.append(final_combine)
195//          #CARRYSET
196//          if self.carryin == '': new_test = whileNode.test
197//          else: new_test = ast.BoolOp(ast.Or(), [whileNode.test, mkCall(ast.Attribute(self.carryvar, 'CarryTest', ast.Load()), carry_arglist)])
198//          else_part = [mkCallStmt(ast.Attribute(self.carryvar, 'CarryDequeueEnqueue', ast.Load()), carry_arglist)]   
199//          return ast.If(new_test, whileNode.body + [inner_while], else_part)
Note: See TracBrowser for help on using the repository browser.