Skip to content

Commit b1c0c44

Browse files
Merge pull request #144 from SciML/chainrules
ZygoteRules -> ChainRules
2 parents 532745d + cbf86f5 commit b1c0c44

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@ version = "2.11.4"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1112
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
14-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1515

1616
[compat]
1717
ArrayInterface = "2.7, 3.0"
18+
ChainRulesCore = "0.10.7"
1819
DocStringExtensions = "0.8"
1920
RecipesBase = "0.7, 0.8, 1.0"
2021
Requires = "0.5, 1.0"
2122
StaticArrays = "0.12, 1.0"
22-
ZygoteRules = "0.2"
23-
julia = "1.3"
23+
julia = "1.6"
2424

2525
[extras]
2626
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ module RecursiveArrayTools
66

77
using DocStringExtensions
88
using Requires, RecipesBase, StaticArrays, Statistics,
9-
ArrayInterface, ZygoteRules, LinearAlgebra
9+
ArrayInterface, LinearAlgebra
1010

11+
import ChainRulesCore
12+
import ChainRulesCore: NoTangent
1113
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
1214
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
1315

src/init.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ function __init__()
1717
return CuArrays.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
1818
end
1919
Base.convert(::Type{<:CuArrays.CuArray},VA::AbstractVectorOfArray) = CuArrays.CuArray(VA)
20-
@adjoint CuArrays.CuArray(xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (ȳ,)
20+
ChainRules.rrule(::Type{<:CuArrays.CuArray},xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (NoTangent(),ȳ)
2121
end
22-
22+
2323
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
2424
function CUDA.CuArray(VA::AbstractVectorOfArray)
2525
vecs = vec.(VA.u)
2626
return CUDA.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
2727
end
2828
Base.convert(::Type{<:CUDA.CuArray},VA::AbstractVectorOfArray) = CUDA.CuArray(VA)
29-
@adjoint CUDA.CuArray(xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (ȳ,)
29+
ChainRules.rrule(::Type{<:CUDA.CuArray},xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (NoTangent(),ȳ)
3030
end
3131

3232
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin

src/zygote.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,44 @@
1-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i)
1+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i)
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
4-
(Δ′,nothing)
4+
(NoTangent(),Δ′,NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
88

9-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...)
9+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...)
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
1212
Δ′[i,j...] = Δ
13-
(Δ′, i,map(_ -> nothing, j)...)
13+
(NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...)
1414
end
1515
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
1616
end
1717

18-
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
18+
function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
1919
function ArrayPartition_adjoint(_y)
2020
y = Array(_y)
2121
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22-
ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), nothing
22+
NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent()
2323
end
2424

2525
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
2626
end
2727

28-
ZygoteRules.@adjoint function VectorOfArray(u)
29-
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
28+
function ChainRulesCore.rrule(::Type{<:VectorOfArray},u)
29+
VectorOfArray(u),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]])
3030
end
3131

32-
ZygoteRules.@adjoint function DiffEqArray(u,t)
33-
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
32+
function ChainRulesCore.rrule(::Type{<:DiffEqArray},u,t)
33+
DiffEqArray(u,t),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],NoTangent())
3434
end
3535

36-
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
37-
function literal_ArrayPartition_x_adjoint(d)
38-
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
39-
end
40-
A.x,literal_ArrayPartition_x_adjoint
36+
function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol)
37+
if s !== :x
38+
error("$s is not a field of ArrayPartition")
39+
end
40+
function literal_ArrayPartition_x_adjoint(d)
41+
(NoTangent(),ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...))
42+
end
43+
A.x,literal_ArrayPartition_x_adjoint
4144
end

0 commit comments

Comments
 (0)