Skip to content

Commit f70e46d

Browse files
committed
specialize recursivecopy for ArrayPartition of VectorOfArray
1 parent eb25df4 commit f70e46d

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

src/array_partition.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,17 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
329329
end
330330
recursivecopy(A::ArrayPartition) = ArrayPartition(copy.(A.x))
331331

332+
function recursivecopy(A::ArrayPartition{T, S}) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
333+
return ArrayPartition(map(recursivecopy, A.x))
334+
end
335+
336+
function recursivecopy!(A::ArrayPartition{T, S}, B::ArrayPartition{T, S}) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
337+
for i in eachindex(A.x, B.x)
338+
recursivecopy!(A.x[i], B.x[i])
339+
end
340+
return A
341+
end
342+
332343
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
333344

334345
# note: consider only first partition for recursive one and eltype

test/partitions_test.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,23 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
132132
@inferred recursive_one(x)
133133
@inferred recursive_bottom_eltype(x)
134134

135+
src_voa = VectorOfArray([[1.0, 2.0], [3.0, 4.0]])
136+
src_ap = ArrayPartition(src_voa)
137+
138+
copied_ap = recursivecopy(src_ap)
139+
@test copied_ap.x[1].u[1] == src_ap.x[1].u[1]
140+
@test copied_ap.x[1].u[2] == src_ap.x[1].u[2]
141+
@test copied_ap.x[1].u[1] !== src_ap.x[1].u[1]
142+
@test copied_ap.x[1].u[2] !== src_ap.x[1].u[2]
143+
144+
dest_voa = VectorOfArray([zeros(2), zeros(2)])
145+
dest_ap = ArrayPartition(dest_voa)
146+
recursivecopy!(dest_ap, src_ap)
147+
@test dest_ap.x[1].u[1] == src_ap.x[1].u[1]
148+
@test dest_ap.x[1].u[2] == src_ap.x[1].u[2]
149+
@test dest_ap.x[1].u[1] !== src_ap.x[1].u[1]
150+
@test dest_ap.x[1].u[2] !== src_ap.x[1].u[2]
151+
135152
# mapreduce
136153
@inferred Union{Int, Float64} sum(x)
137154
@inferred sum(ArrayPartition(ArrayPartition(zeros(4, 4))))
@@ -149,7 +166,7 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
149166
@test any(isnan, ArrayPartition([2], [NaN]))
150167
@test any(isnan, ArrayPartition([2], ArrayPartition([NaN])))
151168

152-
# all
169+
# all
153170
@test !all(isnan, ArrayPartition([1, 2], [3.0, 4.0]))
154171
@test !all(isnan, ArrayPartition([3.0, 4.0]))
155172
@test !all(isnan, ArrayPartition([NaN], [3.0, 4.0]))

0 commit comments

Comments
 (0)