Skip to content

Commit 6a5d121

Browse files
Merge pull request #81 from JuliaDiffEq/zygote
add some Zygote rules
2 parents 4a63457 + ca71ba7 commit 6a5d121

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
12+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1213

1314
[compat]
1415
ArrayInterface = "1.2, 2.0"

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ __precompile__()
33
module RecursiveArrayTools
44

55
using Requires, RecipesBase, StaticArrays, Statistics,
6-
ArrayInterface
6+
ArrayInterface, ZygoteRules
77

88
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
99
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
@@ -12,6 +12,7 @@ module RecursiveArrayTools
1212
include("vector_of_array.jl")
1313
include("array_partition.jl")
1414
include("init.jl")
15+
include("zygote.jl")
1516

1617
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
1718
vecarr_to_arr, vecarr_to_vectors, tuples

src/zygote.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i)
2+
function AbstractVectorOfArray_getindex_adjoint(Δ)
3+
Δ′ = Union{Nothing, eltype(VA.u)}[nothing for x in VA.u]
4+
Δ′[i] = Δ
5+
(Δ′,nothing)
6+
end
7+
VA[i],AbstractVectorOfArray_getindex_adjoint
8+
end
9+
10+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...)
11+
function AbstractVectorOfArray_getindex_adjoint(Δ)
12+
Δ′ = zero(VA)
13+
Δ′[i,j...] = Δ
14+
(Δ′, map(_ -> nothing, i)...)
15+
end
16+
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
17+
end

0 commit comments

Comments
 (0)