Skip to content

Commit 093e0ff

Browse files
committed
Added linalg mul! overloads for ArrayPartition with tests
1 parent c24e54f commit 093e0ff

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/array_partition.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,18 @@ function LinearAlgebra._swap_rows!(B::ArrayPartition, i::Integer, j::Integer)
370370
B[i], B[j] = B[j], B[i]
371371
return B
372372
end
373+
374+
# linalg mul! overloads for ArrayPartition
375+
function LinearAlgebra.mul!(C::T, A::T, B::AbstractArray) where T<:ArrayPartition
376+
@assert length(C.x) == length(A.x)
377+
for index = 1:length(C.x)
378+
mul!(C.x[index], A.x[index], B)
379+
end
380+
end
381+
382+
function LinearAlgebra.mul!(C::T, A::T, B::T) where T<:ArrayPartition
383+
@assert length(C.x) == length(A.x) == length(B.x)
384+
for index = 1:length(C.x)
385+
mul!(C.x[index], A.x[index], B.x[index])
386+
end
387+
end

test/linalg.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using RecursiveArrayTools, Test, Random
2-
using LinearAlgebra
2+
using LinearAlgebra, SparseArrays
33

44
n, m = 5, 6
55
bb = rand(n), rand(m)
@@ -26,3 +26,24 @@ for ff in (lu, svd, qr)
2626
@test ldiv!(FF, bbb) === bbb
2727
@test A*bbb b
2828
end
29+
30+
# linalg mul! overloads
31+
n, m, l = 5, 6, 7
32+
bb = rand(n, n), rand(m, n), rand(l, n)
33+
cc = rand(n), rand(n), rand(n)
34+
dd = rand(n), rand(m), rand(l)
35+
b = ArrayPartition(bb)
36+
c = ArrayPartition(cc)
37+
d = ArrayPartition(dd)
38+
A = rand(n)
39+
for T in (Array{Float64}, Array{ComplexF64}, sparse, )
40+
B = T(A)
41+
mul!(d, b, A)
42+
for i = 1:length(c.x)
43+
@test d.x[i] == b.x[i] * A
44+
end
45+
mul!(d, b, c)
46+
for i = 1:length(d.x)
47+
@test d.x[i] == b.x[i] * c.x[i]
48+
end
49+
end

0 commit comments

Comments
 (0)