Skip to content

Commit f85bebc

Browse files
Merge pull request #104 from SciML/myb/bcfix
Flatten the broadcast before calculating `narray` and handle edge cases
2 parents a354738 + afe634b commit f85bebc

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "2.4.1"
4+
version = "2.4.2"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/vector_of_array.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ function Base.Array(VA::AbstractVectorOfArray)
1515
Array(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
1616
end
1717

18-
VectorOfArray(vec::AbstractVector{T}, dims::NTuple{N}) where {T, N} = VectorOfArray{eltype(T), N, typeof(vec)}(vec)
18+
VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{eltype(T), N, typeof(vec)}(vec)
1919
# Assume that the first element is representative of all other elements
2020
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
2121
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
2222

23-
DiffEqArray(vec::AbstractVector{T}, ts, dims::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts)}(vec, ts)
23+
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts)}(vec, ts)
2424
# Assume that the first element is representative of all other elements
2525
DiffEqArray(vec::AbstractVector,ts::AbstractVector) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)))
2626
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts)}(vec, ts)
@@ -94,7 +94,11 @@ recursivecopy(VA::VectorOfArray) = VectorOfArray(copy.(VA.u))
9494
# For DiffEqArray it ignores ts and fills only u
9595
function Base.fill!(VA::AbstractVectorOfArray, x)
9696
for i in eachindex(VA)
97-
fill!(VA[i], x)
97+
if VA[i] isa AbstractArray
98+
fill!(VA[i], x)
99+
else
100+
VA[i] = x
101+
end
98102
end
99103
return VA
100104
end
@@ -160,17 +164,22 @@ Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where
160164
Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = VectorOfArrayStyle{N}()
161165

162166
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
167+
bc = Broadcast.flatten(bc)
163168
N = narrays(bc)
164-
x = unpack_voa(bc, 1)
165169
VectorOfArray(map(1:N) do i
166170
copy(unpack_voa(bc, i))
167171
end)
168172
end
169173

170174
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
175+
bc = Broadcast.flatten(bc)
171176
N = narrays(bc)
172177
@inbounds for i in 1:N
173-
copyto!(dest[i], unpack_voa(bc, i))
178+
if dest[i] isa AbstractArray
179+
copyto!(dest[i], unpack_voa(bc, i))
180+
else
181+
dest[i] = copy(unpack_voa(bc, i))
182+
end
174183
end
175184
dest
176185
end

test/basic_indexing.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ testva = VectorOfArray(recs)
77

88
# broadcast with array
99
X = rand(3, 3)
10-
mulX = testva .* X
11-
ref = mapreduce((x,y)->x.*y, hcat, testva, eachcol(X))
10+
mulX = sqrt.(abs.(testva .* X))
11+
ref = mapreduce((x,y)->sqrt.(abs.(x.*y)), hcat, testva, eachcol(X))
1212
@test mulX == ref
1313
fill!(mulX, 0)
14-
mulX .= testva .* X
14+
mulX .= sqrt.(abs.(testva .* X))
1515
@test mulX == ref
1616

1717
t = [1,2,3]
@@ -107,3 +107,14 @@ x .= v .* v
107107
w = v .+ 1
108108
@test w isa VectorOfArray
109109
@test w.u == map(x -> x .+ 1, v.u)
110+
111+
# edges cases
112+
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
113+
testva = DiffEqArray(x, x)
114+
testvb = DiffEqArray(x, x)
115+
mulX = sqrt.(abs.(testva .* testvb))
116+
ref = sqrt.(abs.(x .* x))
117+
@test mulX == ref
118+
fill!(mulX, 0)
119+
mulX .= sqrt.(abs.(testva .* testvb))
120+
@test mulX == ref

0 commit comments

Comments
 (0)