Skip to content

Commit 68ac004

Browse files
Merge pull request #143 from sharanry/sy/diffeqarray_change
Enable symbol based indexing of interpolated solutions by adding extra fields to DiffEqArray type
2 parents 723647d + 59d1150 commit 68ac004

File tree

6 files changed

+116
-5
lines changed

6 files changed

+116
-5
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ julia = "1.5"
2828
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2929
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
3030
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
31+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3132
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3233
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3334
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3435
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3536
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3637

3738
[targets]
38-
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random", "StructArrays", "Zygote"]
39+
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StructArrays", "Zygote"]

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ using DocStringExtensions
2121
include("zygote.jl")
2222

2323
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
24-
vecarr_to_arr, vecarr_to_vectors, tuples
24+
AllObserved, vecarr_to_arr, vecarr_to_vectors, tuples
2525

2626
export recursivecopy, recursivecopy!, vecvecapply, copyat_or_push!,
2727
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
2828
recursive_unitless_bottom_eltype, recursive_unitless_eltype
2929

30+
3031
export ArrayPartition
3132

3233

src/vector_of_array.jl

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,30 @@ mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
33
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
44
end
55
# VectorOfArray with an added series for time
6-
mutable struct DiffEqArray{T, N, A, B} <: AbstractDiffEqArray{T, N, A}
6+
mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
77
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
88
t::B
9+
syms::C
10+
indepsym::D
11+
observed::E
12+
p::F
913
end
1014

15+
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
16+
parameterless_type(x) = parameterless_type(typeof(x))
17+
parameterless_type(x::Type) = __parameterless_type(x)
18+
19+
### Abstract Interface
20+
struct AllObserved
21+
end
22+
issymbollike(x) = x isa Symbol ||
23+
x isa AllObserved ||
24+
Symbol(parameterless_type(typeof(x))) == :Operation ||
25+
Symbol(parameterless_type(typeof(x))) == :Variable ||
26+
Symbol(parameterless_type(typeof(x))) == :Sym ||
27+
Symbol(parameterless_type(typeof(x))) == :Num ||
28+
Symbol(parameterless_type(typeof(x))) == :Term
29+
1130
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:AbstractVector}} = reduce(hcat,VA.u)
1231
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:Number}} = VA.u
1332
function Base.Array(VA::AbstractVectorOfArray)
@@ -20,10 +39,11 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
2039
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
2140
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
2241

23-
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts)}(vec, ts)
42+
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing, Nothing}(vec, ts, nothing, nothing, nothing, nothing)
2443
# Assume that the first element is representative of all other elements
2544
DiffEqArray(vec::AbstractVector,ts::AbstractVector) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)))
26-
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts)}(vec, ts)
45+
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts), Nothing, Nothing, Nothing, Nothing}(vec, ts, nothing, nothing, nothing, nothing)
46+
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms::Vector{Symbol}, indepsym::Symbol, observed::Function, p) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
2747

