|
1 | | -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i) |
| 1 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i) |
2 | 2 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
3 | 3 | Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] |
4 | | - (Δ′,nothing) |
| 4 | + (NoTangent(),Δ′,NoTangent()) |
5 | 5 | end |
6 | 6 | VA[i],AbstractVectorOfArray_getindex_adjoint |
7 | 7 | end |
8 | 8 |
|
9 | | -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...) |
| 9 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...) |
10 | 10 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
11 | 11 | Δ′ = zero(VA) |
12 | 12 | Δ′[i,j...] = Δ |
13 | | - (Δ′, i,map(_ -> nothing, j)...) |
| 13 | + (NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...) |
14 | 14 | end |
15 | 15 | VA[i,j...],AbstractVectorOfArray_getindex_adjoint |
16 | 16 | end |
17 | 17 |
|
18 | | -ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x} |
| 18 | +function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x} |
19 | 19 | function ArrayPartition_adjoint(_y) |
20 | 20 | y = Array(_y) |
21 | 21 | starts = vcat(0,cumsum(reduce(vcat,length.(x)))) |
22 | | - ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), nothing |
| 22 | + NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent() |
23 | 23 | end |
24 | 24 |
|
25 | 25 | ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint |
26 | 26 | end |
27 | 27 |
|
28 | | -ZygoteRules.@adjoint function VectorOfArray(u) |
29 | | - VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],) |
| 28 | +function ChainRulesCore.rrule(::Type{<:VectorOfArray},u) |
| 29 | + VectorOfArray(u),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]]) |
30 | 30 | end |
31 | 31 |
|
32 | | -ZygoteRules.@adjoint function DiffEqArray(u,t) |
33 | | - DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing) |
| 32 | +function ChainRulesCore.rrule(::Type{<:DiffEqArray},u,t) |
| 33 | + DiffEqArray(u,t),y -> (NoTangent(),[y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],NoTangent()) |
34 | 34 | end |
35 | 35 |
|
36 | | -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x}) |
37 | | - function literal_ArrayPartition_x_adjoint(d) |
38 | | - (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) |
39 | | - end |
40 | | - A.x,literal_ArrayPartition_x_adjoint |
| 36 | +function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol) |
| 37 | + if s !== :x |
| 38 | + error("$s is not a field of ArrayPartition") |
| 39 | + end |
| 40 | + function literal_ArrayPartition_x_adjoint(d) |
| 41 | + (NoTangent(),ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...)) |
| 42 | + end |
| 43 | + A.x,literal_ArrayPartition_x_adjoint |
41 | 44 | end |
0 commit comments