Skip to content

Commit 1bd8882

Browse files
support linear algebra on ArrayPartition
1 parent 05a5a49 commit 1bd8882

File tree

5 files changed

+40
-1
lines changed

5 files changed

+40
-1
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ version = "2.2.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
10+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
811
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
912
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1013
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ __precompile__()
33
module RecursiveArrayTools
44

55
using Requires, RecipesBase, StaticArrays, Statistics,
6-
ArrayInterface, ZygoteRules
6+
ArrayInterface, ZygoteRules, LinearAlgebra
77

88
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
99
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end

src/array_partition.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,9 @@ common_number(a, b) =
293293
(b == 0 ? a :
294294
(a == b ? a :
295295
throw(DimensionMismatch("number of partitions must be equal"))))
296+
297+
## Linear Algebra
298+
299+
LinearAlgebra.ldiv!(A::LinearAlgebra.LU,b::ArrayPartition) = ldiv!(A,Array(b))
300+
LinearAlgebra.ldiv!(A::LinearAlgebra.QR,b::ArrayPartition) = ldiv!(A,Array(b))
301+
LinearAlgebra.ldiv!(A::LinearAlgebra.SVD,b::ArrayPartition) = ldiv!(A,Array(b))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ using Test
77
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
88
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
99
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
10+
@time @testset "Upstream Tests" begin include("upstream.jl") end
1011
end

test/upstream.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test
2+
function lorenz(du,u,p,t)
3+
du[1] = 10.0*(u[2]-u[1])
4+
du[2] = u[1]*(28.0-u[3]) - u[2]
5+
du[3] = u[1]*u[2] - (8/3)*u[3]
6+
end
7+
u0 = ArrayPartition([1.0,0.0],[0.0])
8+
@test u0 .* u0' .* false isa Matrix
9+
tspan = (0.0,100.0)
10+
prob = ODEProblem(lorenz,u0,tspan)
11+
sol = solve(prob,Tsit5());
12+
sol = solve(prob,AutoTsit5(Rosenbrock23(autodiff=false)));
13+
14+
function mymodel(F, vars)
15+
x = vars.x[1]
16+
F.x[1][1] = (x[1]+3)*(x[2]^3-7)+18
17+
F.x[1][2] = sin(x[2]*exp(x[1])-1)
18+
y=vars.x[2]
19+
F.x[2][1] = (y[1]+3)*(y[2]^3-7)+18
20+
F.x[2][2] = sin(y[2]*exp(y[1])-1)
21+
end
22+
23+
# To show that the function works
24+
F = ArrayPartition([0.0 0.0],[0.0, 0.0])
25+
u0= ArrayPartition([0.1; 1.2], [0.1; 1.2])
26+
result = mymodel(F, u0)
27+
28+
# To show the NLsolve error that results with ArrayPartitions:
29+
nlsolve(f!, ArrayPartition([0.1; 1.2], [0.1; 1.2]))

0 commit comments

Comments
 (0)