|
14 | 14 | ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() |
15 | 15 |
|
16 | 16 | function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray}, |
17 | | - xs::AbstractVectorOfArray) |
| 17 | + xs::AbstractVectorOfArray) |
18 | 18 | T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ) |
19 | 19 | end |
20 | 20 |
|
|
28 | 28 | end |
29 | 29 |
|
30 | 30 | @adjoint function getindex(VA::AbstractVectorOfArray, |
31 | | - i::Union{BitArray, AbstractArray{Bool}}) |
| 31 | + i::Union{BitArray, AbstractArray{Bool}}) |
32 | 32 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
33 | 33 | Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x))) |
34 | 34 | for (x, j) in zip(VA.u, 1:length(VA))] |
|
48 | 48 | end |
49 | 49 |
|
50 | 50 | @adjoint function getindex(VA::AbstractVectorOfArray, |
51 | | - i::Union{Int, AbstractArray{Int}}) |
| 51 | + i::Union{Int, AbstractArray{Int}}) |
52 | 52 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
53 | 53 | Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x))) |
54 | 54 | for (x, j) in zip(VA.u, 1:length(VA))] |
|
65 | 65 | end |
66 | 66 |
|
67 | 67 | @adjoint function getindex(VA::AbstractVectorOfArray, i::Int, |
68 | | - j::Union{Int, AbstractArray{Int}, CartesianIndex, |
69 | | - Colon, BitArray, AbstractArray{Bool}}...) |
| 68 | + j::Union{Int, AbstractArray{Int}, CartesianIndex, |
| 69 | + Colon, BitArray, AbstractArray{Bool}}...) |
70 | 70 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
71 | 71 | Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) |
72 | 72 | Δ′[i, j...] = Δ |
|
76 | 76 | end |
77 | 77 |
|
78 | 78 | @adjoint function ArrayPartition(x::S, |
79 | | - ::Type{Val{copy_x}} = Val{false}) where { |
80 | | - S <: |
81 | | - Tuple, |
82 | | - copy_x |
83 | | - } |
| 79 | + ::Type{Val{copy_x}} = Val{false}) where { |
| 80 | + S <: |
| 81 | + Tuple, |
| 82 | + copy_x, |
| 83 | +} |
84 | 84 | function ArrayPartition_adjoint(_y) |
85 | 85 | y = Array(_y) |
86 | 86 | starts = vcat(0, cumsum(reduce(vcat, length.(x)))) |
|
93 | 93 |
|
94 | 94 | @adjoint function VectorOfArray(u) |
95 | 95 | VectorOfArray(u), |
96 | | - y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] |
97 | | - for i in 1:size(y)[end]]),) |
| 96 | + y -> begin |
| 97 | + y isa Ref && (y = VectorOfArray(y[].u)) |
| 98 | + (VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i] |
| 99 | + for i in 1:size(y.u)[end]]),) |
| 100 | + end |
98 | 101 | end |
99 | 102 |
|
100 | 103 | @adjoint function DiffEqArray(u, t) |
101 | 104 | DiffEqArray(u, t), |
102 | | - y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], |
103 | | - t), nothing) |
| 105 | + y -> begin |
| 106 | + y isa Ref && (y = VectorOfArray(y[].u)) |
| 107 | + (DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i] |
| 108 | + for i in 1:size(y.u)[end]], |
| 109 | + t), nothing) |
| 110 | + end |
104 | 111 | end |
105 | 112 |
|
106 | 113 | @adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x}) |
|
0 commit comments