Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ LinearSolveCUDAExt = "CUDA"
LinearSolveCUDSSExt = "CUDSS"
LinearSolveCUSOLVERRFExt = ["CUSOLVERRF", "SparseArrays"]
LinearSolveCliqueTreesExt = ["CliqueTrees", "SparseArrays"]
LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveEnzymeExt = ["EnzymeCore", "SparseArrays"]
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
LinearSolveForwardDiffExt = "ForwardDiff"
Expand Down
104 changes: 92 additions & 12 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,87 @@ using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearP
using LinearSolve.LinearAlgebra
using EnzymeCore
using EnzymeCore: EnzymeRules
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:LinearSolve.SciMLLinearSolveAlgorithm}) = true

# Helper functions for sparse-safe gradient accumulation
# These avoid broadcast operations that can change sparsity patterns
#
# Key insight: Enzyme.make_zero shares structural arrays (rowval, colptr) between
# primal and shadow sparse matrices. Broadcast operations like `dA .-= z * y'` can
# change the sparsity pattern, corrupting both shadow AND primal. We must operate
# directly on nzval to preserve the sparsity pattern.

using SparseArrays: nonzeros, rowvals, getcolptr

"""
_safe_add!(dst, src)

Add `src` to `dst` in a way that preserves the sparsity pattern of sparse matrices.
For sparse matrices with matching sparsity patterns (as with Enzyme shadows),
this operates directly on the nonzeros array.
"""
function _safe_add!(dst::SparseMatrixCSC, src::SparseMatrixCSC)
nonzeros(dst) .+= nonzeros(src)
return dst
end

function _safe_add!(dst::AbstractArray, src::AbstractArray)
dst .+= src
return dst
end

"""
_safe_zero!(A)

Zero out `A` in a way that preserves the sparsity pattern of sparse matrices.
For sparse matrices, this operates directly on the nonzeros array.
"""
function _safe_zero!(A::SparseMatrixCSC)
fill!(nonzeros(A), zero(eltype(A)))
return A
end

function _safe_zero!(A::AbstractArray)
fill!(A, zero(eltype(A)))
return A
end

"""
_sparse_outer_sub!(dA, z, y)

Compute `dA .-= z * transpose(y)` in a sparsity-preserving manner.

For sparse matrices, only accumulates gradients into existing non-zero positions.
This is mathematically correct for sparse matrix AD: gradients are only meaningful
at positions where the matrix can be modified.

Note: SparseMatrixCSC is a CPU-only type. GPU sparse matrices (CuSparseMatrixCSC, etc.)
have their own types and would need handling in their respective extensions.
"""
function _sparse_outer_sub!(dA::SparseMatrixCSC, z::AbstractVector, y::AbstractVector)
rows = rowvals(dA)
vals = nonzeros(dA)
colptr = getcolptr(dA)

# Non-allocating loop over CSC structure
# This is efficient and cache-friendly (column-major order)
@inbounds for col in 1:size(dA, 2)
y_col = y[col]
for idx in colptr[col]:(colptr[col + 1] - 1)
vals[idx] -= z[rows[idx]] * y_col
end
end

return dA
end

function _sparse_outer_sub!(dA::AbstractArray, z::AbstractVector, y::AbstractVector)
dA .-= z * transpose(y)
return dA
end

function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand All @@ -25,10 +103,10 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
dres = func.val(prob.dval, alg.val; kwargs...)

if dres.b == res.b
dres.b .= false
_safe_zero!(dres.b)
end
if dres.A == res.A
dres.A .= false
_safe_zero!(dres.A)
end

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
Expand Down Expand Up @@ -125,22 +203,23 @@ function EnzymeRules.reverse(

if EnzymeRules.width(config) == 1
if d_A !== prob_d_A
prob_d_A .+= d_A
d_A .= 0
# Use sparse-safe addition to preserve sparsity pattern
_safe_add!(prob_d_A, d_A)
_safe_zero!(d_A)
end
if d_b !== prob_d_b
prob_d_b .+= d_b
d_b .= 0
_safe_add!(prob_d_b, d_b)
_safe_zero!(d_b)
end
else
for (_prob_d_A, _d_A, _prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
if _d_A !== _prob_d_A
_prob_d_A .+= _d_A
_d_A .= 0
_safe_add!(_prob_d_A, _d_A)
_safe_zero!(_d_A)
end
if _d_b !== _prob_d_b
_prob_d_b .+= _d_b
_d_b .= 0
_safe_add!(_prob_d_b, _d_b)
_safe_zero!(_d_b)
end
end
end
Expand All @@ -149,7 +228,7 @@ function EnzymeRules.reverse(
end

# y=inv(A) B
# dA −= z y^T
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.solve!)},
Expand Down Expand Up @@ -254,7 +333,8 @@ function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end

dA .-= z * transpose(y)
# Use sparse-safe outer product subtraction to preserve sparsity pattern
_sparse_outer_sub!(dA, z, y)
db .+= z
dy .= eltype(dy)(0)
end
Expand Down
Loading