2848
# Interface for the linear indexing. This is just a view of the underlying nested structure
2949
@inline Base.firstindex(VA::AbstractVectorOfArray) = firstindex(VA.u)
@@ -38,6 +58,52 @@ Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Int)
3858
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Colon) where {T, N} = VA.u[I]
3959
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::AbstractArray{Int}) where {T, N} = VectorOfArray(VA.u[I])
4060
Base.@propagate_inbounds Base.getindex(VA::AbstractDiffEqArray{T, N}, I::AbstractArray{Int}) where {T, N} = DiffEqArray(VA.u[I],VA.t[I])
61+
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
62+
if issymbollike(sym) && A.syms !== nothing
63+
i = findfirst(isequal(Symbol(sym)),A.syms)
64+
else
65+
i = sym
66+
end
67+
68+
if i === nothing
69+
if issymbollike(i) && A.indepsym !== nothing && Symbol(i) == A.indepsym
70+
A.t
71+
else
72+
observed(A,sym,:)
73+
end
74+
else
75+
Base.getindex.(A.u, i)
76+
end
77+
end
78+
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
79+
if issymbollike(sym) && A.syms !== nothing
80+
i = findfirst(isequal(Symbol(sym)),A.syms)
81+
else
82+
i = sym
83+
end
84+
85+
if i === nothing
86+
if issymbollike(i) && A.indepsym !== nothing && Symbol(i) == A.indepsym
87+
A.t[args...]
88+
else
89+
observed(A,sym,args...)
90+
end
91+
else
92+
Base.getindex.(A.u, args...)
93+
end
94+
end
95+
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
96+
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int) where {T, N} = A.u[i]
97+
function observed(A::AbstractDiffEqArray{T, N},sym,i::Int) where {T, N}
98+
A.observed(sym,A.u[i],A.p,A.t[i])
99+
end
100+
function observed(A::AbstractDiffEqArray{T, N},sym,i::AbstractArray{Int}) where {T, N}
101+
A.observed.((sym,),A.u[i],(A.p,),A.t[i])
102+
end
103+
function observed(A::AbstractDiffEqArray{T, N},sym,::Colon) where {T, N}
104+
A.observed.((sym,),A.u,(A.p,),A.t)
105+
end
106+
41107
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, i::Int,::Colon) where {T, N} = [VA.u[j][i] for j in 1:length(VA)]
42108
Base.@propagate_inbounds function Base.getindex(VA::AbstractVectorOfArray{T,N}, ii::CartesianIndex) where {T, N}
43109
ti = Tuple(ii)
@@ -145,6 +211,8 @@ Base.show(io::IO, m::MIME"text/plain", x::AbstractDiffEqArray) = (print(io,"t: "
145211
convert(Array,VA)
146212
end
147213
@recipe function f(VA::AbstractDiffEqArray)
214+
xguide --> ((VA.indepsym !== nothing) ? string(VA.indepsym) : "")
215+
label --> ((VA.syms !== nothing) ? reshape(string.(VA.syms), 1, :) : "")
148216
VA.t,VA'
149217
end
150218
@recipe function f(VA::DiffEqArray{T,1}) where {T}

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
3+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

test/downstream/symbol_indexing.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, Test
2+
3+
@variables t x(t)
4+
@parameters τ
5+
D = Differential(t)
6+
@variables RHS(t)
7+
@named fol_separate = ODESystem([ RHS ~ (1 - x)/τ,
8+
D(x) ~ RHS ])
9+
fol_simplified = structural_simplify(fol_separate)
10+
11+
prob = ODEProblem(fol_simplified, [x => 0.0], (0.0,10.0), [τ => 3.0])
12+
sol = solve(prob, Tsit5())
13+
14+
sol_new = DiffEqArray(
15+
sol.u[1:10],
16+
sol.t[1:10],
17+
sol.prob.f.syms,
18+
sol.prob.f.indepsym,
19+
sol.prob.f.observed,
20+
sol.prob.p
21+
)
22+
23+
@test sol_new[RHS] (1 .- sol_new[x])./3.0

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1+
using Pkg
12
using RecursiveArrayTools
23
using Test
34

5+
const GROUP = get(ENV, "GROUP", "All")
6+
const is_APPVEYOR = ( Sys.iswindows() && haskey(ENV,"APPVEYOR") )
7+
8+
function activate_downstream_env()
9+
Pkg.activate("downstream")
10+
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
11+
Pkg.instantiate()
12+
end
13+
414
@time begin
515
@time @testset "Utils Tests" begin include("utils_test.jl") end
616
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
@@ -10,4 +20,9 @@ using Test
1020
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
1121
@time @testset "Upstream Tests" begin include("upstream.jl") end
1222
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
23+
24+
if !is_APPVEYOR && GROUP == "Downstream"
25+
activate_downstream_env()
26+
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
27+
end
1328
end

0 commit comments

Comments
 (0)