@@ -604,7 +604,7 @@ struct CPUIndex <: AbstractCPU end
604604struct GPU <: AbstractDevice end
605605
606606"""
607- device(::Type{T})
607+ device(::Type{T})
608608
609609Indicates the most efficient way to access elements from the collection in low-level code.
610610For `GPUArrays`, will return `ArrayInterface.GPU()`.
@@ -615,20 +615,25 @@ Otherwise, returns `nothing`.
615615device (A) = device (typeof (A))
616616device (:: Type ) = nothing
617617device (:: Type{<:Tuple} ) = CPUIndex ()
618- # Relies on overloading for GPUArrays that have subtyped `StridedArray`.
619- device (:: Type{<:StridedArray} ) = CPUPointer ()
620- device (
621- :: Type{<:SubArray{T,N,A,I}} ,
622- ) where {T,N,A,I<: Tuple{Vararg{Union{Integer,AbstractRange}}} } = device (A)
623- device (:: Type{<:SubArray} ) = CPUIndex ()
624- function device (:: Type{T} ) where {T<: AbstractArray }
625- P = parent_type (T)
626- T === P ? CPUIndex () : device (P)
618+ device (:: Type{T} ) where {T<: Array } = CPUPointer ()
619+ device (:: Type{T} ) where {T<: AbstractArray } = CPUIndex ()
620+ device (:: Type{T} ) where {T<: PermutedDimsArray } = device (parent_type (T))
621+ device (:: Type{T} ) where {T<: Transpose } = device (parent_type (T))
622+ device (:: Type{T} ) where {T<: Adjoint } = device (parent_type (T))
623+ device (:: Type{T} ) where {T<: ReinterpretArray } = device (parent_type (T))
624+ device (:: Type{T} ) where {T<: ReshapedArray } = device (parent_type (T))
625+ function device (:: Type{T} ) where {T<: SubArray }
626+ if defines_strides (T)
627+ return device (parent_type (T))
628+ else
629+ return _not_pointer (device (parent_type (T)))
630+ end
627631end
628-
632+ _not_pointer (:: CPUPointer ) = CPUIndex ()
633+ _not_pointer (x) = x
629634
630635"""
631- defines_strides(::Type{T}) -> Bool
636+ defines_strides(::Type{T}) -> Bool
632637
633638Is strides(::T) defined?
634639"""
@@ -1058,6 +1063,9 @@ function __init__()
10581063 stride_rank (parent_type (A))
10591064 ArrayInterface. axes (A:: OffsetArrays.OffsetArray ) = Base. axes (A)
10601065 ArrayInterface. axes (A:: OffsetArrays.OffsetArray , dim:: Integer ) = Base. axes (A, dim)
1066+ function ArrayInterface. device (:: Type{T} ) where {T<: OffsetArrays.OffsetArray }
1067+ return device (parent_type (T))
1068+ end
10611069 end
10621070end
10631071
0 commit comments