Skip to content

Commit 2d64664

Browse files
Merge pull request #502 from JoshuaLampert/recursivecopy-arraypartition-voa
Specialize recursivecopy for `ArrayPartition` of `VectorOfArray`
2 parents eb25df4 + 3056492 commit 2d64664

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

src/array_partition.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,19 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
329329
end
330330
recursivecopy(A::ArrayPartition) = ArrayPartition(copy.(A.x))
331331

332+
function recursivecopy(A::ArrayPartition{
333+
T, S}) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
334+
return ArrayPartition(map(recursivecopy, A.x))
335+
end
336+
337+
function recursivecopy!(A::ArrayPartition{T, S},
338+
B::ArrayPartition{T, S}) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
339+
for i in eachindex(A.x, B.x)
340+
recursivecopy!(A.x[i], B.x[i])
341+
end
342+
return A
343+
end
344+
332345
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
333346

334347
# note: consider only first partition for recursive one and eltype
@@ -475,7 +488,7 @@ end
475488
## Linear Algebra
476489

477490
function ArrayInterface.zeromatrix(A::ArrayPartition)
478-
x = reduce(vcat,vec.(A.x))
491+
x = reduce(vcat, vec.(A.x))
479492
x .* x' .* false
480493
end
481494

test/partitions_test.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,23 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
132132
@inferred recursive_one(x)
133133
@inferred recursive_bottom_eltype(x)
134134

135+
src_voa = VectorOfArray([[1.0, 2.0], [3.0, 4.0]])
136+
src_ap = ArrayPartition(src_voa)
137+
138+
copied_ap = recursivecopy(src_ap)
139+
@test copied_ap.x[1].u[1] == src_ap.x[1].u[1]
140+
@test copied_ap.x[1].u[2] == src_ap.x[1].u[2]
141+
@test copied_ap.x[1].u[1] !== src_ap.x[1].u[1]
142+
@test copied_ap.x[1].u[2] !== src_ap.x[1].u[2]
143+
144+
dest_voa = VectorOfArray([zeros(2), zeros(2)])
145+
dest_ap = ArrayPartition(dest_voa)
146+
recursivecopy!(dest_ap, src_ap)
147+
@test dest_ap.x[1].u[1] == src_ap.x[1].u[1]
148+
@test dest_ap.x[1].u[2] == src_ap.x[1].u[2]
149+
@test dest_ap.x[1].u[1] !== src_ap.x[1].u[1]
150+
@test dest_ap.x[1].u[2] !== src_ap.x[1].u[2]
151+
135152
# mapreduce
136153
@inferred Union{Int, Float64} sum(x)
137154
@inferred sum(ArrayPartition(ArrayPartition(zeros(4, 4))))
@@ -149,7 +166,7 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
149166
@test any(isnan, ArrayPartition([2], [NaN]))
150167
@test any(isnan, ArrayPartition([2], ArrayPartition([NaN])))
151168

152-
# all
169+
# all
153170
@test !all(isnan, ArrayPartition([1, 2], [3.0, 4.0]))
154171
@test !all(isnan, ArrayPartition([3.0, 4.0]))
155172
@test !all(isnan, ArrayPartition([NaN], [3.0, 4.0]))

0 commit comments

Comments
 (0)