File tree Expand file tree Collapse file tree 4 files changed +26
-6
lines changed
Expand file tree Collapse file tree 4 files changed +26
-6
lines changed Original file line number Diff line number Diff line change @@ -23,6 +23,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2323MonteCarloMeasurements = " 0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2424ReverseDiff = " 37e2e3b7-166d-5795-8a7a-e32c996b4267"
2525SparseArrays = " 2f01184e-e22b-5df5-ae63-d93ebab69eaf"
26+ StructArrays = " 09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2627Tracker = " 9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2728Zygote = " e88e6eb3-aa80-5325-afca-941959d7151f"
2829
@@ -33,6 +34,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements"
3334RecursiveArrayToolsMonteCarloMeasurementsExt = " MonteCarloMeasurements"
3435RecursiveArrayToolsReverseDiffExt = [" ReverseDiff" , " Zygote" ]
3536RecursiveArrayToolsSparseArraysExt = [" SparseArrays" ]
37+ RecursiveArrayToolsStructArraysExt = " StructArrays"
3638RecursiveArrayToolsTrackerExt = " Tracker"
3739RecursiveArrayToolsZygoteExt = " Zygote"
3840
Original file line number Diff line number Diff line change 1+ module RecursiveArrayToolsStructArraysExt
2+
3+ import RecursiveArrayTools, StructArrays
4+ RecursiveArrayTools. rewrap (:: StructArrays.StructArray , u) = StructArrays. StructArray (u)
5+
6+ end
Original file line number Diff line number Diff line change @@ -849,28 +849,33 @@ end
849849
850850@inline function Base. copy (bc:: Broadcast.Broadcasted{<:VectorOfArrayStyle} )
851851 bc = Broadcast. flatten (bc)
852-
853852 parent = find_VoA_parent (bc. args)
854853
855- if parent isa AbstractVector
854+ u = if parent isa AbstractVector
856855 # this is the default behavior in v3.15.0
857856 N = narrays (bc)
858- return VectorOfArray ( map (1 : N) do i
857+ map (1 : N) do i
859858 copy (unpack_voa (bc, i))
860- end )
859+ end
861860 else # if parent isa AbstractArray
862- return VectorOfArray ( map (enumerate (Iterators. product (axes (parent)... ))) do (i, _)
861+ map (enumerate (Iterators. product (axes (parent)... ))) do (i, _)
863862 copy (unpack_voa (bc, i))
864- end )
863+ end
865864 end
865+ VectorOfArray (rewrap (parent, u))
866866end
867867
868+ rewrap (:: Array ,u) = u
869+ rewrap (parent, u) = convert (typeof (parent), u)
870+
868871for (type, N_expr) in [
869872 (Broadcast. Broadcasted{<: VectorOfArrayStyle }, :(narrays (bc))),
870873 (Broadcast. Broadcasted{<: Broadcast.DefaultArrayStyle }, :(length (dest. u)))
871874]
872875 @eval @inline function Base. copyto! (dest:: AbstractVectorOfArray ,
873876 bc:: $type )
877+ @show typeof (dest)
878+ error ()
874879 bc = Broadcast. flatten (bc)
875880 N = $ N_expr
876881 @inbounds for i in 1 : N
Original file line number Diff line number Diff line change @@ -114,3 +114,10 @@ a_voa = VectorOfArray(a)
114114a_voa .= 1.0
115115@test a_voa[1 ] == SVector (1.0 , 1.0 )
116116@test a_voa[2 ] == SVector (1.0 , 1.0 )
117+
118+ # Broadcast Copy of StructArray
119+ x = StructArray {SVector{2, Float64}} ((randn (2 ), randn (2 )))
120+ vx = VectorOfArray (x)
121+ vx2 = copy (vx) .+ 1
122+ ans = vx .+ vx2
123+ @test ans. u isa StructArray
You can’t perform that action at this time.
0 commit comments