Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2098245
start
kripken Nov 20, 2025
468be27
work
kripken Nov 20, 2025
c43bd44
work
kripken Nov 20, 2025
b5a7f44
test
kripken Dec 1, 2025
36fe492
work
kripken Dec 1, 2025
6eb2b7e
work
kripken Dec 1, 2025
f67c5c2
comment
kripken Dec 1, 2025
d759cb9
work
kripken Dec 1, 2025
6007366
comment
kripken Dec 1, 2025
bb1ec77
waka?
kripken Dec 2, 2025
bad7b84
clean
kripken Dec 2, 2025
efdf8e8
timer
kripken Dec 2, 2025
765950a
timer
kripken Dec 2, 2025
caf506d
timer
kripken Dec 2, 2025
4fb8477
timer
kripken Dec 2, 2025
553186e
timer
kripken Dec 2, 2025
6dcf2ad
timer
kripken Dec 2, 2025
9863333
timer
kripken Dec 2, 2025
5cf0f0c
timer
kripken Dec 2, 2025
35b31b7
timer
kripken Dec 2, 2025
3b2e5db
timer
kripken Dec 2, 2025
1df1294
timer
kripken Dec 2, 2025
3f90d40
timer
kripken Dec 2, 2025
cfa5891
timer
kripken Dec 2, 2025
bf3eb42
timer
kripken Dec 2, 2025
20a5743
fix
kripken Dec 2, 2025
a1a6413
fix
kripken Dec 2, 2025
4640c3a
fix
kripken Dec 2, 2025
2d77d4a
fix
kripken Dec 2, 2025
e142546
fix
kripken Dec 2, 2025
fe41327
fix
kripken Dec 2, 2025
5008b43
fix
kripken Dec 2, 2025
3d41a87
:wq
kripken Dec 2, 2025
7418cbc
nicer
kripken Dec 2, 2025
d9203ef
nicer
kripken Dec 2, 2025
956e778
nicer
kripken Dec 2, 2025
006cb6b
nicer
kripken Dec 2, 2025
f7e565b
nicer
kripken Dec 2, 2025
b905ad7
nicer
kripken Dec 2, 2025
56430a0
nicer
kripken Dec 2, 2025
dfda490
nicer
kripken Dec 2, 2025
6e15aa4
nicer
kripken Dec 2, 2025
a9d46ef
nicer
kripken Dec 2, 2025
c7708f5
nicer
kripken Dec 2, 2025
347e5df
nicer
kripken Dec 2, 2025
6896345
merge
kripken Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 106 additions & 10 deletions src/passes/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ struct DAEFunctionInfo {
// Whether this needs to be recomputed. This begins as true for the first
// computation, and we reset it every time we touch the function.
bool stale = true;
// Set after we just updated this. That is the case right after we were stale,
// and were updated to become not stale, and it indicates that we contain new
// data.
bool justUpdated = false;
// The unused parameters, if any.
SortedVector unusedParams;
// Maps a function name to the calls going to it.
Expand Down Expand Up @@ -160,18 +164,21 @@ struct DAEScanner

if (!info->stale) {
// Nothing changed since last time.
info->justUpdated = false;
return;
}

// Clear the data, mark us as no longer stale, and recompute everything.
info->clear();
info->stale = false;
info->justUpdated = true;

auto numParams = func->getNumParams();
PostWalker<DAEScanner, Visitor<DAEScanner>>::doWalkFunction(func);

// If there are params, check if they are used.
// TODO: This work could be avoided if we cannot optimize for other reasons.
// That would require deferring this to later and checking that.
auto numParams = func->getNumParams();
if (numParams > 0) {
auto usedParams = ParamUtils::getUsedParams(func, getModule());
for (Index i = 0; i < numParams; i++) {
Expand All @@ -195,6 +202,17 @@ struct DAE : public Pass {
// Map of function names to indexes. This lets us use indexes below for speed.
std::unordered_map<Name, Index> indexes;

struct CallInfo {
// Store the calls and their origins in parallel vectors (as we need |calls|
// by itself for certain APIs). That is, origins[i] is the function index in
// which calls[i] appears.
std::vector<Call*> calls;
std::vector<Index> origins;
};
// The set of all calls (and their origins) between functions. We compute this
// incrementally in later iterations, to avoid repeated work.
std::vector<CallInfo> allCalls;

void run(Module* module) override {
DAEFunctionInfoMap infoMap;
// Ensure all entries exist so the parallel threads don't modify the data
Expand All @@ -211,6 +229,8 @@ struct DAE : public Pass {
indexes[module->functions[i]->name] = i;
}

allCalls.resize(numFunctions);

// Iterate to convergence.
while (1) {
if (!iteration(module, infoMap)) {
Expand All @@ -232,6 +252,8 @@ struct DAE : public Pass {
// of computing this map is significant, so we compute it once at the start
// and then use that possibly-over-approximating data.
std::vector<std::vector<Name>> callers;
// Reverse data: The list of functions called by a function.
std::vector<std::vector<Index>> callees;

// A count of how many iterations we saw unprofitable removals of parameters.
// An unprofitable removal is one where we only manage to remove from a single
Expand Down Expand Up @@ -274,15 +296,10 @@ struct DAE : public Pass {
scanner.run(getPassRunner(), module);

// Combine all the info from the scan.
std::vector<std::vector<Call*>> allCalls(numFunctions);
std::vector<bool> tailCallees(numFunctions);
std::vector<bool> hasUnseenCalls(numFunctions);

for (auto& [func, info] : infoMap) {
for (auto& [name, calls] : info.calls) {
auto& allCallsToName = allCalls[indexes[name]];
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
}
for (auto& callee : info.tailCallees) {
tailCallees[indexes[callee]] = true;
}
Expand All @@ -300,7 +317,7 @@ struct DAE : public Pass {
}
}

// See comment above, we compute callers once and never again.
// See comment above, we compute callers and callees once and never again.
if (callers.empty()) {
// Compute first as sets, to deduplicate.
std::vector<std::unordered_set<Name>> callersSets(numFunctions);
Expand All @@ -311,12 +328,18 @@ struct DAE : public Pass {
}
// Copy into efficient vectors.
callers.resize(numFunctions);
callees.resize(numFunctions);
for (Index i = 0; i < numFunctions; ++i) {
auto& set = callersSets[i];
callers[i] = std::vector<Name>(set.begin(), set.end());
for (auto& caller : callers[i]) {
callees[indexes[caller]].push_back(i);
}
}
}

updateAllCalls(module, infoMap);

// Track which functions we changed that are worth re-optimizing at the end.
std::unordered_set<Function*> worthOptimizing;

Expand Down Expand Up @@ -361,7 +384,7 @@ struct DAE : public Pass {
if (hasUnseenCalls[index]) {
continue;
}
auto& calls = allCalls[index];
auto& calls = allCalls[index].calls;
if (calls.empty()) {
// Nothing calls this, so it is not worth optimizing.
continue;
Expand Down Expand Up @@ -418,7 +441,7 @@ struct DAE : public Pass {
if (numParams == 0) {
continue;
}
auto& calls = allCalls[index];
auto& calls = allCalls[index].calls;
if (calls.empty()) {
continue;
}
Expand Down Expand Up @@ -475,7 +498,7 @@ struct DAE : public Pass {
if (tailCallees[index]) {
continue;
}
auto& calls = allCalls[index];
auto& calls = allCalls[index].calls;
if (calls.empty()) {
continue;
}
Expand Down Expand Up @@ -516,6 +539,79 @@ struct DAE : public Pass {
private:
std::unordered_map<Call*, Expression**> allDroppedCalls;

void updateAllCalls(Module* module, DAEFunctionInfoMap& infoMap) {
// Recompute parts of allCalls as necessary. We know which function infos
// were just updated, and start there: If we updated { A, B }, and A calls
// C while B calls C and D, then the list of all calls must be updated for
// C and D. First, find the functions just updated, and the ones they call.
std::unordered_set<Index> justUpdated;
std::unordered_set<Index> calledByJustUpdated;
for (auto& [func, info] : infoMap) {
if (info.justUpdated) {
auto index = indexes[func];
justUpdated.insert(index);
for (auto& callee : callees[index]) {
calledByJustUpdated.insert(callee);
}
}
}

// Add calls from one caller to allCalls.
auto addCallsFrom = [&](Index caller) {
auto& info = infoMap[module->functions[caller]->name];
for (auto& [name, calls] : info.calls) {
auto& allCallsToName = allCalls[indexes[name]].calls;
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
auto num = calls.size();
auto& origins = allCalls[indexes[name]].origins;
for (Index i = 0; i < num; i++) {
origins.push_back(caller);
}
}
};

if (justUpdated.size() + calledByJustUpdated.size() >= numFunctions) {
// Many functions need to be processed to do an incremental update. Just
// do a full recompute from scratch, which may be faster.
allCalls.clear();
allCalls.resize(numFunctions);
for (Index caller = 0; caller < numFunctions; caller++) {
addCallsFrom(caller);
}
return;
}

// Do an incremental update. First, remove all stale calls from allCalls,
// that is, remove calls from the just-updated functions.
for (auto& called : calledByJustUpdated) {
auto& calledCalls = allCalls[called];
auto oldSize = calledCalls.calls.size();
assert(oldSize == calledCalls.origins.size());
Index skip = 0;
for (Index i = 0; i < calledCalls.calls.size(); i++) {
if (justUpdated.count(calledCalls.origins[i])) {
// Remove it by skipping over.
skip++;
} else if (skip) {
// Keep it by writing to the proper place.
calledCalls.calls[i - skip] = calledCalls.calls[i];
calledCalls.origins[i - skip] = calledCalls.origins[i];
}
}
if (skip > 0) {
// Update the sizes after removing things.
calledCalls.calls.resize(oldSize - skip);
calledCalls.origins.resize(oldSize - skip);
}
}

// The just-updated callers' calls have been cleaned out of allCalls. Add
// them in, leaving us with fully-updated data.
for (auto& caller : justUpdated) {
addCallsFrom(caller);
}
}

// Returns `true` if the caller should be optimized.
bool
removeReturnValue(Function* func, std::vector<Call*>& calls, Module* module) {
Expand Down
78 changes: 78 additions & 0 deletions test/lit/passes/dae-gc.wast
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,81 @@
(nop)
)
)

;; After the first optimization, where we remove params from the call to $param,
;; we update the IR incrementally, and must do so properly: we update $param and
;; its callers $caller. $caller also calls $result, so we must update some of
;; $result's callers but not all. Ditto with $caller2.
(module
(rec
;; CHECK: (type $A (sub (struct (field (mut f64)) (field (mut funcref)))))
(type $A (sub (struct (field (mut f64)) (field (mut funcref)))))
)

;; CHECK: (func $nop (type $0)
;; CHECK-NEXT: )
(func $nop
;; Helper.
)

;; CHECK: (func $param (type $0)
;; CHECK-NEXT: (local $0 (ref $A))
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
(func $param (param $0 (ref $A))
;; Helper with a param and lets us have calls inside it.
(unreachable)
)

;; CHECK: (func $caller (type $0)
;; CHECK-NEXT: (local $0 (ref (exact $A)))
;; CHECK-NEXT: (local.set $0
;; CHECK-NEXT: (struct.new $A
;; CHECK-NEXT: (call $result)
;; CHECK-NEXT: (ref.func $nop)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (call $param)
;; CHECK-NEXT: )
(func $caller
(call $param
(struct.new $A
(call $result)
(ref.func $nop)
)
)
)

;; CHECK: (func $param2 (type $0)
;; CHECK-NEXT: (local $0 (ref any))
;; CHECK-NEXT: )
(func $param2 (param $0 (ref any))
)

;; CHECK: (func $caller2 (type $0)
;; CHECK-NEXT: (call $param2)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (call $result)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $caller2
(call $param2
(struct.new $A
(f64.const 0)
(ref.func $param)
)
)
;; The second call is not nested in this case, to test another form.
(drop
(call $result)
)
)

;; CHECK: (func $result (type $2) (result f64)
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
(func $result (result f64)
(unreachable)
)
)

Loading