Skip to content

Commit 6af53d5

Browse files
fix up a few more overloads
1 parent 5c6d7fd commit 6af53d5

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

src/zygote.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
4-
(NoTangent(),Δ′,NoTangent())
4+
(NoTangent(),VectorOfArray(Δ′),NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
@@ -10,7 +10,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indi
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
1212
Δ′[indices...] = Δ
13-
(NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...)
13+
(NoTangent(), VectorOfArray(Δ′), indices[1],map(_ -> NoTangent(), indices[2:end])...)
1414
end
1515
VA[indices...],AbstractVectorOfArray_getindex_adjoint
1616
end
@@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}
1919
function ArrayPartition_adjoint(_y)
2020
y = Array(_y)
2121
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22-
NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent()
22+
NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent()
2323
end
2424

2525
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
@@ -43,8 +43,6 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4343
A.x,literal_ArrayPartition_x_adjoint
4444
end
4545

46-
#=
47-
4846
# Define a new species of projection operator for this type:
4947
ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}()
5048

@@ -53,11 +51,6 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ProjectTo{VectorOfArray}()
5351
# Gradient from broadcasting will be another AbstractArray
5452
(::ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
5553

56-
But this may not be necessary?
57-
58-
=#
59-
60-
6154
# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
6255
# definition first, and finds its own before finding those.
6356

@@ -73,10 +66,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
7366
function AbstractVectorOfArray_getindex_adjoint(Δ)
7467
Δ′ = zero(VA)
7568
Δ′[i,j...] = Δ
76-
@show Δ′
77-
# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
78-
(Δ′, nothing, map(_ -> nothing, j)...)
79-
# (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
69+
(VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
8070
end
8171
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
8272
end
@@ -91,11 +81,11 @@ ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{fal
9181
end
9282

9383
ZygoteRules.@adjoint function VectorOfArray(u)
94-
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
84+
VectorOfArray(u),y -> (VectorOfArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]]),)
9585
end
9686

9787
ZygoteRules.@adjoint function DiffEqArray(u,t)
98-
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
88+
DiffEqArray(u,t),y -> (DiffEqArray([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],t),nothing)
9989
end
10090

10191
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})

0 commit comments

Comments
 (0)