@@ -301,9 +301,14 @@ common_number(a, b) =
301301ArrayInterface. zeromatrix (A:: ArrayPartition ) = ArrayInterface. zeromatrix (reduce (vcat,vec .(A. x)))
302302
303303LinearAlgebra. ldiv! (A:: Factorization , b:: ArrayPartition ) = (x = ldiv! (A, Array (b)); copyto! (b, x))
304+ function LinearAlgebra. ldiv! (A:: LU , b:: ArrayPartition )
305+ LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
306+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
307+ return b
308+ end
304309
305310# block matrix indexing
306- function getblock (A, lens, i, j)
311+ @inbounds function getblock (A, lens, i, j)
307312 ii1 = i == 1 ? 0 : sum (ii-> lens[ii], 1 : i- 1 )
308313 jj1 = j == 1 ? 0 : sum (ii-> lens[ii], 1 : j- 1 )
309314 ij1 = CartesianIndex (ii1, jj1)
@@ -315,18 +320,49 @@ end
315320# [U11 U12 U13] [ b1 ]
316321# [ 0 U22 U23] \ [ b2 ]
317322# [ 0 0 U33] [ b3 ]
318- function LinearAlgebra. ldiv! (A:: T , b :: ArrayPartition ) where T<: Union{UnitUpperTriangular,UpperTriangular}
323+ function LinearAlgebra. ldiv! (A:: T , bb :: ArrayPartition ) where T<: Union{UnitUpperTriangular,UpperTriangular}
319324 A = A. data
320- n = npartitions (b)
321- lens = map (length, b. x)
325+ n = npartitions (bb)
326+ b = bb. x
327+ lens = map (length, b)
322328 @inbounds for j in n: - 1 : 1
323329 Ajj = T (getblock (A, lens, j, j))
324- xj = ldiv! (Ajj, b. x [j])
330+ xj = ldiv! (Ajj, b[j])
325331 for i in j- 1 : - 1 : 1
332+ Aij = getblock (A, lens, i, j)
333+ # bi = -Aij * xj + bi
334+ mul! (b[i], Aij, xj, - 1 , true )
335+ end
336+ end
337+ return bb
338+ end
339+
340+ function LinearAlgebra. ldiv! (A:: T , bb:: ArrayPartition ) where T<: Union{UnitLowerTriangular,LowerTriangular}
341+ A = A. data
342+ n = npartitions (bb)
343+ b = bb. x
344+ lens = map (length, b)
345+ @inbounds for j in 1 : n
346+ Ajj = T (getblock (A, lens, j, j))
347+ xj = ldiv! (Ajj, b[j])
348+ for i in j+ 1 : n
326349 Aij = getblock (A, lens, i, j)
327350 # bi = -Aij * xj + b[i]
328- mul! (b. x [i], Aij, xj, - 1 , true )
351+ mul! (b[i], Aij, xj, - 1 , true )
329352 end
330353 end
331- return b
354+ return bb
355+ end
356+ # TODO : optimize
357+ function LinearAlgebra. _ipiv_rows! (A:: LU , order:: OrdinalRange , B:: ArrayPartition )
358+ for i = order
359+ if i != A. ipiv[i]
360+ LinearAlgebra. _swap_rows! (B, i, A. ipiv[i])
361+ end
362+ end
363+ return B
364+ end
365+ function LinearAlgebra. _swap_rows! (B:: ArrayPartition , i:: Integer , j:: Integer )
366+ B[i], B[j] = B[j], B[i]
367+ return B
332368end
0 commit comments