diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 27cc915e0a4..96c00c3eaea 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -60,6 +60,10 @@ struct DAEFunctionInfo { // Whether this needs to be recomputed. This begins as true for the first // computation, and we reset it every time we touch the function. bool stale = true; + // Set after we just updated this. That is the case right after we were stale, + // and were updated to become not stale, and it indicates that we contain new + // data. + bool justUpdated = false; // The unused parameters, if any. SortedVector unusedParams; // Maps a function name to the calls going to it. @@ -160,18 +164,21 @@ struct DAEScanner if (!info->stale) { // Nothing changed since last time. + info->justUpdated = false; return; } // Clear the data, mark us as no longer stale, and recompute everything. info->clear(); info->stale = false; + info->justUpdated = true; - auto numParams = func->getNumParams(); PostWalker>::doWalkFunction(func); + // If there are params, check if they are used. // TODO: This work could be avoided if we cannot optimize for other reasons. // That would require deferring this to later and checking that. + auto numParams = func->getNumParams(); if (numParams > 0) { auto usedParams = ParamUtils::getUsedParams(func, getModule()); for (Index i = 0; i < numParams; i++) { @@ -195,6 +202,17 @@ struct DAE : public Pass { // Map of function names to indexes. This lets us use indexes below for speed. std::unordered_map indexes; + struct CallInfo { + // Store the calls and their origins in parallel vectors (as we need |calls| + // by itself for certain APIs). That is, origins[i] is the function index in + // which calls[i] appears. + std::vector calls; + std::vector origins; + }; + // The set of all calls (and their origins) between functions. We compute this + // incrementally in later iterations, to avoid repeated work. + std::vector allCalls; + void run(Module* module) override { DAEFunctionInfoMap infoMap; // Ensure all entries exist so the parallel threads don't modify the data @@ -211,6 +229,8 @@ struct DAE : public Pass { indexes[module->functions[i]->name] = i; } + allCalls.resize(numFunctions); + // Iterate to convergence. while (1) { if (!iteration(module, infoMap)) { @@ -232,6 +252,8 @@ struct DAE : public Pass { // of computing this map is significant, so we compute it once at the start // and then use that possibly-over-approximating data. std::vector> callers; + // Reverse data: The list of functions called by a function. + std::vector> callees; // A count of how many iterations we saw unprofitable removals of parameters. // An unprofitable removal is one where we only manage to remove from a single @@ -274,15 +296,10 @@ struct DAE : public Pass { scanner.run(getPassRunner(), module); // Combine all the info from the scan. - std::vector> allCalls(numFunctions); std::vector tailCallees(numFunctions); std::vector hasUnseenCalls(numFunctions); for (auto& [func, info] : infoMap) { - for (auto& [name, calls] : info.calls) { - auto& allCallsToName = allCalls[indexes[name]]; - allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); - } for (auto& callee : info.tailCallees) { tailCallees[indexes[callee]] = true; } @@ -300,7 +317,7 @@ struct DAE : public Pass { } } - // See comment above, we compute callers once and never again. + // See comment above, we compute callers and callees once and never again. if (callers.empty()) { // Compute first as sets, to deduplicate. std::vector> callersSets(numFunctions); @@ -311,12 +328,18 @@ struct DAE : public Pass { } // Copy into efficient vectors. callers.resize(numFunctions); + callees.resize(numFunctions); for (Index i = 0; i < numFunctions; ++i) { auto& set = callersSets[i]; callers[i] = std::vector(set.begin(), set.end()); + for (auto& caller : callers[i]) { + callees[indexes[caller]].push_back(i); + } } } + updateAllCalls(module, infoMap); + // Track which functions we changed that are worth re-optimizing at the end. std::unordered_set worthOptimizing; @@ -361,7 +384,7 @@ struct DAE : public Pass { if (hasUnseenCalls[index]) { continue; } - auto& calls = allCalls[index]; + auto& calls = allCalls[index].calls; if (calls.empty()) { // Nothing calls this, so it is not worth optimizing. continue; @@ -418,7 +441,7 @@ struct DAE : public Pass { if (numParams == 0) { continue; } - auto& calls = allCalls[index]; + auto& calls = allCalls[index].calls; if (calls.empty()) { continue; } @@ -475,7 +498,7 @@ struct DAE : public Pass { if (tailCallees[index]) { continue; } - auto& calls = allCalls[index]; + auto& calls = allCalls[index].calls; if (calls.empty()) { continue; } @@ -516,6 +539,79 @@ struct DAE : public Pass { private: std::unordered_map allDroppedCalls; + void updateAllCalls(Module* module, DAEFunctionInfoMap& infoMap) { + // Recompute parts of allCalls as necessary. We know which function infos + // were just updated, and start there: If we updated { A, B }, and A calls + // C while B calls C and D, then the list of all calls must be updated for + // C and D. First, find the functions just updated, and the ones they call. + std::unordered_set justUpdated; + std::unordered_set calledByJustUpdated; + for (auto& [func, info] : infoMap) { + if (info.justUpdated) { + auto index = indexes[func]; + justUpdated.insert(index); + for (auto& callee : callees[index]) { + calledByJustUpdated.insert(callee); + } + } + } + + // Add calls from one caller to allCalls. + auto addCallsFrom = [&](Index caller) { + auto& info = infoMap[module->functions[caller]->name]; + for (auto& [name, calls] : info.calls) { + auto& allCallsToName = allCalls[indexes[name]].calls; + allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); + auto num = calls.size(); + auto& origins = allCalls[indexes[name]].origins; + for (Index i = 0; i < num; i++) { + origins.push_back(caller); + } + } + }; + + if (justUpdated.size() + calledByJustUpdated.size() >= numFunctions) { + // Many functions need to be processed to do an incremental update. Just + // do a full recompute from scratch, which may be faster. + allCalls.clear(); + allCalls.resize(numFunctions); + for (Index caller = 0; caller < numFunctions; caller++) { + addCallsFrom(caller); + } + return; + } + + // Do an incremental update. First, remove all stale calls from allCalls, + // that is, remove calls from the just-updated functions. + for (auto& called : calledByJustUpdated) { + auto& calledCalls = allCalls[called]; + auto oldSize = calledCalls.calls.size(); + assert(oldSize == calledCalls.origins.size()); + Index skip = 0; + for (Index i = 0; i < calledCalls.calls.size(); i++) { + if (justUpdated.count(calledCalls.origins[i])) { + // Remove it by skipping over. + skip++; + } else if (skip) { + // Keep it by writing to the proper place. + calledCalls.calls[i - skip] = calledCalls.calls[i]; + calledCalls.origins[i - skip] = calledCalls.origins[i]; + } + } + if (skip > 0) { + // Update the sizes after removing things. + calledCalls.calls.resize(oldSize - skip); + calledCalls.origins.resize(oldSize - skip); + } + } + + // The just-updated callers' calls have been cleaned out of allCalls. Add + // them in, leaving us with fully-updated data. + for (auto& caller : justUpdated) { + addCallsFrom(caller); + } + } + // Returns `true` if the caller should be optimized. bool removeReturnValue(Function* func, std::vector& calls, Module* module) { diff --git a/test/lit/passes/dae-gc.wast b/test/lit/passes/dae-gc.wast index 6d014d0f8b4..4ced51e1d27 100644 --- a/test/lit/passes/dae-gc.wast +++ b/test/lit/passes/dae-gc.wast @@ -195,3 +195,81 @@ (nop) ) ) + +;; After the first optimization, where we remove params from the call to $param, +;; we update the IR incrementally, and must do so properly: we update $param and +;; its callers $caller. $caller also calls $result, so we must update some of +;; $result's callers but not all. Ditto with $caller2. +(module + (rec + ;; CHECK: (type $A (sub (struct (field (mut f64)) (field (mut funcref))))) + (type $A (sub (struct (field (mut f64)) (field (mut funcref))))) + ) + + ;; CHECK: (func $nop (type $0) + ;; CHECK-NEXT: ) + (func $nop + ;; Helper. + ) + + ;; CHECK: (func $param (type $0) + ;; CHECK-NEXT: (local $0 (ref $A)) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $param (param $0 (ref $A)) + ;; Helper with a param and lets us have calls inside it. + (unreachable) + ) + + ;; CHECK: (func $caller (type $0) + ;; CHECK-NEXT: (local $0 (ref (exact $A))) + ;; CHECK-NEXT: (local.set $0 + ;; CHECK-NEXT: (struct.new $A + ;; CHECK-NEXT: (call $result) + ;; CHECK-NEXT: (ref.func $nop) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $param) + ;; CHECK-NEXT: ) + (func $caller + (call $param + (struct.new $A + (call $result) + (ref.func $nop) + ) + ) + ) + + ;; CHECK: (func $param2 (type $0) + ;; CHECK-NEXT: (local $0 (ref any)) + ;; CHECK-NEXT: ) + (func $param2 (param $0 (ref any)) + ) + + ;; CHECK: (func $caller2 (type $0) + ;; CHECK-NEXT: (call $param2) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (call $result) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller2 + (call $param2 + (struct.new $A + (f64.const 0) + (ref.func $param) + ) + ) + ;; The second call is not nested in this case, to test another form. + (drop + (call $result) + ) + ) + + ;; CHECK: (func $result (type $2) (result f64) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $result (result f64) + (unreachable) + ) +) +