Skip to content

Commit b9c30c8

Browse files
Update arraypartition_gpu.jl
1 parent 3da6b80 commit b9c30c8

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

test/gpu/arraypartition_gpu.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,26 @@ a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
2222
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
2323
@. a + b
2424

25+
# Test adapt from ArrayPartition with CuArrays to ArrayPartition with CPU arrays
26+
27+
a = CuArray(Float64.([1., 2., 3., 4.]))
28+
b = CuArray(Float64.([1., 2., 3., 4.]))
29+
part_a_gpu = ArrayPartition(a, b)
30+
part_a = adapt(Array{Float32}, part_a_gpu)
31+
32+
c = Float32.([1., 2., 3., 4.])
33+
d = Float32.([1., 2., 3., 4.])
34+
part_b = ArrayPartition(c, d)
35+
36+
@test part_a == part_b # Test equality
37+
38+
for i in 1:length(part_a.x)
39+
sub_a = part_a.x[i]
40+
sub_b = part_b.x[i]
41+
@test sub_a == sub_b # Test for value equality in sub-arrays
42+
@test typeof(sub_a) === typeof(sub_b) # Test type equality
43+
end
44+
2545
x = ArrayPartition((CUDA.zeros(2),CUDA.zeros(2)))
2646
@test ArrayInterface.zeromatrix(x) isa CuMatrix
2747
@test size(ArrayInterface.zeromatrix(x)) == (4,4)

0 commit comments

Comments
 (0)