|
1 | | -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int) |
| 1 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) |
2 | 2 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
3 | 3 | Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] |
4 | 4 | (NoTangent(),Δ′,NoTangent()) |
5 | 5 | end |
6 | 6 | VA[i],AbstractVectorOfArray_getindex_adjoint |
7 | 7 | end |
8 | 8 |
|
9 | | -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N} |
| 9 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) |
10 | 10 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
11 | 11 | Δ′ = zero(VA) |
12 | 12 | Δ′[indices...] = Δ |
@@ -43,15 +43,16 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol |
43 | 43 | A.x,literal_ArrayPartition_x_adjoint |
44 | 44 | end |
45 | 45 |
|
46 | | -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) |
| 46 | +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) |
47 | 47 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
48 | 48 | Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] |
49 | 49 | (Δ′,nothing) |
50 | 50 | end |
| 51 | + @show VA[i] |
51 | 52 | VA[i],AbstractVectorOfArray_getindex_adjoint |
52 | 53 | end |
53 | 54 |
|
54 | | -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...) |
| 55 | +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) |
55 | 56 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
56 | 57 | Δ′ = zero(VA) |
57 | 58 | Δ′[i,j...] = Δ |
|
0 commit comments