Skip to content

Commit a608541

Browse files
add some more ZygoteRules
1 parent 045cab5 commit a608541

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/init.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ function __init__()
1717
return CuArrays.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
1818
end
1919
Base.convert(::Type{<:CuArrays.CuArray},VA::AbstractVectorOfArray) = CuArrays.CuArray(VA)
20-
@adjoint CuArrays.CuArray(xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (ȳ,)
20+
ZygoteRules.@adjoint CuArrays.CuArray(xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (ȳ,)
2121
end
22-
22+
2323
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
2424
function CUDA.CuArray(VA::AbstractVectorOfArray)
2525
vecs = vec.(VA.u)
2626
return CUDA.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
2727
end
2828
Base.convert(::Type{<:CUDA.CuArray},VA::AbstractVectorOfArray) = CUDA.CuArray(VA)
29-
@adjoint CUDA.CuArray(xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (ȳ,)
29+
ZygoteRules.@adjoint CUDA.CuArray(xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (ȳ,)
3030
end
3131

3232
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin

0 commit comments

Comments
 (0)