Skip to content

Commit f5c3165

Browse files
ZygoteRules -> ChainRules
1 parent 532745d commit f5c3165

File tree

3 files changed

+17
-15
lines changed

3 files changed

+17
-15
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@ version = "2.11.4"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1112
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
14-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1515

1616
[compat]
1717
ArrayInterface = "2.7, 3.0"
18+
ChainRulesCore = "0.10.7"
1819
DocStringExtensions = "0.8"
1920
RecipesBase = "0.7, 0.8, 1.0"
2021
Requires = "0.5, 1.0"
2122
StaticArrays = "0.12, 1.0"
22-
ZygoteRules = "0.2"
23-
julia = "1.3"
23+
julia = "1.6"
2424

2525
[extras]
2626
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ module RecursiveArrayTools
66

77
using DocStringExtensions
88
using Requires, RecipesBase, StaticArrays, Statistics,
9-
ArrayInterface, ZygoteRules, LinearAlgebra
9+
ArrayInterface, LinearAlgebra
1010

11+
import ChainRulesCore
12+
import ChainRulesCore: NoTangent
1113
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
1214
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1315

src/zygote.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
1-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i)
1+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i)
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
4-
(Δ′,nothing)
4+
(NoTangent(),Δ′,NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
88

9-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...)
9+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...)
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
1212
Δ′[i,j...] = Δ
13-
(Δ′, i,map(_ -> nothing, j)...)
13+
(NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...)
1414
end
1515
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
1616
end
1717

18-
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
18+
function ChainRulesCore.rrule(::ArrayPartition, x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
1919
function ArrayPartition_adjoint(_y)
2020
y = Array(_y)
2121
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22-
ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), nothing
22+
NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent()
2323
end
2424

2525
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
2626
end
2727

28-
ZygoteRules.@adjoint function VectorOfArray(u)
29-
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
28+
function ChainRulesCore.rrule(::VectorOfArray,u)
29+
VectorOfArray(u),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]])
3030
end
3131

32-
ZygoteRules.@adjoint function DiffEqArray(u,t)
33-
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
32+
function ChainRulesCore.rrule(::DiffEqArray,u,t)
33+
DiffEqArray(u,t),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],NoTangent())
3434
end
3535

3636
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
3737
function literal_ArrayPartition_x_adjoint(d)
38-
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
38+
(NoTangent(),ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...))
3939
end
4040
A.x,literal_ArrayPartition_x_adjoint
4141
end

0 commit comments

Comments
 (0)