Skip to content

Commit 14e3e11

Browse files
committed
add a sensitive test
1 parent fd72df9 commit 14e3e11

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

test/adjoints.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using RecursiveArrayTools, Zygote, ForwardDiff, Test
2+
using OrdinaryDiffEq
23

34
function loss(x)
45
sum(abs2,Array(VectorOfArray([x .* i for i in 1:5])))
@@ -30,10 +31,17 @@ function loss5(x)
3031
sum(abs2,Array(ArrayPartition([x .* i for i in 1:5]...)))
3132
end
3233

34+
function loss6(x)
35+
_x = ArrayPartition([x .* i for i in 1:5]...)
36+
_prob = ODEProblem((u,p,t)->u, _x, (0,1))
37+
sum(abs2, Array(_prob.u0))
38+
end
39+
3340
x = float.(6:10)
3441
loss(x)
3542
@test Zygote.gradient(loss,x)[1] == ForwardDiff.gradient(loss,x)
3643
@test Zygote.gradient(loss2,x)[1] == ForwardDiff.gradient(loss2,x)
3744
@test Zygote.gradient(loss3,x)[1] == ForwardDiff.gradient(loss3,x)
3845
@test Zygote.gradient(loss4,x)[1] == ForwardDiff.gradient(loss4,x)
3946
@test Zygote.gradient(loss5,x)[1] == ForwardDiff.gradient(loss5,x)
47+
@test Zygote.gradient(loss6,x)[1] == ForwardDiff.gradient(loss6,x)

0 commit comments

Comments
 (0)