Skip to content

Commit 21408fd

Browse files
committed
Flatten the broadcast before calculating narray
1 parent a354738 commit 21408fd

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/vector_of_array.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ function Base.Array(VA::AbstractVectorOfArray)
1515
Array(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
1616
end
1717

18-
VectorOfArray(vec::AbstractVector{T}, dims::NTuple{N}) where {T, N} = VectorOfArray{eltype(T), N, typeof(vec)}(vec)
18+
VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{eltype(T), N, typeof(vec)}(vec)
1919
# Assume that the first element is representative of all other elements
2020
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
2121
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
2222

23-
DiffEqArray(vec::AbstractVector{T}, ts, dims::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts)}(vec, ts)
23+
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts)}(vec, ts)
2424
# Assume that the first element is representative of all other elements
2525
DiffEqArray(vec::AbstractVector,ts::AbstractVector) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)))
2626
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts)}(vec, ts)
@@ -160,14 +160,15 @@ Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where
160160
Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = VectorOfArrayStyle{N}()
161161

162162
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
163+
bc = Broadcast.flatten(bc)
163164
N = narrays(bc)
164-
x = unpack_voa(bc, 1)
165165
VectorOfArray(map(1:N) do i
166166
copy(unpack_voa(bc, i))
167167
end)
168168
end
169169

170170
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
171+
bc = Broadcast.flatten(bc)
171172
N = narrays(bc)
172173
@inbounds for i in 1:N
173174
copyto!(dest[i], unpack_voa(bc, i))

0 commit comments

Comments
 (0)