@@ -15,7 +15,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j
1515 VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
1616end
1717
18- function ChainRulesCore. rrule (:: ArrayPartition , x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
18+ function ChainRulesCore. rrule (:: Type{<: ArrayPartition} , x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
1919 function ArrayPartition_adjoint (_y)
2020 y = Array (_y)
2121 starts = vcat (0 ,cumsum (reduce (vcat,length .(x))))
@@ -25,11 +25,11 @@ function ChainRulesCore.rrule(::ArrayPartition, x::S, ::Type{Val{copy_x}} = Val{
2525 ArrayPartition (x, Val{copy_x}), ArrayPartition_adjoint
2626end
2727
28- function ChainRulesCore. rrule (:: VectorOfArray ,u)
28+ function ChainRulesCore. rrule (:: Type{<: VectorOfArray} ,u)
2929 VectorOfArray (u),y -> (NoTangent (),[y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]])
3030end
3131
32- function ChainRulesCore. rrule (:: DiffEqArray ,u,t)
32+ function ChainRulesCore. rrule (:: Type{<: DiffEqArray} ,u,t)
3333 DiffEqArray (u,t),y -> (NoTangent (),[y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],NoTangent ())
3434end
3535
0 commit comments