Skip to content

Commit 3db225f

Browse files
AbstractVectorOfArray broadcast overload
```julia [ Info: Precompiling RecursiveArrayTools [731186ca-8d62-57ce-b412-fbd966d074cd] Please submit a bug report with steps to reproduce this fault, and any error messages that follow (in their entirety). Thanks. Exception: EXCEPTION_ACCESS_VIOLATION at 0x6690ec5e -- jl_subtype_env at /cygdrive/d/buildbot/worker/package_win64/build/src\subtype.c:1800 in expression starting at none:0 jl_subtype_env at /cygdrive/d/buildbot/worker/package_win64/build/src\subtype.c:1800 jl_subtype at /cygdrive/d/buildbot/worker/package_win64/build/src\subtype.c:1854 [inlined] jl_isa at /cygdrive/d/buildbot/worker/package_win64/build/src\subtype.c:2056 rewrap at .\compiler\typeutils.jl:8 [inlined] matching_cache_argtypes at .\compiler\inferenceresult.jl:132 InferenceResult at .\compiler\inferenceresult.jl:12 [inlined] InferenceResult at .\compiler\inferenceresult.jl:12 [inlined] typeinf_ext at .\compiler\typeinfer.jl:572 typeinf_ext at .\compiler\typeinfer.jl:605 jfptr_typeinf_ext_1.clone_1 at C:\Users\accou\AppData\Local\Programs\Julia\Julia-1.4.1\lib\julia\sys.dll (unknown line) jl_apply at /cygdrive/d/buildbot/worker/package_win64/build/src\julia.h:1700 [inlined] jl_type_infer at /cygdrive/d/buildbot/worker/package_win64/build/src\gf.c:213 jl_compile_method_internal at /cygdrive/d/buildbot/worker/package_win64/build/src\gf.c:1887 _jl_invoke at /cygdrive/d/buildbot/worker/package_win64/build/src\gf.c:2153 [inlined] jl_apply_generic at /cygdrive/d/buildbot/worker/package_win64/build/src\gf.c:2322 _reformat_bt at .\error.jl:90 #catch_stack#49 at .\error.jl:149 catch_stack at .\error.jl:144 [inlined] catch_stack at .\error.jl:144 [inlined] _start at .\client.jl:486 jfptr__start_2087.clone_1 at C:\Users\accou\AppData\Local\Programs\Julia\Julia-1.4.1\lib\julia\sys.dll (unknown line) unknown function (ip: 00000000004017E1) unknown function (ip: 0000000000401BD6) unknown function (ip: 00000000004013DE) unknown function (ip: 000000000040151A) BaseThreadInitThunk at C:\WINDOWS\System32\KERNEL32.DLL (unknown line) RtlUserThreadStart at C:\WINDOWS\SYSTEM32\ntdll.dll (unknown line) Allocations: 5137387 (Pool: 5136723; Big: 664); GC: 3 ```
1 parent 31bf951 commit 3db225f

File tree

1 file changed

+63
-22
lines changed

1 file changed

+63
-22
lines changed

src/vector_of_array.jl

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,68 @@ end
147147
VA.t,VA.u
148148
end
149149

150-
# Broadcast
151-
152-
#add_idxs(x,expr) = expr
153-
#add_idxs{T<:AbstractVectorOfArray}(::Type{T},expr) = :($(expr)[i])
154-
#add_idxs{T<:AbstractArray}(::Type{Vector{T}},expr) = :($(expr)[i])
155-
#=
156-
@generated function Base.broadcast!(f,A::AbstractVectorOfArray,B...)
157-
exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
158-
:(for i in eachindex(A)
159-
broadcast!(f,A[i],$(exs...))
160-
end)
161-
end
162-
163-
@generated function Base.broadcast(f,B::Union{Number,AbstractVectorOfArray}...)
164-
arr_idx = 0
165-
for (i,b) in enumerate(B)
166-
if b <: ArrayPartition
167-
arr_idx = i
168-
break
150+
## broadcasting
151+
152+
struct VectorOfArrayStyle{Style <: Broadcast.BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
153+
VectorOfArrayStyle(::S) where {S} = VectorOfArrayStyle{S}()
154+
VectorOfArrayStyle(::S, ::Val{N}) where {S,N} = VectorOfArrayStyle(S(Val(N)))
155+
VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{Broadcast.DefaultArrayStyle{N}}()
156+
157+
# promotion rules
158+
@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
159+
VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
160+
end
161+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{Style}, ::Broadcast.DefaultArrayStyle{0}) where Style<:Broadcast.BroadcastStyle = VectorOfArrayStyle{Style}()
162+
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
163+
164+
function Broadcast.BroadcastStyle(::Type{AbstractVectorOfArray{T,S}}) where {T, N}
165+
VectorOfArrayStyle(Broadcast.result_style(Broadcast.BroadcastStyle(T)))
166+
end
167+
168+
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}) where Style
169+
N = narrays(bc)
170+
VectorOfArray(map(1:N) do i
171+
copy(unpack_voa(bc, i))
172+
end)
173+
end
174+
175+
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}) where Style
176+
N = narrays(bc)
177+
@inbounds for i in 1:N
178+
copyto!(dest[i], unpack_voa(bc, i))
169179
end
170-
end
171-
:(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
180+
dest
172181
end
173-
=#
182+
183+
## broadcasting utils
184+
185+
"""
186+
narrays(A...)
187+
188+
Retrieve number of arrays in the AbstractVectorOfArrays of a broadcast
189+
"""
190+
narrays(A) = 0
191+
narrays(A::AbstractVectorOfArray) = length(A)
192+
narrays(bc::Broadcast.Broadcasted) = _narrays(bc.args)
193+
narrays(A, Bs...) = common_length(narrays(A), _narrays(Bs))
194+
195+
common_length(a, b) =
196+
a == 0 ? b :
197+
(b == 0 ? a :
198+
(a == b ? a :
199+
throw(DimensionMismatch("number of arrays must be equal"))))
200+
201+
@inline _narrays(args::Tuple) = common_length(narrays(args[1]), _narrays(Base.tail(args)))
202+
_narrays(args::Tuple{Any}) = _narrays(args[1])
203+
_narrays(args::Tuple{}) = 0
204+
205+
# drop axes because it is easier to recompute
206+
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
207+
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
208+
unpack_voa(x,::Any) = x
209+
unpack_voa(x::AbstractVectorOfArray, i) = x[i]
210+
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]
211+
212+
@inline unpack_args_voa(i, args::Tuple) = (unpack_voa(args[1], i), unpack_args_voa(i, Base.tail(args))...)
213+
unpack_args_voa(i, args::Tuple{Any}) = (unpack_voa(args[1], i),)
214+
unpack_args_voa(::Any, args::Tuple{}) = ()

0 commit comments

Comments
 (0)