Skip to content

Commit 6250057

Browse files
fix for ode units
1 parent b72c4b4 commit 6250057

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,18 @@ However, broadcast is overloaded to loop in an efficient manner, meaning that
2121
do not match types. A full array interface is included for completeness, which
2222
allows this array type to be used in place of a standard array in places where
2323
such a type stable broadcast may be needed. One example is in heterogeneous
24-
differential equations for [DifferentialEquations.jl](https://github.com/JuliaDiffEq/DifferentialEquations.jl).
24+
differential equations for [DifferentialEquations.jl](https://github.com/JuliaDiffEq/DifferentialEquations.jl).
25+
26+
An `ArrayPartition` acts like a single array. `A[i]` indexes through the first
27+
array, then the second, etc. all linearly. But `A.x` is where the arrays are stored.
28+
Thus for
29+
30+
```julia
31+
using RecursiveArrayTools
32+
A = ArrayPartition(y,z)
33+
```
34+
35+
We would have `A.x[1]==y` and `A.x[2]==z`. Broadcasting like `f.(A)` is efficient.
2536

2637
### Functions
2738

src/array_partition.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,45 @@ function ArrayPartition{T}(x,::Type{Val{T}}=Val{false})
1111
end
1212
Base.similar(A::ArrayPartition) = ArrayPartition((similar(x) for x in A.x)...)
1313
Base.similar(A::ArrayPartition,dims::Tuple) = ArrayPartition((similar(x,dim) for (x,dim) in zip(A.x,dims))...)
14+
Base.similar(A::ArrayPartition,T,dims::Tuple) = ArrayPartition((similar(x,T,dim) for (x,dim) in zip(A.x,dims))...)
1415
Base.copy(A::ArrayPartition) = Base.similar(A)
1516
Base.zeros(A::ArrayPartition) = ArrayPartition((zeros(x) for x in A.x)...)
1617

18+
# Special to work with units
19+
function Base.ones(A::ArrayPartition)
20+
B = similar(A::ArrayPartition)
21+
for i in eachindex(A.x)
22+
B.x[i] .= eltype(A.x[i])(one(first(A.x[i])))
23+
end
24+
B
25+
end
26+
27+
Base.:+(A::ArrayPartition, B::ArrayPartition) = ArrayPartition((x .+ y for (x,y) in zip(A.x,B.x))...)
28+
Base.:+(A::Number, B::ArrayPartition) = ArrayPartition((A .+ x for x in B.x)...)
29+
Base.:+(A::ArrayPartition, B::Number) = ArrayPartition((B .+ x for x in A.x)...)
30+
Base.:-(A::ArrayPartition, B::ArrayPartition) = ArrayPartition((x .- y for (x,y) in zip(A.x,B.x))...)
31+
Base.:-(A::Number, B::ArrayPartition) = ArrayPartition((A .- x for x in B.x)...)
32+
Base.:-(A::ArrayPartition, B::Number) = ArrayPartition((x .- B for x in A.x)...)
1733
Base.:*(A::Number, B::ArrayPartition) = ArrayPartition((A .* x for x in B.x)...)
1834
Base.:*(A::ArrayPartition, B::Number) = ArrayPartition((x .* B for x in A.x)...)
1935
Base.:/(A::ArrayPartition, B::Number) = ArrayPartition((x ./ B for x in A.x)...)
2036
Base.:\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...)
2137

38+
if VERSION < v"0.6-"
39+
Base.:.+(A::ArrayPartition, B::ArrayPartition) = ArrayPartition((x .+ y for (x,y) in zip(A.x,B.x))...)
40+
Base.:.+(A::Number, B::ArrayPartition) = ArrayPartition((A .+ x for x in B.x)...)
41+
Base.:.+(A::ArrayPartition, B::Number) = ArrayPartition((B .+ x for x in A.x)...)
42+
Base.:.-(A::ArrayPartition, B::ArrayPartition) = ArrayPartition((x .- y for (x,y) in zip(A.x,B.x))...)
43+
Base.:.-(A::Number, B::ArrayPartition) = ArrayPartition((A .- x for x in B.x)...)
44+
Base.:.-(A::ArrayPartition, B::Number) = ArrayPartition((x .- B for x in A.x)...)
45+
Base.:.*(A::ArrayPartition, B::ArrayPartition) = ArrayPartition((x .* y for (x,y) in zip(A.x,B.x))...)
46+
Base.:.*(A::Number, B::ArrayPartition) = ArrayPartition((A .* x for x in B.x)...)
47+
Base.:.*(A::ArrayPartition, B::Number) = ArrayPartition((x .* B for x in A.x)...)
48+
Base.:./(A::ArrayPartition, B::ArrayPartition) = ArrayPartition((x ./ y for (x,y) in zip(A.x,B.x))...)
49+
Base.:./(A::ArrayPartition, B::Number) = ArrayPartition((x ./ B for x in A.x)...)
50+
Base.:.\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...)
51+
end
52+
2253
@inline function Base.getindex( A::ArrayPartition,i::Int)
2354
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
2455
@inbounds for j in 1:length(A.x)
@@ -52,9 +83,9 @@ recursive_one(A::ArrayPartition) = recursive_one(first(A.x))
5283
Base.zero(A::ArrayPartition) = zero(first(A.x))
5384
Base.first(A::ArrayPartition) = first(A.x)
5485

55-
Base.start(A::ArrayPartition) = chain(A.x...)
56-
Base.next(iter::ArrayPartition,state) = next(state,state)
57-
Base.done(iter::ArrayPartition,state) = done(state,state)
86+
Base.start(A::ArrayPartition) = start(chain(A.x...))
87+
Base.next(A::ArrayPartition,state) = next(chain(A.x...),state)
88+
Base.done(A::ArrayPartition,state) = done(chain(A.x...),state)
5889

5990
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
6091
Base.size(A::ArrayPartition) = (length(A),)
@@ -72,15 +103,12 @@ add_idxs{T<:ArrayPartition}(::Type{T},expr) = :($(expr).x[i])
72103
end
73104

74105
@generated function Base.broadcast(f,B::Union{Number,ArrayPartition}...)
75-
exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
76106
arr_idx = 0
77107
for (i,b) in enumerate(B)
78108
if b <: ArrayPartition
79109
arr_idx = i
80110
break
81111
end
82112
end
83-
:(for i in eachindex(B[$arr_idx].x)
84-
broadcast(f,$(exs...))
85-
end)
113+
:(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
86114
end

0 commit comments

Comments
 (0)