Skip to content

Commit cbf86f5

Browse files
fully remove ZygoteRules?
1 parent 3da1037 commit cbf86f5

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1212
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1313
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
15-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1615

1716
[compat]
1817
ArrayInterface = "2.7, 3.0"
@@ -21,7 +20,6 @@ DocStringExtensions = "0.8"
2120
RecipesBase = "0.7, 0.8, 1.0"
2221
Requires = "0.5, 1.0"
2322
StaticArrays = "0.12, 1.0"
24-
ZygoteRules = "0.2"
2523
julia = "1.6"
2624

2725
[extras]

src/RecursiveArrayTools.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ using DocStringExtensions
1010

1111
import ChainRulesCore
1212
import ChainRulesCore: NoTangent
13-
import ZygoteRules
1413
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
1514
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1615

src/init.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function __init__()
1717
return CuArrays.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
1818
end
1919
Base.convert(::Type{<:CuArrays.CuArray},VA::AbstractVectorOfArray) = CuArrays.CuArray(VA)
20-
ZygoteRules.@adjoint CuArrays.CuArray(xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (ȳ,)
20+
ChainRules.rrule(::Type{<:CuArrays.CuArray},xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (NoTangent(),ȳ)
2121
end
2222

2323
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
@@ -26,7 +26,7 @@ function __init__()
2626
return CUDA.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
2727
end
2828
Base.convert(::Type{<:CUDA.CuArray},VA::AbstractVectorOfArray) = CUDA.CuArray(VA)
29-
ZygoteRules.@adjoint CUDA.CuArray(xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (ȳ,)
29+
ChainRules.rrule(::Type{<:CUDA.CuArray},xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (NoTangent(),ȳ)
3030
end
3131

3232
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin

src/zygote.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ function ChainRulesCore.rrule(::Type{<:DiffEqArray},u,t)
3333
DiffEqArray(u,t),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],NoTangent())
3434
end
3535

36-
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
37-
function literal_ArrayPartition_x_adjoint(d)
38-
(NoTangent(),ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...))
39-
end
40-
A.x,literal_ArrayPartition_x_adjoint
36+
function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol)
37+
if s !== :x
38+
error("$s is not a field of ArrayPartition")
39+
end
40+
function literal_ArrayPartition_x_adjoint(d)
41+
(NoTangent(),ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...))
42+
end
43+
A.x,literal_ArrayPartition_x_adjoint
4144
end

0 commit comments

Comments
 (0)