Skip to content

Commit 48bbb7a

Browse files
Merge pull request #108 from SciML/ChrisRackauckas-patch-1
CUDA.jl compat
2 parents 0167d2a + 65a60ff commit 48bbb7a

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/init.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ function __init__()
1919
Base.convert(::Type{<:CuArrays.CuArray},VA::AbstractVectorOfArray) = CuArrays.CuArray(VA)
2020
@adjoint CuArrays.CuArray(xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (ȳ,)
2121
end
22+
23+
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
24+
function CUDA.CuArray(VA::AbstractVectorOfArray)
25+
vecs = vec.(VA.u)
26+
return CUDA.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u)))
27+
end
28+
Base.convert(::Type{<:CUDA.CuArray},VA::AbstractVectorOfArray) = CUDA.CuArray(VA)
29+
@adjoint CUDA.CuArray(xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (ȳ,)
30+
end
2231

2332
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
2433
function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:Tracker.TrackedArray,T2<:Tracker.TrackedArray,N}

0 commit comments

Comments
 (0)