Skip to content

Commit 8005b83

Browse files
Add new interface for symbolic indexing
- DiffEqArray now stores the new `SymbolCache` struct, which defines and implements its interface to query symbols - DiffEqArray supports symbolically indexing parameters, if provided
1 parent 8f22d84 commit 8005b83

File tree

2 files changed

+52
-31
lines changed

2 files changed

+52
-31
lines changed

src/tabletraits.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
77
N = length(A.u[1])
88
names = [
99
:timestamp,
10-
(A.syms !== nothing ? (A.syms[i] for i in 1:N) :
10+
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
1111
(Symbol("value", i) for i in 1:N))...,
1212
]
1313
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
1414
else
15-
names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value]
15+
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
1616
types = Type[eltype(A.t), VT]
1717
end
1818
return AbstractDiffEqArrayRows(names, types, A.t, A.u)

src/vector_of_array.jl

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,27 @@ mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
3333
end
3434
# VectorOfArray with an added series for time
3535

36+
37+
struct SymbolCache{S,T,U}
38+
syms::S
39+
indepsym::T
40+
paramsyms::U
41+
end
42+
43+
is_indep_sym(sc::SymbolCache, sym) = isequal(sc.indepsym, sym)
44+
is_indep_sym(::SymbolCache{S,Nothing}, _) where {S} = false
45+
state_sym_to_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.syms)
46+
state_sym_to_index(::SymbolCache{Nothing}, _) = nothing
47+
is_state_sym(sc::SymbolCache, sym) = !isnothing(state_sym_to_index(sc, sym))
48+
param_sym_to_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.paramsyms)
49+
param_sym_to_index(::SymbolCache{S,T,Nothing}, _) where {S,T} = nothing
50+
is_param_sym(sc::SymbolCache, sym) = !isnothing(param_sym_to_index(sc, sym))
51+
52+
Base.copy(VA::SymbolCache) = typeof(VA)(
53+
(VA.syms===nothing) ? nothing : copy(VA.syms),
54+
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
55+
(VA.paramsyms===nothing) ? nothing : copy(VA.paramsyms),
56+
)
3657
"""
3758
```julia
3859
DiffEqArray(u::AbstractVector,t::AbstractVector)
@@ -53,11 +74,10 @@ A[1,:] # all time periods for f(t)
5374
A.t
5475
```
5576
"""
56-
mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
77+
mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A}
5778
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
5879
t::B
59-
syms::C
60-
indepsym::D
80+
sc::C
6181
observed::E
6282
p::F
6383
end
@@ -94,11 +114,11 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
94114
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
95115
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
96116

97-
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
117+
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), SymbolCache{typeof(syms), typeof(indepsym), Nothing}, typeof(observed), typeof(p)}(vec, ts, SymbolCache(syms, indepsym, nothing), observed, p)
98118
# Assume that the first element is representative of all other elements
99119
DiffEqArray(vec::AbstractVector,ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)), syms, indepsym, observed, p)
100120
function DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N, VT<:AbstractArray{T, N}}
101-
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
121+
DiffEqArray{T, N+1, typeof(vec), typeof(ts), SymbolCache{typeof(syms), typeof(indepsym), Nothing}, typeof(observed), typeof(p)}(vec, ts, SymbolCache(syms, indepsym, nothing), observed, p)
102122
end
103123

104124
# Interface for the linear indexing. This is just a view of the underlying nested structure
@@ -138,37 +158,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co
138158
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i]
139159
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II]
140160
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
141-
if issymbollike(sym) && A.syms !== nothing
142-
i = findfirst(isequal(Symbol(sym)),A.syms)
143-
else
144-
i = sym
145-
end
146-
147-
if i === nothing
148-
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
149-
A.t
161+
if issymbollike(sym) && !isnothing(A.sc)
162+
if is_indep_sym(A.sc, sym)
163+
return A.t
164+
elseif is_state_sym(A.sc, sym)
165+
return getindex.(A.u, state_sym_to_index(A.sc, sym))
166+
elseif is_param_sym(A.sc, sym)
167+
return A.p[param_sym_to_index(A.sc, sym)]
168+
else
169+
return observed(A, sym, :)
170+
end
171+
elseif all(issymbollike, sym) && !isnothing(A.sc)
172+
if all(Base.Fix1(is_param_sym, A.sc), sym)
173+
return getindex.((A,), sym)
150174
else
151-
observed(A,sym,:)
175+
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
152176
end
153177
else
154-
Base.getindex.(A.u, i)
178+
return getindex.(A.u, sym)
155179
end
156180
end
157181
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
158-
if issymbollike(sym) && A.syms !== nothing
159-
i = findfirst(isequal(Symbol(sym)),A.syms)
160-
else
161-
i = sym
162-
end
163-
164-
if i === nothing
165-
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
166-
A.t[args...]
182+
if issymbollike(sym) && !isnothing(A.sc)
183+
if is_indep_sym(A.sc, sym)
184+
return A.t[args...]
185+
elseif is_state_sym(A.sc, sym)
186+
return A[sym][args...]
167187
else
168-
observed(A,sym,args...)
188+
return observed(A, sym, args...)
169189
end
190+
elseif all(issymbollike, sym) && !isnothing(A.sc)
191+
return reduce(vcat, map(s -> A[s, args...]', sym))
170192
else
171-
Base.getindex.(A.u, i, args...)
193+
return getindex.(A.u, sym)
172194
end
173195
end
174196
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
@@ -230,8 +252,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
230252
Base.copy(VA::AbstractDiffEqArray) = typeof(VA)(
231253
copy(VA.u),
232254
copy(VA.t),
233-
(VA.syms===nothing) ? nothing : copy(VA.syms),
234-
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
255+
(VA.sc===nothing) ? nothing : copy(VA.sc),
235256
(VA.observed===nothing) ? nothing : copy(VA.observed),
236257
(VA.p===nothing) ? nothing : copy(VA.p)
237258
)

0 commit comments

Comments
 (0)