Skip to content

Commit 1a0fcfd

Browse files
close
1 parent 229e61c commit 1a0fcfd

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/vector_of_array.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,16 @@ end
161161
Broadcast.BroadcastStyle(::VectorOfArrayStyle{Style}, ::Broadcast.DefaultArrayStyle{0}) where Style<:Broadcast.BroadcastStyle = VectorOfArrayStyle{Style}()
162162
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
163163

164-
function Broadcast.BroadcastStyle(::Type{AbstractVectorOfArray{T,S}}) where {T, S}
164+
function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S}
165165
VectorOfArrayStyle(Broadcast.result_style(Broadcast.BroadcastStyle(T)))
166166
end
167167

168168
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}) where Style
169169
N = narrays(bc)
170+
@show "here"
171+
x = unpack_voa(bc, 1)
172+
@show x
173+
@show copy(x)
170174
VectorOfArray(map(1:N) do i
171175
copy(unpack_voa(bc, i))
172176
end)
@@ -198,6 +202,7 @@ common_length(a, b) =
198202
(a == b ? a :
199203
throw(DimensionMismatch("number of arrays must be equal"))))
200204

205+
_narrays(args::AbstractVectorOfArray) = length(args)
201206
@inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args)))
202207
_narrays(args::Tuple{Any}) = _narrays(args[1])
203208
_narrays(args::Tuple{}) = 0
@@ -206,7 +211,7 @@ _narrays(args::Tuple{}) = 0
206211
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
207212
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
208213
unpack_voa(x,::Any) = x
209-
unpack_voa(x::AbstractVectorOfArray, i) = x[i]
214+
unpack_voa(x::AbstractVectorOfArray, i) = x.u[i]
210215
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]
211216

212217
@inline unpack_args_voa(i, args::Tuple) = (unpack_voa(args[1], i), unpack_args_voa(i, Base.tail(args))...)

test/basic_indexing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools
1+
using RecursiveArrayTools, Test
22

33
# Example Problem
44
recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
@@ -86,3 +86,5 @@ v = VectorOfArray([zeros(20), zeros(10,10), zeros(3,3,3)])
8686
v[CartesianIndex((2, 3, 2, 3))] = 1
8787
@test v[CartesianIndex((2, 3, 2, 3))] == 1
8888
@test v.u[3][2, 3, 2] == 1
89+
90+
v .* v

0 commit comments

Comments
 (0)