Skip to content

Commit 5013964

Browse files
committed
Add specialized LowerTriangular and UnitLowerTriangular solve
1 parent b5469ef commit 5013964

File tree

1 file changed

+43
-7
lines changed

1 file changed

+43
-7
lines changed

src/array_partition.jl

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,14 @@ common_number(a, b) =
301301
ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(reduce(vcat,vec.(A.x)))
302302

303303
LinearAlgebra.ldiv!(A::Factorization, b::ArrayPartition) = (x = ldiv!(A, Array(b)); copyto!(b, x))
304+
function LinearAlgebra.ldiv!(A::LU, b::ArrayPartition)
305+
LinearAlgebra._ipiv_rows!(A, 1 : length(A.ipiv), b)
306+
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b))
307+
return b
308+
end
304309

305310
# block matrix indexing
306-
function getblock(A, lens, i, j)
311+
@inbounds function getblock(A, lens, i, j)
307312
ii1 = i == 1 ? 0 : sum(ii->lens[ii], 1:i-1)
308313
jj1 = j == 1 ? 0 : sum(ii->lens[ii], 1:j-1)
309314
ij1 = CartesianIndex(ii1, jj1)
@@ -315,18 +320,49 @@ end
315320
# [U11 U12 U13] [ b1 ]
316321
# [ 0 U22 U23] \ [ b2 ]
317322
# [ 0 0 U33] [ b3 ]
318-
function LinearAlgebra.ldiv!(A::T, b::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular}
323+
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular}
319324
A = A.data
320-
n = npartitions(b)
321-
lens = map(length, b.x)
325+
n = npartitions(bb)
326+
b = bb.x
327+
lens = map(length, b)
322328
@inbounds for j in n:-1:1
323329
Ajj = T(getblock(A, lens, j, j))
324-
xj = ldiv!(Ajj, b.x[j])
330+
xj = ldiv!(Ajj, b[j])
325331
for i in j-1:-1:1
332+
Aij = getblock(A, lens, i, j)
333+
# bi = -Aij * xj + bi
334+
mul!(b[i], Aij, xj, -1, true)
335+
end
336+
end
337+
return bb
338+
end
339+
340+
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerTriangular,LowerTriangular}
341+
A = A.data
342+
n = npartitions(bb)
343+
b = bb.x
344+
lens = map(length, b)
345+
@inbounds for j in 1:n
346+
Ajj = T(getblock(A, lens, j, j))
347+
xj = ldiv!(Ajj, b[j])
348+
for i in j+1:n
326349
Aij = getblock(A, lens, i, j)
327350
# bi = -Aij * xj + b[i]
328-
mul!(b.x[i], Aij, xj, -1, true)
351+
mul!(b[i], Aij, xj, -1, true)
329352
end
330353
end
331-
return b
354+
return bb
355+
end
356+
# TODO: optimize
357+
function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition)
358+
for i = order
359+
if i != A.ipiv[i]
360+
LinearAlgebra._swap_rows!(B, i, A.ipiv[i])
361+
end
362+
end
363+
return B
364+
end
365+
function LinearAlgebra._swap_rows!(B::ArrayPartition, i::Integer, j::Integer)
366+
B[i], B[j] = B[j], B[i]
367+
return B
332368
end

0 commit comments

Comments
 (0)