Ignore:
Timestamp:
Sep 23, 2016, 4:12:41 PM (3 years ago)
Author:
nmedfort
Message:

Initial work for incorporating Types into Pablo AST.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/pablo/optimizers/booleanreassociationpass.cpp

    r5157 r5160  
    302302 ** ------------------------------------------------------------------------------------------------------------- */
    303303inline bool BooleanReassociationPass::processScopes(PabloFunction & function) {
    304     mRefs.clear();
    305304    CharacterizationMap C;
    306305    PabloBlock * const entry = function.getEntryBlock();
     
    313312    mInFile = makeVar();
    314313    processScopes(entry, C);
     314    for (auto i = mRefs.begin(); i != mRefs.end(); ++i) {
     315        Z3_dec_ref(mContext, *i);
     316    }
     317    mRefs.clear();
    315318    return mModified;
    316319}
     
    345348            }
    346349        } else { // characterize this statement then check whether it is equivalent to any existing one.
    347             stmt = characterize(stmt, C);
     350            PabloAST * const folded = Simplifier::fold(stmt, block);
     351            if (LLVM_UNLIKELY(folded != nullptr)) {
     352                stmt = stmt->replaceWith(folded);
     353            } else {
     354                stmt = characterize(stmt, C);
     355            }
    348356        }
    349357    }   
     
    359367 ** ------------------------------------------------------------------------------------------------------------- */
    360368inline Statement * BooleanReassociationPass::characterize(Statement * const stmt, CharacterizationMap & C) {
     369
    361370    Z3_ast node = nullptr;
    362371    const size_t n = stmt->getNumOperands(); assert (n > 0);
     372    bool use_expensive_simplification = false;
    363373    if (isa<Variadic>(stmt)) {
    364374        Z3_ast operands[n];
    365375        for (size_t i = 0; i < n; ++i) {
    366             operands[i] = C.get(stmt->getOperand(i)); assert (operands[i]);
     376            PabloAST * const op = stmt->getOperand(i);
     377            if (isa<Not>(op)) {
     378                use_expensive_simplification = true;
     379            }
     380            operands[i] = C.get(op); assert (operands[i]);
    367381        }
    368382        if (isa<And>(stmt)) {
     
    395409        check[1] = isa<InFile>(stmt) ? mInFile : Z3_mk_not(mContext, mInFile); assert (check[1]);
    396410        node = Z3_mk_and(mContext, 2, check);
    397     } else {
    398         if (LLVM_UNLIKELY(isa<Assign>(stmt) || isa<Next>(stmt))) {
    399             Z3_ast op = C.get(stmt->getOperand(0)); assert (op);
    400             C.add(stmt, op, true);
    401         } else {
    402             C.add(stmt, makeVar());
    403         }
     411    } else if (LLVM_UNLIKELY(isa<Assign>(stmt) || isa<Next>(stmt))) {
    404412        return stmt->getNextNode();
    405     }
    406     Z3_inc_ref(mContext, node);
    407     node = simplify(node);
     413    }  else {
     414        C.add(stmt, makeVar());
     415        return stmt->getNextNode();
     416    }
     417    node = simplify(node, use_expensive_simplification);
    408418    PabloAST * const replacement = C.findKey(node);
    409419    if (LLVM_LIKELY(replacement == nullptr)) {
     
    440450 * are "flattened" (i.e., allowed to have any number of inputs.)
    441451 ** ------------------------------------------------------------------------------------------------------------- */
     452Vertex BooleanReassociationPass::transcribeSel(Sel * const stmt, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
     453
     454    Z3_ast args[2];
     455
     456    const Vertex c = makeVertex(TypeId::Var, stmt->getCondition(), C, S, M, G);
     457    const Vertex t = makeVertex(TypeId::Var, cast<Sel>(stmt)->getTrueExpr(), C, S, M, G);
     458    const Vertex f = makeVertex(TypeId::Var, cast<Sel>(stmt)->getFalseExpr(), C, S, M, G);
     459
     460    args[0] = getDefinition(G[c]);
     461    args[1] = getDefinition(G[t]);
     462
     463    Z3_ast trueExpr = Z3_mk_and(mContext, 2, args);
     464    Z3_inc_ref(mContext, trueExpr);
     465    mRefs.push_back(trueExpr);
     466
     467    const Vertex x = makeVertex(TypeId::And, nullptr, G, trueExpr);
     468    add_edge(nullptr, c, x, G);
     469    add_edge(nullptr, t, x, G);
     470
     471    Z3_ast notCond = Z3_mk_not(mContext, args[0]);
     472    Z3_inc_ref(mContext, notCond);
     473    mRefs.push_back(notCond);
     474
     475    args[0] = notCond;
     476    args[1] = getDefinition(G[f]);
     477
     478    Z3_ast falseExpr = Z3_mk_and(mContext, 2, args);
     479    Z3_inc_ref(mContext, falseExpr);
     480    mRefs.push_back(falseExpr);
     481
     482    const Vertex n = makeVertex(TypeId::Not, nullptr, G, notCond);
     483
     484    add_edge(nullptr, c, n, G);
     485
     486    const Vertex y = makeVertex(TypeId::And, nullptr, G, falseExpr);
     487    add_edge(nullptr, n, y, G);
     488    add_edge(nullptr, f, y, G);
     489
     490    const Vertex u = makeVertex(TypeId::Or, stmt, C, S, M, G);
     491    add_edge(nullptr, x, u, G);
     492    add_edge(nullptr, y, u, G);
     493
     494    return u;
     495}
     496
     497/** ------------------------------------------------------------------------------------------------------------- *
     498 * @brief summarizeAST
     499 *
     500 * This function scans through a scope block and computes a DAG G in which any sequences of AND, OR or XOR functions
     501 * are "flattened" (i.e., allowed to have any number of inputs.)
     502 ** ------------------------------------------------------------------------------------------------------------- */
    442503void BooleanReassociationPass::transformAST(CharacterizationMap & C, Graph & G) {
    443504
    444505    StatementMap S;
     506
     507    VertexMap M;
    445508
    446509    // Compute the base def-use graph ...
    447510    for (Statement * stmt : *mBlock) {
    448         const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, S, G, C.get(stmt));
    449         for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
    450             PabloAST * const op = stmt->getOperand(i);
    451             if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
    452                 add_edge(op, makeVertex(TypeId::Var, op, C, S, G), u, G);
    453             }
    454         }
    455         if (LLVM_UNLIKELY(isa<If>(stmt))) {
    456             for (Assign * def : cast<const If>(stmt)->getDefined()) {
    457                 const Vertex v = makeVertex(TypeId::Var, def, C, S, G);
    458                 add_edge(def, u, v, G);
    459                 resolveNestedUsages(def, v, C, S, G, stmt);
    460             }
    461         } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
    462             // To keep G a DAG, we need to do a bit of surgery on loop variants because
    463             // the next variables it produces can be used within the condition. Instead,
    464             // we make the loop dependent on the original value of each Next node and
    465             // the Next node dependent on the loop.
    466             for (Next * var : cast<const While>(stmt)->getVariants()) {
    467                 const Vertex v = makeVertex(TypeId::Var, var, C, S, G);
    468                 assert (in_degree(v, G) == 1);
    469                 auto e = first(in_edges(v, G));
    470                 add_edge(G[e], source(e, G), u, G);
    471                 remove_edge(v, u, G);
    472                 add_edge(var, u, v, G);
    473                 resolveNestedUsages(var, v, C, S, G, stmt);
    474             }
     511        if (LLVM_UNLIKELY(isa<Sel>(stmt))) {
     512
     513            const Vertex u = transcribeSel(cast<Sel>(stmt), C, S, M, G);
     514
     515            resolveNestedUsages(stmt, u, C, S, M, G, stmt);
     516
    475517        } else {
    476             resolveNestedUsages(stmt, u, C, S, G, stmt);
    477         }
    478     }
    479 
    480 //    printGraph(G, "G");
    481 
    482     VertexMap M;
     518
     519
     520            const Vertex u = makeVertex(stmt->getClassTypeId(), stmt, C, S, M, G);
     521            for (unsigned i = 0; i < stmt->getNumOperands(); ++i) {
     522                PabloAST * const op = stmt->getOperand(i);
     523                if (LLVM_LIKELY(isa<Statement>(op) || isa<Var>(op))) {
     524                    add_edge(op, makeVertex(TypeId::Var, op, C, S, M, G), u, G);
     525                }
     526            }
     527            if (LLVM_UNLIKELY(isa<If>(stmt))) {
     528                for (Assign * def : cast<const If>(stmt)->getDefined()) {
     529                    const Vertex v = makeVertex(TypeId::Var, def, C, S, M, G);
     530                    add_edge(def, u, v, G);
     531                    resolveNestedUsages(def, v, C, S, M, G, stmt);
     532                }
     533                continue;
     534            } else if (LLVM_UNLIKELY(isa<While>(stmt))) {
     535                // To keep G a DAG, we need to do a bit of surgery on loop variants because
     536                // the next variables it produces can be used within the condition. Instead,
     537                // we make the loop dependent on the original value of each Next node and
     538                // the Next node dependent on the loop.
     539                for (Next * var : cast<const While>(stmt)->getVariants()) {
     540                    const Vertex v = makeVertex(TypeId::Var, var, C, S, M, G);
     541                    assert (in_degree(v, G) == 1);
     542                    auto e = first(in_edges(v, G));
     543                    add_edge(G[e], source(e, G), u, G);
     544                    remove_edge(v, u, G);
     545                    add_edge(var, u, v, G);
     546                    resolveNestedUsages(var, v, C, S, M, G, stmt);
     547                }
     548                continue;
     549            } else {
     550                resolveNestedUsages(stmt, u, C, S, M, G, stmt);
     551            }
     552        }
     553
     554    }
     555
    483556    if (redistributeGraph(C, M, G)) {
    484557        factorGraph(G);
    485 
    486 //        printGraph(G, "H");
    487 
    488         rewriteAST(C, M, G);
     558        rewriteAST(G);
    489559        mModified = true;
    490560    }
     
    496566 ** ------------------------------------------------------------------------------------------------------------- */
    497567void BooleanReassociationPass::resolveNestedUsages(PabloAST * const expr, const Vertex u,
    498                                                    CharacterizationMap & C, StatementMap & S, Graph & G,
     568                                                   CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G,
    499569                                                   const Statement * const ignoreIfThis) const {
    500570    assert ("Cannot resolve nested usages of a null expression!" && expr);
     
    504574            if (LLVM_UNLIKELY(parent != mBlock)) {
    505575                for (;;) {
    506                     if (parent->getParent() == mBlock) {
     576                    if (parent->getPredecessor () == mBlock) {
    507577                        Statement * const branch = parent->getBranch();
    508578                        if (LLVM_UNLIKELY(branch != ignoreIfThis)) {
    509579                            // Add in a Var denoting the user of this expression so that it can be updated if expr changes.
    510                             const Vertex v = makeVertex(TypeId::Var, user, C, S, G);
     580                            const Vertex v = makeVertex(TypeId::Var, user, C, S, M, G);
    511581                            add_edge(expr, u, v, G);
    512582                            const Vertex w = makeVertex(branch->getClassTypeId(), branch, S, G);
     
    515585                        break;
    516586                    }
    517                     parent = parent->getParent();
     587                    parent = parent->getPredecessor ();
    518588                    if (LLVM_UNLIKELY(parent == nullptr)) {
    519589                        assert (isa<Assign>(expr) || isa<Next>(expr));
     
    783853
    784854/** ------------------------------------------------------------------------------------------------------------- *
     855 * @brief recomputeDefinition
     856 ** ------------------------------------------------------------------------------------------------------------- */
     857inline Z3_ast BooleanReassociationPass::computeDefinition(const TypeId typeId, const Vertex u, Graph & G, const bool use_expensive_minimization) const {
     858    const unsigned n = in_degree(u, G);
     859    Z3_ast operands[n];
     860    unsigned k = 0;
     861    for (const auto e : make_iterator_range(in_edges(u, G))) {
     862        const auto v = source(e, G);
     863        if (LLVM_UNLIKELY(getDefinition(G[v]) == nullptr)) {
     864            throw std::runtime_error("No definition for " + std::to_string(v));
     865        }
     866        operands[k++] = getDefinition(G[v]);
     867    }
     868    assert (k == n);
     869    Z3_ast const node = (typeId == TypeId::And) ? Z3_mk_and(mContext, n, operands) : Z3_mk_or(mContext, n, operands);
     870    return simplify(node, use_expensive_minimization);
     871}
     872
     873/** ------------------------------------------------------------------------------------------------------------- *
     874 * @brief updateDefinition
     875 *
     876 * Apply the distribution law to reduce computations whenever possible.
     877 ** ------------------------------------------------------------------------------------------------------------- */
     878Vertex BooleanReassociationPass::updateIntermediaryDefinition(const TypeId typeId, const Vertex u, VertexMap & M, Graph & G) {
     879
     880    Z3_ast def = computeDefinition(typeId, u, G);
     881    Z3_ast orig = getDefinition(G[u]); assert (orig);
     882
     883    Z3_dec_ref(mContext, orig);
     884
     885    const auto g = M.find(orig);
     886    if (LLVM_LIKELY(g != M.end())) {
     887        M.erase(g);
     888    }
     889
     890    const auto f = std::find(mRefs.rbegin(), mRefs.rend(), orig);
     891    assert (f != mRefs.rend());
     892    *f = def;
     893
     894    const auto h = M.find(def);
     895    if (LLVM_UNLIKELY(h != M.end())) {
     896        const auto v = h->second;
     897        if (v != u) {
     898            for (auto e : make_iterator_range(out_edges(u, G))) {
     899                add_edge(G[e], v, target(e, G), G);
     900            }
     901            removeVertex(u, G);
     902            return v;
     903        }
     904    }
     905
     906    getDefinition(G[u]) = def;
     907    M.emplace(def, u);
     908    return u;
     909}
     910
     911/** ------------------------------------------------------------------------------------------------------------- *
     912 * @brief updateDefinition
     913 *
     914 * Apply the distribution law to reduce computations whenever possible.
     915 ** ------------------------------------------------------------------------------------------------------------- */
     916Vertex BooleanReassociationPass::updateSinkDefinition(const TypeId typeId, const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G) {
     917
     918    Z3_ast const def = computeDefinition(typeId, u, G);
     919
     920    auto f = M.find(def);
     921
     922    if (LLVM_UNLIKELY(f != M.end())) {
     923        Z3_dec_ref(mContext, def);
     924        Vertex v = f->second; assert (v != u);
     925        for (auto e : make_iterator_range(out_edges(u, G))) {
     926            add_edge(G[e], v, target(e, G), G);
     927        }
     928        removeVertex(u, G);
     929        return v;
     930    } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
     931        PabloAST * const factor = C.predecessor()->findKey(def);
     932        if (LLVM_UNLIKELY(factor != nullptr)) {
     933            getValue(G[u]) = factor;
     934            getType(G[u]) = TypeId::Var;
     935            clear_in_edges(u, G);
     936        }
     937    }
     938
     939    getDefinition(G[u]) = def;
     940    mRefs.push_back(def);
     941
     942    graph_traits<Graph>::in_edge_iterator begin, end;
     943
     944restart:
     945
     946    if (in_degree(u, G) > 1) {
     947        std::tie(begin, end) = in_edges(u, G);
     948        for (auto i = begin; ++i != end; ) {
     949            const auto v = source(*i, G);
     950            for (auto j = begin; j != i; ++j) {
     951                const auto w = source(*j, G);
     952                Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
     953                Z3_ast test = nullptr;
     954                switch (typeId) {
     955                    case TypeId::And:
     956                        test = Z3_mk_and(mContext, 2, operands); break;
     957                    case TypeId::Or:
     958                        test = Z3_mk_or(mContext, 2, operands); break;
     959                    case TypeId::Xor:
     960                        test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
     961                    default:
     962                        llvm_unreachable("impossible type id");
     963                }
     964                test = simplify(test, true);
     965
     966                bool replacement = false;
     967                Vertex x = 0;
     968                const auto f = M.find(test);
     969                if (LLVM_UNLIKELY(f != M.end())) {
     970                    x = f->second;
     971                    Z3_ast orig = getDefinition(G[x]);
     972                    if (LLVM_UNLIKELY(orig != test)) {
     973                        std::string tmp;
     974                        raw_string_ostream out(tmp);
     975                        out << "vertex " << x << " is mapped to:\n"
     976                            << Z3_ast_to_string(mContext, test)
     977                            << "\n\nBut is recorded as:\n\n";
     978                        if (orig) {
     979                            out << Z3_ast_to_string(mContext, orig);
     980                        } else {
     981                            out << "<null>";
     982                        }
     983                        throw std::runtime_error(out.str());
     984                    }
     985                    Z3_dec_ref(mContext, test);
     986                    replacement = true;
     987                } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
     988                    PabloAST * const factor = C.predecessor()->findKey(test);
     989                    if (LLVM_UNLIKELY(factor != nullptr)) {
     990                        x = makeVertex(TypeId::Var, factor, G, test);
     991                        M.emplace(test, x);
     992                        replacement = true;
     993                        mRefs.push_back(test);
     994                    }
     995                }
     996
     997                if (LLVM_UNLIKELY(replacement)) {
     998
     999                    assert (G[*i] == nullptr);
     1000                    assert (G[*j] == nullptr);
     1001
     1002                    remove_edge(*i, G);
     1003                    remove_edge(*j, G);
     1004
     1005                    add_edge(nullptr, x, u, G);
     1006
     1007                    goto restart;
     1008                }
     1009
     1010                Z3_dec_ref(mContext, test);
     1011            }
     1012        }
     1013    }
     1014
     1015    M.emplace(def, u);
     1016
     1017    return u;
     1018}
     1019
     1020/** ------------------------------------------------------------------------------------------------------------- *
    7851021 * @brief redistributeAST
    7861022 *
     
    7911027    bool modified = false;
    7921028
     1029//    errs() << "=====================================================\n";
     1030
    7931031    DistributionGraph H;
    7941032
    795     contractGraph(G);
    796 
    7971033    for (;;) {
    7981034
    799         for (;;) {
    800 
    801             generateDistributionGraph(G, H);
    802 
    803             // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
    804             if (num_vertices(H) == 0) {
    805                 break;
    806             }
    807 
    808             const DistributionSets distributionSets = safeDistributionSets(G, H);
    809 
    810             if (LLVM_UNLIKELY(distributionSets.empty())) {
    811                 break;
    812             }
    813 
    814             modified = true;
    815 
    816             for (const DistributionSet & set : distributionSets) {
    817 
    818                 // Each distribution tuple consists of the sources, intermediary, and sink nodes.
    819                 const VertexSet & sources = std::get<0>(set);
    820                 const VertexSet & intermediary = std::get<1>(set);
    821                 const VertexSet & sinks = std::get<2>(set);
    822 
    823                 const TypeId outerTypeId = getType(G[H[sinks.front()]]);
    824                 assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
    825                 const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
    826 
    827                 const Vertex x = makeVertex(outerTypeId, nullptr, G);
    828                 const Vertex y = makeVertex(innerTypeId, nullptr, G);
    829 
    830                 // Update G to reflect the distributed operations (including removing the subgraph of
    831                 // the to-be distributed edges.)
    832 
    833                 add_edge(nullptr, x, y, G);
    834 
    835                 for (const Vertex i : sources) {
    836                     const auto u = H[i];
    837                     for (const Vertex j : intermediary) {
    838                         const auto v = H[j];
    839                         const auto e = edge(u, v, G); assert (e.second);
    840                         remove_edge(e.first, G);
    841                     }
    842                     add_edge(nullptr, u, y, G);
    843                 }
    844 
    845                 for (const Vertex i : intermediary) {
    846                     const auto u = H[i];
    847                     for (const Vertex j : sinks) {
    848                         const auto v = H[j];
    849                         const auto e = edge(u, v, G); assert (e.second);
    850                         add_edge(G[e.first], y, v, G);
    851                         remove_edge(e.first, G);
    852                     }
    853                     add_edge(nullptr, u, x, G);
    854                     getDefinition(G[u]) = nullptr;
    855                 }
    856 
    857             }
    858 
    859             H.clear();
    860 
    861             contractGraph(G);
    862         }
    863 
    864         // Although exceptionally unlikely, it's possible that if we can reduce the graph, we could
    865         // further simplify it. Restart the process if and only if we succeed.
    866         if (reduceGraph(C, M, G)) {
    867             if (LLVM_UNLIKELY(contractGraph(G))) {
    868                 H.clear();
    869                 continue;
    870             }
    871         }
    872 
    873         break;
     1035        contractGraph(M, G);
     1036
     1037//        printGraph(G, "G");
     1038
     1039        generateDistributionGraph(G, H);
     1040
     1041        // If we found no potential opportunities then we cannot apply the distribution law to any part of G.
     1042        if (num_vertices(H) == 0) {
     1043            break;
     1044        }
     1045
     1046        const DistributionSets distributionSets = safeDistributionSets(G, H);
     1047
     1048        if (LLVM_UNLIKELY(distributionSets.empty())) {
     1049            break;
     1050        }
     1051
     1052        modified = true;
     1053
     1054        mRefs.reserve(distributionSets.size() * 2);
     1055
     1056        for (const DistributionSet & set : distributionSets) {
     1057
     1058            // Each distribution tuple consists of the sources, intermediary, and sink nodes.
     1059            const VertexSet & sources = std::get<0>(set);
     1060            const VertexSet & intermediary = std::get<1>(set);
     1061            const VertexSet & sinks = std::get<2>(set);
     1062
     1063            const TypeId outerTypeId = getType(G[H[sinks.front()]]);
     1064            assert (outerTypeId == TypeId::And || outerTypeId == TypeId::Or);
     1065            const TypeId innerTypeId = (outerTypeId == TypeId::Or) ? TypeId::And : TypeId::Or;
     1066
     1067            const Vertex x = makeVertex(outerTypeId, nullptr, G);
     1068            const Vertex y = makeVertex(innerTypeId, nullptr, G);
     1069
     1070            // Update G to reflect the distributed operations (including removing the subgraph of
     1071            // the to-be distributed edges.)
     1072
     1073            add_edge(nullptr, x, y, G);
     1074
     1075            for (const Vertex i : sources) {
     1076                const auto u = H[i];
     1077                for (const Vertex j : intermediary) {
     1078                    const auto v = H[j];
     1079                    assert (getType(G[v]) == innerTypeId);
     1080                    const auto e = edge(u, v, G); assert (e.second);
     1081                    remove_edge(e.first, G);
     1082                }
     1083                add_edge(nullptr, u, y, G);
     1084            }
     1085
     1086            for (const Vertex i : intermediary) {
     1087
     1088                const auto u = updateIntermediaryDefinition(innerTypeId, H[i], M, G);
     1089
     1090                for (const Vertex j : sinks) {
     1091                    const auto v = H[j];
     1092                    assert (getType(G[v]) == outerTypeId);
     1093                    const auto e = edge(u, v, G); assert (e.second);
     1094                    add_edge(G[e.first], y, v, G);
     1095                    remove_edge(e.first, G);
     1096                }
     1097                add_edge(nullptr, u, x, G);
     1098            }
     1099
     1100            updateSinkDefinition(outerTypeId, x, C, M, G);
     1101
     1102            updateSinkDefinition(innerTypeId, y, C, M, G);
     1103
     1104        }
     1105
     1106        H.clear();
     1107
    8741108    }
    8751109
     
    9301164}
    9311165
    932 
    9331166/** ------------------------------------------------------------------------------------------------------------- *
    9341167 * @brief contractGraph
    9351168 ** ------------------------------------------------------------------------------------------------------------- */
    936 bool BooleanReassociationPass::contractGraph(Graph & G) const {
     1169bool BooleanReassociationPass::contractGraph(VertexMap & M, Graph & G) const {
    9371170
    9381171    bool contracted = false;
     
    9571190                        add_edge(G[ej], v, target(ej, G), G);
    9581191                    }
    959                     removeVertex(u, G);
     1192                    removeVertex(u, M, G);
    9601193                    contracted = true;
    9611194                } else if (LLVM_UNLIKELY(has_unique_target(u, G))) {
     
    9691202                            add_edge(G[ej], source(ej, G), v, G);
    9701203                        }
    971                         removeVertex(u, G);
     1204                        removeVertex(u, M, G);
    9721205                        contracted = true;
    9731206                    }
     
    9751208            }
    9761209        } else if (LLVM_UNLIKELY(isNonEscaping(G[u]))) {
    977             removeVertex(u, G);
     1210            removeVertex(u, M, G);
    9781211            contracted = true;
    9791212        }
    9801213    }
    9811214    return contracted;
    982 }
    983 
    984 /** ------------------------------------------------------------------------------------------------------------- *
    985  * @brief isReducible
    986  ** ------------------------------------------------------------------------------------------------------------- */
    987 inline bool isReducible(const VertexData & data) {
    988     switch (getType(data)) {
    989         case TypeId::Var:
    990         case TypeId::If:
    991         case TypeId::While:
    992             return false;
    993         default:
    994             return true;
    995     }
    996 }
    997 
    998 /** ------------------------------------------------------------------------------------------------------------- *
    999  * @brief reduceGraph
    1000  ** ------------------------------------------------------------------------------------------------------------- */
    1001 BooleanReassociationPass::Reduction BooleanReassociationPass::reduceVertex(const Vertex u, CharacterizationMap & C, VertexMap & M, Graph & G, const bool use_expensive_simplification) {
    1002 
    1003     Reduction reduction = Reduction::NoChange;
    1004 
    1005     assert (isReducible(G[u]));
    1006 
    1007     Z3_ast node = getDefinition(G[u]);
    1008     if (isAssociative(G[u])) {
    1009         const TypeId typeId = getType(G[u]);
    1010         if (node == nullptr) {
    1011             const auto n = in_degree(u, G); assert (n > 1);
    1012             Z3_ast operands[n];
    1013             unsigned i = 0;
    1014             for (auto e : make_iterator_range(in_edges(u, G))) {
    1015                 const Vertex v = source(e, G);
    1016                 assert (getDefinition(G[v]));
    1017                 operands[i++] = getDefinition(G[v]);
    1018             }
    1019             switch (typeId) {
    1020                 case TypeId::And:
    1021                     node = Z3_mk_and(mContext, n, operands);
    1022                     break;
    1023                 case TypeId::Or:
    1024                     node = Z3_mk_or(mContext, n, operands);
    1025                     break;
    1026                 case TypeId::Xor:
    1027                     node = Z3_mk_xor(mContext, operands[0], operands[1]);
    1028                     for (unsigned i = 2; LLVM_UNLIKELY(i < n); ++i) {
    1029                         node = Z3_mk_xor(mContext, node, operands[i]);
    1030                     }
    1031                     break;
    1032                 default: llvm_unreachable("unexpected type id");
    1033             }
    1034             assert (node);
    1035             Z3_inc_ref(mContext, node);
    1036             mRefs.push_back(node);
    1037             getDefinition(G[u]) = node;
    1038         }
    1039 
    1040         graph_traits<Graph>::in_edge_iterator begin, end;
    1041 restart:if (in_degree(u, G) > 1) {
    1042             std::tie(begin, end) = in_edges(u, G);
    1043             for (auto i = begin; ++i != end; ) {
    1044                 const auto v = source(*i, G);
    1045                 for (auto j = begin; j != i; ++j) {
    1046                     const auto w = source(*j, G);
    1047                     Z3_ast operands[2] = { getDefinition(G[v]), getDefinition(G[w]) };
    1048                     Z3_ast test = nullptr;
    1049                     switch (typeId) {
    1050                         case TypeId::And:
    1051                             test = Z3_mk_and(mContext, 2, operands); break;
    1052                         case TypeId::Or:
    1053                             test = Z3_mk_or(mContext, 2, operands); break;
    1054                         case TypeId::Xor:
    1055                             test = Z3_mk_xor(mContext, operands[0], operands[1]); break;
    1056                         default:
    1057                             llvm_unreachable("impossible type id");
    1058                     }
    1059                     assert (test);
    1060                     Z3_inc_ref(mContext, test);
    1061                     test = simplify(test, use_expensive_simplification);
    1062                     bool replacement = false;
    1063                     Vertex x = 0;
    1064                     const auto f = M.find(test);
    1065                     if (LLVM_UNLIKELY(f != M.end())) {
    1066                         x = f->second;
    1067                         assert (getDefinition(G[x]) == test);
    1068                         Z3_dec_ref(mContext, test);
    1069                         replacement = true;
    1070                     } else if (LLVM_LIKELY(C.predecessor() != nullptr)) {
    1071                         PabloAST * const factor = C.predecessor()->findKey(test);
    1072                         if (LLVM_UNLIKELY(factor != nullptr)) {
    1073                             x = makeVertex(TypeId::Var, factor, G, test);
    1074                             M.emplace(test, x);
    1075                             replacement = true;
    1076                             mRefs.push_back(test);
    1077                         }
    1078                     }
    1079 
    1080                     if (LLVM_UNLIKELY(replacement)) {
    1081 
    1082                         // note: unless both edges carry an Pablo AST replacement value, they will converge into a single edge.
    1083                         PabloAST * const r1 = G[*i];
    1084                         PabloAST * const r2 = G[*j];
    1085 
    1086                         remove_edge(*i, G);
    1087                         remove_edge(*j, G);
    1088 
    1089                         if (LLVM_UNLIKELY(r1 && r2)) {
    1090                             add_edge(r1, x, u, G);
    1091                             add_edge(r2, x, u, G);
    1092                         } else {
    1093                             add_edge(r1 ? r1 : r2, x, u, G);
    1094                         }
    1095 
    1096                         reduction = Reduction::Simplified;
    1097 
    1098                         goto restart;
    1099                     }
    1100 
    1101                     Z3_dec_ref(mContext, test);
    1102                 }
    1103             }
    1104         }
    1105     }
    1106 
    1107     if (LLVM_UNLIKELY(node == nullptr)) {
    1108         throw std::runtime_error("No Z3 characterization for vertex " + std::to_string(u));
    1109     }
    1110 
    1111     auto f = M.find(node);
    1112     if (LLVM_LIKELY(f == M.end())) {
    1113         M.emplace(node, u);
    1114     } else if (isAssociative(G[u])) {
    1115         const Vertex v = f->second;
    1116         for (auto e : make_iterator_range(out_edges(u, G))) {
    1117             add_edge(G[e], v, target(e, G), G);
    1118         }
    1119         removeVertex(u, G);
    1120         reduction = Reduction::Removed;
    1121     }
    1122 
    1123     return reduction;
    1124 }
    1125 
    1126 /** ------------------------------------------------------------------------------------------------------------- *
    1127  * @brief reduceGraph
    1128  ** ------------------------------------------------------------------------------------------------------------- */
    1129 bool BooleanReassociationPass::reduceGraph(CharacterizationMap & C, VertexMap & M, Graph & G) {
    1130 
    1131     bool reduced = false;
    1132 
    1133     circular_buffer<Vertex> ordering(num_vertices(G));
    1134 
    1135     topological_sort(G, std::front_inserter(ordering)); // topological ordering
    1136 
    1137     M.clear();
    1138 
    1139     // first contract the graph
    1140     for (const Vertex u : ordering) {
    1141         if (isReducible(G[u])) {
    1142             if (reduceVertex(u, C, M, G, false) != Reduction::NoChange) {
    1143                 reduced = true;
    1144             }
    1145         }
    1146     }
    1147     return reduced;
    11481215}
    11491216
     
    12401307}
    12411308
    1242 
     1309/** ------------------------------------------------------------------------------------------------------------- *
     1310 * @brief isMutable
     1311 ** ------------------------------------------------------------------------------------------------------------- */
    12431312inline bool isMutable(const Vertex u, const Graph & G) {
    12441313    return getType(G[u]) != TypeId::Var;
     
    12481317 * @brief rewriteAST
    12491318 ** ------------------------------------------------------------------------------------------------------------- */
    1250 bool BooleanReassociationPass::rewriteAST(CharacterizationMap & C, VertexMap & M, Graph & G) {
     1319bool BooleanReassociationPass::rewriteAST(Graph & G) {
    12511320
    12521321    using line_t = long long int;
    12531322
    12541323    enum : line_t { MAX_INT = std::numeric_limits<line_t>::max() };
     1324
     1325    // errs() << "---------------------------------------------------------\n";
     1326
     1327    // printGraph(G, "X");
    12551328
    12561329    Z3_config cfg = Z3_mk_config();
     
    13271400    std::vector<line_t> L(num_vertices(G));
    13281401
    1329 
    1330 
    13311402    for (const Vertex u : make_iterator_range(vertices(G))) {
    13321403        line_t line = LoadEarly ? 0 : MAX_INT;
     
    13451416
    13461417    Z3_model_dec_ref(ctx, model);
     1418    Z3_solver_dec_ref(ctx, solver);
     1419    Z3_del_context(ctx);
    13471420
    13481421    std::sort(S.begin(), S.end(), [&L](const Vertex u, const Vertex v){ return L[u] < L[v]; });
     
    13531426
    13541427    line_t count = 1;
    1355 
    1356 //    errs() << "--------------------------------------------------\n";
    1357 
    1358 //    printGraph(G, "G");
    13591428
    13601429    for (auto u : S) {
     
    13641433        assert (L[u] > 0 && L[u] < MAX_INT);
    13651434
     1435        bool append = true;
     1436
    13661437        if (isAssociative(G[u])) {
    13671438
     
    13741445
    13751446            const auto typeId = getType(G[u]);
    1376 
    1377 // retry:
    13781447
    13791448            T.clear();
     
    14291498                            llvm_unreachable("Invalid TypeId!");
    14301499                    }
    1431 
    1432 //                    // If the insertion point isn't the statement we just attempted to create
    1433 //                    // we must have unexpectidly reused a prior statement (or Var.)
    1434 //                    if (LLVM_UNLIKELY(expr != mBlock->getInsertPoint())) {
    1435 //                        const auto reduction = reduceVertex(u, C, M, G, true);
    1436 //                        if (LLVM_UNLIKELY(reduction == Reduction::NoChange)) {
    1437 //                            throw std::runtime_error("Unable to reduce vertex " + std::to_string(u));
    1438 //                        } else if (LLVM_UNLIKELY(reduction == Reduction::Simplified)) {
    1439 //                            goto retry;
    1440 //                        } else { // if (reduction == Reduction::Removed) {
    1441 //                            mBlock->setInsertPoint(ip->getPrevNode());
    1442 //                            goto next_statement;
    1443 //                        }
    1444 //                    }
    14451500                }
    14461501                join = expr;
     
    14521507
    14531508            mBlock->setInsertPoint(ip->getPrevNode());
    1454 
    1455             for (auto e : make_iterator_range(out_edges(u, G))) {
    1456                 if (G[e]) {
    1457                     if (PabloAST * user = getValue(G[target(e, G)])) {
    1458                         cast<Statement>(user)->replaceUsesOfWith(G[e], expr);
    1459                     }
    1460                 }
    1461             }
    14621509
    14631510            stmt = expr;
     
    14791526                    }
    14801527                }
    1481                 continue;
    1482             }
    1483         }
    1484 
    1485         assert (stmt);
    1486 
    1487         if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
     1528                append = false;
     1529            }
     1530        } else if (stmt == nullptr) {
     1531            assert (getType(G[u]) == TypeId::Not);
     1532            assert (in_degree(u, G) == 1);
     1533            PabloAST * op = getValue(G[source(first(in_edges(u, G)), G)]); assert (op);
     1534            stmt = mBlock->createNot(op);
     1535        } else if (LLVM_UNLIKELY(isa<If>(stmt) || isa<While>(stmt))) {
    14881536            for (auto e : make_iterator_range(out_edges(u, G))) {
    14891537                const auto v = target(e, G);
     
    14931541        }
    14941542
    1495 //        PabloPrinter::print(cast<Statement>(stmt), errs()); errs() << "\n";
    1496 
    1497         mBlock->insert(cast<Statement>(stmt));
    1498         L[u] = count++; // update the line count with the actual one.
    1499 //        next_statement: continue;
    1500     }
    1501 
    1502     Z3_solver_dec_ref(ctx, solver);
    1503     Z3_del_context(ctx);
     1543        for (auto e : make_iterator_range(out_edges(u, G))) {
     1544            if (G[e]) {
     1545                if (PabloAST * user = getValue(G[target(e, G)])) {
     1546                    cast<Statement>(user)->replaceUsesOfWith(G[e], stmt);
     1547                }
     1548            }
     1549        }
     1550
     1551        if (LLVM_LIKELY(append)) {
     1552            mBlock->insert(cast<Statement>(stmt));
     1553            L[u] = count++; // update the line count with the actual one.
     1554        }
     1555    }
    15041556
    15051557    Statement * const end = mBlock->getInsertPoint(); assert (end);
     
    15341586
    15351587/** ------------------------------------------------------------------------------------------------------------- *
    1536  * @brief addSummaryVertex
     1588* @brief makeVertex
     1589** ------------------------------------------------------------------------------------------------------------- */
     1590Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & S, VertexMap & M, Graph & G) {
     1591    assert (expr);
     1592    const auto f = S.find(expr);
     1593    if (f != S.end()) {
     1594        assert (getValue(G[f->second]) == expr);
     1595        return f->second;
     1596    }
     1597    const auto node = C.get(expr);   
     1598    const Vertex u = makeVertex(typeId, expr, G, node);
     1599    S.emplace(expr, u);
     1600    if (node) {
     1601        M.emplace(node, u);
     1602    }
     1603    return u;
     1604}
     1605
     1606/** ------------------------------------------------------------------------------------------------------------- *
     1607 * @brief makeVertex
     1608 ** ------------------------------------------------------------------------------------------------------------- */
     1609Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
     1610    assert (expr);
     1611    const auto f = M.find(expr);
     1612    if (f != M.end()) {
     1613        assert (getValue(G[f->second]) == expr);
     1614        return f->second;
     1615    }
     1616    const Vertex u = makeVertex(typeId, expr, G, node);
     1617    M.emplace(expr, u);
     1618    return u;
     1619}
     1620
     1621/** ------------------------------------------------------------------------------------------------------------- *
     1622 * @brief makeVertex
    15371623 ** ------------------------------------------------------------------------------------------------------------- */
    15381624Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, Graph & G, Z3_ast node) {
     
    15481634
    15491635/** ------------------------------------------------------------------------------------------------------------- *
    1550  * @brief addSummaryVertex
    1551  ** ------------------------------------------------------------------------------------------------------------- */
    1552 Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, StatementMap & M, Graph & G, Z3_ast node) {
    1553     assert (expr);
    1554     const auto f = M.find(expr);
    1555     if (f != M.end()) {
    1556         assert (getValue(G[f->second]) == expr);
    1557         return f->second;
    1558     }
    1559     const Vertex u = makeVertex(typeId, expr, G, node);
    1560     M.emplace(expr, u);
    1561     return u;
    1562 }
    1563 
    1564 /** ------------------------------------------------------------------------------------------------------------- *
    1565  * @brief addSummaryVertex
    1566  ** ------------------------------------------------------------------------------------------------------------- */
    1567 Vertex BooleanReassociationPass::makeVertex(const TypeId typeId, PabloAST * const expr, CharacterizationMap & C, StatementMap & M, Graph & G) {
    1568     assert (expr);
    1569     const auto f = M.find(expr);
    1570     if (f != M.end()) {
    1571         assert (getValue(G[f->second]) == expr);
    1572         return f->second;
    1573     }
    1574     const Vertex u = makeVertex(typeId, expr, G, C.get(expr));
    1575     M.emplace(expr, u);
    1576     return u;
    1577 }
    1578 
    1579 /** ------------------------------------------------------------------------------------------------------------- *
    1580  * @brief removeSummaryVertex
    1581  ** ------------------------------------------------------------------------------------------------------------- */
    1582 inline void BooleanReassociationPass::removeVertex(const Vertex u, StatementMap & M, Graph & G) const {
     1636 * @brief removeVertex
     1637 ** ------------------------------------------------------------------------------------------------------------- */
     1638inline void BooleanReassociationPass::removeVertex(const Vertex u, VertexMap & M, Graph & G) const {
    15831639    VertexData & ref = G[u];
    1584     if (std::get<1>(ref)) {
    1585         auto f = M.find(std::get<1>(ref));
    1586         assert (f != M.end());
    1587         M.erase(f);
    1588     }
     1640    Z3_ast def = getDefinition(ref); assert (def);
     1641    auto f = M.find(def); assert (f != M.end());
     1642    M.erase(f);
    15891643    removeVertex(u, G);
    15901644}
    15911645
    15921646/** ------------------------------------------------------------------------------------------------------------- *
    1593  * @brief removeSummaryVertex
     1647 * @brief removeVertex
    15941648 ** ------------------------------------------------------------------------------------------------------------- */
    15951649inline void BooleanReassociationPass::removeVertex(const Vertex u, Graph & G) const {
     
    16141668 * @brief simplify
    16151669 ** ------------------------------------------------------------------------------------------------------------- */
    1616 Z3_ast BooleanReassociationPass::simplify(Z3_ast const node, bool use_expensive_minimization) const {
     1670inline Z3_ast BooleanReassociationPass::simplify(Z3_ast const node, bool use_expensive_minimization) const {
    16171671    assert (node);
    1618     Z3_ast result = Z3_simplify_ex(mContext, node, mParams);
    1619     Z3_inc_ref(mContext, result);
     1672    Z3_inc_ref(mContext, node);
     1673    Z3_ast result = nullptr;
    16201674    if (use_expensive_minimization) {
    1621         Z3_goal g = Z3_mk_goal(mContext, true, false, false);
     1675
     1676        Z3_goal g = Z3_mk_goal(mContext, true, false, false); assert (g);
    16221677        Z3_goal_inc_ref(mContext, g);
    1623         Z3_goal_assert(mContext, g, result);
    1624 
    1625         Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g);
     1678        Z3_goal_assert(mContext, g, node);
     1679
     1680        Z3_apply_result r = Z3_tactic_apply(mContext, mTactic, g); assert (r);
    16261681        Z3_apply_result_inc_ref(mContext, r);
     1682        Z3_goal_dec_ref(mContext, g);
     1683
    16271684        assert (Z3_apply_result_get_num_subgoals(mContext, r) == 1);
    16281685
    1629         Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0);
     1686        Z3_goal h = Z3_apply_result_get_subgoal(mContext, r, 0); assert (h);
    16301687        Z3_goal_inc_ref(mContext, h);
    1631         Z3_goal_dec_ref(mContext, g);
     1688        Z3_apply_result_dec_ref(mContext, r);
    16321689
    16331690        const unsigned n = Z3_goal_size(mContext, h);
    16341691
    1635         Z3_ast optimized = nullptr;
    16361692        if (n == 1) {
    1637             optimized = Z3_goal_formula(mContext, h, 0);
    1638             Z3_inc_ref(mContext, optimized);
    1639 
     1693            result = Z3_goal_formula(mContext, h, 0); assert (result);
     1694            Z3_inc_ref(mContext, result);
    16401695        } else if (n > 1) {
    16411696            Z3_ast operands[n];
     
    16441699                Z3_inc_ref(mContext, operands[i]);
    16451700            }
    1646             optimized = Z3_mk_and(mContext, n, operands);
    1647             Z3_inc_ref(mContext, optimized);
     1701            result = Z3_mk_and(mContext, n, operands); assert (result);
     1702            Z3_inc_ref(mContext, result);
    16481703            for (unsigned i = 0; i < n; ++i) {
    16491704                Z3_dec_ref(mContext, operands[i]);
    16501705            }
     1706        } else {
     1707            result = Z3_mk_true(mContext); assert (result);
    16511708        }
    16521709        Z3_goal_dec_ref(mContext, h);
    1653         Z3_apply_result_dec_ref(mContext, r);
    1654         Z3_dec_ref(mContext, result);
    1655         result = optimized;
    1656     }
     1710
     1711    } else {       
     1712        result = Z3_simplify_ex(mContext, node, mParams); assert (result);
     1713        Z3_inc_ref(mContext, result);
     1714    }   
    16571715    Z3_dec_ref(mContext, node);
     1716    assert (result);
    16581717    return result;
    16591718}
Note: See TracChangeset for help on using the changeset viewer.