@@ -33,6 +33,27 @@ mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
3333end
3434# VectorOfArray with an added series for time
3535
36+
37+ struct SymbolCache{S,T,U}
38+ syms:: S
39+ indepsym:: T
40+ paramsyms:: U
41+ end
42+
43+ is_indep_sym (sc:: SymbolCache , sym) = isequal (sc. indepsym, sym)
44+ is_indep_sym (:: SymbolCache{S,Nothing} , _) where {S} = false
45+ state_sym_to_index (sc:: SymbolCache , sym) = findfirst (isequal (sym), sc. syms)
46+ state_sym_to_index (:: SymbolCache{Nothing} , _) = nothing
47+ is_state_sym (sc:: SymbolCache , sym) = ! isnothing (state_sym_to_index (sc, sym))
48+ param_sym_to_index (sc:: SymbolCache , sym) = findfirst (isequal (sym), sc. paramsyms)
49+ param_sym_to_index (:: SymbolCache{S,T,Nothing} , _) where {S,T} = nothing
50+ is_param_sym (sc:: SymbolCache , sym) = ! isnothing (param_sym_to_index (sc, sym))
51+
52+ Base. copy (VA:: SymbolCache ) = typeof (VA)(
53+ (VA. syms=== nothing ) ? nothing : copy (VA. syms),
54+ (VA. indepsym=== nothing ) ? nothing : copy (VA. indepsym),
55+ (VA. paramsyms=== nothing ) ? nothing : copy (VA. paramsyms),
56+ )
3657"""
3758```julia
3859DiffEqArray(u::AbstractVector,t::AbstractVector)
@@ -53,11 +74,10 @@ A[1,:] # all time periods for f(t)
5374A.t
5475```
5576"""
56- mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
77+ mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A}
5778 u:: A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
5879 t:: B
59- syms:: C
60- indepsym:: D
80+ sc:: C
6181 observed:: E
6282 p:: F
6383end
@@ -94,11 +114,11 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
94114VectorOfArray (vec:: AbstractVector ) = VectorOfArray (vec, (size (vec[1 ])... , length (vec)))
95115VectorOfArray (vec:: AbstractVector{VT} ) where {T, N, VT<: AbstractArray{T, N} } = VectorOfArray {T, N+1, typeof(vec)} (vec)
96116
97- DiffEqArray (vec:: AbstractVector{T} , ts, :: NTuple{N} , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) where {T, N} = DiffEqArray {eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)} (vec, ts, syms, indepsym, observed, p)
117+ DiffEqArray (vec:: AbstractVector{T} , ts, :: NTuple{N} , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) where {T, N} = DiffEqArray {eltype(T), N, typeof(vec), typeof(ts), SymbolCache{ typeof(syms), typeof(indepsym), Nothing}, typeof(observed), typeof(p)} (vec, ts, SymbolCache ( syms, indepsym, nothing ) , observed, p)
98118# Assume that the first element is representative of all other elements
99119DiffEqArray (vec:: AbstractVector ,ts:: AbstractVector , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) = DiffEqArray (vec, ts, (size (vec[1 ])... , length (vec)), syms, indepsym, observed, p)
100120function DiffEqArray (vec:: AbstractVector{VT} ,ts:: AbstractVector , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) where {T, N, VT<: AbstractArray{T, N} }
101- DiffEqArray {T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)} (vec, ts, syms, indepsym, observed, p)
121+ DiffEqArray {T, N+1, typeof(vec), typeof(ts), SymbolCache{ typeof(syms), typeof(indepsym), Nothing}, typeof(observed), typeof(p)} (vec, ts, SymbolCache ( syms, indepsym, nothing ) , observed, p)
102122end
103123
104124# Interface for the linear indexing. This is just a view of the underlying nested structure
@@ -138,37 +158,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co
138158Base. @propagate_inbounds Base. getindex (A:: AbstractDiffEqArray{T, N} , :: Colon ,i:: Int ) where {T, N} = A. u[i]
139159Base. @propagate_inbounds Base. getindex (A:: AbstractDiffEqArray{T, N} , i:: Int ,II:: AbstractArray{Int} ) where {T, N} = [A. u[j][i] for j in II]
140160Base. @propagate_inbounds function Base. getindex (A:: AbstractDiffEqArray{T, N} ,sym) where {T, N}
141- if issymbollike (sym) && A. syms != = nothing
142- i = findfirst (isequal (Symbol (sym)),A. syms)
143- else
144- i = sym
145- end
146-
147- if i === nothing
148- if issymbollike (sym) && A. indepsym != = nothing && Symbol (sym) == A. indepsym
149- A. t
161+ if issymbollike (sym) && ! isnothing (A. sc)
162+ if is_indep_sym (A. sc, sym)
163+ return A. t
164+ elseif is_state_sym (A. sc, sym)
165+ return getindex .(A. u, state_sym_to_index (A. sc, sym))
166+ elseif is_param_sym (A. sc, sym)
167+ return A. p[param_sym_to_index (A. sc, sym)]
168+ else
169+ return observed (A, sym, :)
170+ end
171+ elseif all (issymbollike, sym) && ! isnothing (A. sc)
172+ if all (Base. Fix1 (is_param_sym, A. sc), sym)
173+ return getindex .((A,), sym)
150174 else
151- observed (A, sym,:)
175+ return [ getindex .((A,), sym, i) for i in eachindex (A . t)]
152176 end
153177 else
154- Base . getindex .(A. u, i )
178+ return getindex .(A. u, sym )
155179 end
156180end
157181Base. @propagate_inbounds function Base. getindex (A:: AbstractDiffEqArray{T, N} ,sym,args... ) where {T, N}
158- if issymbollike (sym) && A. syms != = nothing
159- i = findfirst (isequal (Symbol (sym)),A. syms)
160- else
161- i = sym
162- end
163-
164- if i === nothing
165- if issymbollike (sym) && A. indepsym != = nothing && Symbol (sym) == A. indepsym
166- A. t[args... ]
182+ if issymbollike (sym) && ! isnothing (A. sc)
183+ if is_indep_sym (A. sc, sym)
184+ return A. t[args... ]
185+ elseif is_state_sym (A. sc, sym)
186+ return A[sym][args... ]
167187 else
168- observed (A,sym,args... )
188+ return observed (A, sym, args... )
169189 end
190+ elseif all (issymbollike, sym) && ! isnothing (A. sc)
191+ return reduce (vcat, map (s -> A[s, args... ]' , sym))
170192 else
171- Base . getindex .(A. u, i, args ... )
193+ return getindex .(A. u, sym )
172194 end
173195end
174196Base. @propagate_inbounds Base. getindex (A:: AbstractDiffEqArray{T, N} , I:: Int... ) where {T, N} = A. u[I[end ]][Base. front (I)... ]
@@ -230,8 +252,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
230252Base. copy (VA:: AbstractDiffEqArray ) = typeof (VA)(
231253 copy (VA. u),
232254 copy (VA. t),
233- (VA. syms=== nothing ) ? nothing : copy (VA. syms),
234- (VA. indepsym=== nothing ) ? nothing : copy (VA. indepsym),
255+ (VA. sc=== nothing ) ? nothing : copy (VA. sc),
235256 (VA. observed=== nothing ) ? nothing : copy (VA. observed),
236257 (VA. p=== nothing ) ? nothing : copy (VA. p)
237258 )
0 commit comments