Skip to content

Commit 0dfed2f

Browse files
Use GPUArraysCore
Before: julia> @time_imports using RecursiveArrayTools 18.3 ms ┌ MacroTools 35.6 ms ┌ ZygoteRules 2.3 ms ┌ Compat 318.7 ms ┌ FillArrays 1320.3 ms ┌ StaticArrays 80.1 ms ┌ Preferences 86.5 ms ┌ JLLWrappers 373.9 ms ┌ LLVMExtra_jll 17.1 ms ┌ CEnum 485.5 ms ┌ LLVM 3.2 ms ┌ Adapt 2731.4 ms ┌ GPUArrays 11.8 ms ┌ DocStringExtensions 2.1 ms ┌ IfElse 72.5 ms ┌ RecipesBase 100.5 ms ┌ Static 6.0 ms ┌ ArrayInterfaceCore 164.4 ms ┌ ArrayInterface 13.0 ms ┌ ArrayInterfaceStaticArrays 208.8 ms ┌ ChainRulesCore 5503.5 ms RecursiveArrayTools julia> @time_imports using RecursiveArrayTools 28.2 ms ┌ MacroTools 53.8 ms ┌ ZygoteRules 3.2 ms ┌ Compat 503.4 ms ┌ FillArrays 1715.4 ms ┌ StaticArrays 69.1 ms ┌ Preferences 103.9 ms ┌ JLLWrappers 500.3 ms ┌ LLVMExtra_jll 7.0 ms ┌ CEnum 265.1 ms ┌ LLVM 3.1 ms ┌ Adapt 2544.7 ms ┌ GPUArrays 16.1 ms ┌ DocStringExtensions 3.1 ms ┌ IfElse 265.8 ms ┌ RecipesBase 175.6 ms ┌ Static 8.1 ms ┌ ArrayInterfaceCore 222.9 ms ┌ ArrayInterface 3.6 ms ┌ ArrayInterfaceStaticArrays 175.3 ms ┌ ChainRulesCore 6228.5 ms RecursiveArrayTools After: julia> @time_imports using RecursiveArrayTools 30.9 ms ┌ MacroTools 54.5 ms ┌ ZygoteRules 2.3 ms ┌ Compat 299.2 ms ┌ FillArrays 1240.2 ms ┌ StaticArrays 7.3 ms ┌ DocStringExtensions 1.5 ms ┌ IfElse 43.5 ms ┌ RecipesBase 149.6 ms ┌ Static 5.2 ms ┌ ArrayInterfaceCore 195.5 ms ┌ ArrayInterface 3.3 ms ┌ Adapt 5.7 ms ┌ ArrayInterfaceStaticArrays 185.8 ms ┌ ChainRulesCore 10.2 ms ┌ GPUArraysCore 2169.1 ms RecursiveArrayTools julia> @time_imports using RecursiveArrayTools 23.3 ms ┌ MacroTools 44.2 ms ┌ ZygoteRules 2.4 ms ┌ Compat 305.4 ms ┌ FillArrays 1015.0 ms ┌ StaticArrays 12.3 ms ┌ DocStringExtensions 2.0 ms ┌ IfElse 108.6 ms ┌ RecipesBase 134.0 ms ┌ Static 5.8 ms ┌ ArrayInterfaceCore 190.5 ms ┌ ArrayInterface 3.4 ms ┌ Adapt 5.2 ms ┌ ArrayInterfaceStaticArrays 273.1 ms ┌ ChainRulesCore 13.3 ms ┌ GPUArraysCore 2111.9 ms RecursiveArrayTools
1 parent e2de7bf commit 0dfed2f

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
13-
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
13+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1616
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -24,7 +24,7 @@ ArrayInterfaceStaticArrays = "0.1"
2424
ChainRulesCore = "0.10.7, 1"
2525
DocStringExtensions = "0.8, 0.9"
2626
FillArrays = "0.11, 0.12, 0.13"
27-
GPUArrays = "8"
27+
GPUArraysCore = "0.1"
2828
RecipesBase = "0.7, 0.8, 1.0"
2929
StaticArrays = "0.12, 1.0"
3030
ZygoteRules = "0.2"

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ include("zygote.jl")
2929

3030
Base.show(io::IO, x::Union{ArrayPartition,AbstractVectorOfArray}) = invoke(show, Tuple{typeof(io), Any}, io, x)
3131

32-
import GPUArrays
33-
Base.convert(T::Type{<:GPUArrays.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
34-
ChainRulesCore.rrule(T::Type{<:GPUArrays.AbstractGPUArray}, xs::AbstractVectorOfArray) = T(xs), ȳ -> (NoTangent(),ȳ)
32+
import GPUArraysCore
33+
Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
34+
ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, xs::AbstractVectorOfArray) = T(xs), ȳ -> (NoTangent(),ȳ)
3535

3636
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
3737
AllObserved, vecarr_to_arr, vecarr_to_vectors, tuples

0 commit comments

Comments
 (0)